## Intro

In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.


Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.

Given stimulus symbols as inputs, the model produces two outputs:

- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).
- Halting probability ($\lambda_n$).

Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.

### Additional resources

- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)


## Problem setting

### Model
Given input and output data, we want to learn a supervised model of the function $X \to y$ as follows:

$
f: X,h_n \mapsto \tilde{y},h_{n+1}, \lambda_n
$

where $X$ and $y$ denote stimulus and response symbols, $\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.

For the brevity and compatibility, both data are one-hot encoded.


### Input

One-hot encoded symbols.

### Output

One-hot encoded symbols.

### Criterion

L = L_cross_entropy + L_halting

In [157]:
# Setup and imports
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

from sklearn.metrics import accuracy_score

import numpy as np
from scipy import stats
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard

In [158]:
# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.

signal_rate = 2
noise_rate = 1
rate = signal_rate + noise_rate
max_duration_in_sec = 10.
resolution_in_sec = .1

n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)
n_spikes = np.random.poisson(rate * max_duration_in_sec)

# method 1: shuffle timesteps
spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))

# method 2: exponential isi -> timestamps
# isi = np.random.exponential(1 / rate, n_spikes)
# spike_timestamps = np.cumsum(isi)

# method 3: homogenous spikes -> timestamps
# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)

## Mock data

In [226]:


def generate_mock_data(n_subjects, n_trials, n_stimuli):
  """[summary]

  # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time

  Args:
      n_subjects (int): [description]
      n_trials (int): [description]
      n_stimuli (int): [description]

  Returns:
      (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).
  """
  # stimuli
  X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))

  # response accuracy
  subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)
  subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials
  accuracies = np.empty(shape=(n_subjects, n_trials))
  for subj in range(n_subjects):
    accuracies[subj,:] = np.random.choice(
      [0,1],
      p=[1-subject_accuracies[subj],subject_accuracies[subj]],
      size=n_trials)

  # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)
  y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))

  # response time
  response_times = np.random.exponential(.5, size=accuracies.shape)

  if n_subjects == 1:
    X = X.squeeze()
    y = y.squeeze()
    accuracies = accuracies.squeeze()
    response_times = response_times.squeeze()

  return X, y, accuracies, response_times

In [269]:
# mock data parameters
n_subjects = 1
n_trials = 20
n_stimuli = 6

X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)

In [270]:
class ICOM(nn.Module):
    def __init__(self, n_inputs, n_channels, n_outputs):
      super(ICOM, self).__init__()

      self.n_inputs = n_inputs
      # encode: x -> sent_msg
      self.encode = nn.Linear(n_inputs, n_channels, bias=False)

      # transmit: sent_msg -> rcvd_msg
      self.transmit = nn.RNN(n_channels, n_channels)

      # decode: rcvd_msg -> action
      self.decode = nn.Sequential(
        nn.Linear(n_channels,n_outputs, bias=False),
        nn.Softmax(dim=2)
      )

    def forward(self, x, h):

      msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)
      msg = self.encode(msg)
      msg, h = self.transmit(msg, h)
      y = self.decode(msg)

      return y.squeeze(), h

In [291]:
class PonderNet(nn.Module):
  def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):
    super(PonderNet, self).__init__()

    self.output_layer = ICOM(n_inputs, n_embeddings, n_outputs)

    self.halting_layer = nn.Sequential(
      nn.Linear(n_embeddings, 1),
      nn.Sigmoid()
    )

    # \lambda_p
    self.max_steps = max_steps
    self.n_embeddings = n_embeddings
    self.n_outputs = n_outputs

  def forward(self, x):

    batch_size = x.shape[0]

    p = []
    y = []

    p_continue = torch.ones((batch_size,))
    halt = torch.zeros((batch_size,))
    p_m = torch.zeros((batch_size,))
    y_m = torch.zeros((batch_size,))
    p_n = torch.zeros((batch_size,))
    h = torch.zeros((1,batch_size,self.n_embeddings))

    for n in range(1, self.max_steps + 1):
      y_n, h = self.output_layer(x.unsqueeze(0), h)
      
      if n == self.max_steps:
        lambda_n = torch.tensor(1.)
        halt_steps = torch.empty((batch_size,)).fill_(n)
      else:
        lambda_n = self.halting_layer(h)
        halt_steps = torch.empty((batch_size,)).geometric_(lambda_n.detach()[0,0].item()) #FIXME

      if n % 500 == 0:
        print('lambda:',lambda_n)
      p_n = p_continue * lambda_n
      p_continue = p_continue * (1 - lambda_n)

      p.append(p_n)
      y.append(y_n)

      is_halted = (halt_steps <= n).type(torch.float)
      p_m = p_m * (1 - is_halted) + p_n * is_halted
      y_m = y_m * (1 - is_halted) + y_n * is_halted

      if all(halt):
        break

    return torch.stack(y), torch.stack(p), y_m, p_m

In [300]:

# split params
train_size = int(n_trials * .8)
test_size = n_trials - train_size

# training parrms
n_epoches = 100

logs = SummaryWriter()

model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)
train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))

X_train, y_train = dataset[train_subset.indices]
X_test, y_test = dataset[test_subset.indices]

for epoch in tqdm(range(n_epoches), desc='Epochs'):

  for X_batch, y_batch in DataLoader(train_subset, batch_size=1):
    model.train()
    optimizer.zero_grad()
    ys, ps, y_pred, p_m = model(X_batch)

  # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')

  model_accuracy = accuracy_score(y_batch, torch.argmax(y_pred.detach(),dim=0).unsqueeze(0))
  logs.add_scalar('accurracy/train', model_accuracy, epoch)  

  loss = criterion(y_pred.unsqueeze(0), y_batch)
  
  logs.add_scalar('loss/train', loss, epoch)

  loss.backward()
  optimizer.step()

  # model.eval()
  # with torch.no_grad():
  #   _, _, y_pred, _ = model(X_test)
  #   loss = criterion(y_test, y_pred)
  #   logs.add_scalar('loss/test', loss.detach(), epoch)

# tensorboard --logdir=runs

Epochs: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s]


In [None]:
model.eval()
_, _, y_pred, _ = model(X)
y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1
y_pred, y

(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),
 array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))

In [45]:
# example code to decode a stimulus into multiple sequence (one per channel)

import torch
from torch import nn

n_inputs = 7
max_timestep = 10
n_channels = 5

X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)

decode = nn.Linear(n_inputs, n_channels * max_timestep)
out = decode(X).reshape((n_channels, max_timestep))

print(out.shape)

torch.Size([5, 10])
