Building on [PonderNet](https://arxiv.org/abs/2107.05407), this notebook implements a neural alternative of the [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0) model to produce human-like responses.

Given stimulus symbols as inputs, the model produces two outputs:

- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).
- Remaining entropy (to be contrasted against a decision threshold and ultimateely halt the process).

Under the hood, the model uses a RNN along with multiple Poisson processes to...


## Resources

- [Network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)


## Problem setting

### Model
Given input and output data, we want to learn a supervised model of the function $X \to y$ as follows:

$
f: X,h_n \mapsto \tilde{y},h_{n+1}, \lambda_n
$

where $X$ and $y$ denote stimulus and response symbols, $\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.

For the brevity and compatibility, both data are one-hot encoded.


### Input

One-hot encoded symbols.

### Output

One-hot encoded symbols.

### Criterion

L = L_cross_entropy + L_halting

In [114]:
# Setup and imports
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

from sklearn.metrics import accuracy_score

import numpy as np
from scipy import stats
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard

In [2]:
# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.

signal_rate = 2
noise_rate = 1
rate = signal_rate + noise_rate
max_duration_in_sec = 10.
resolution_in_sec = .1

n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)
n_spikes = np.random.poisson(rate * max_duration_in_sec)

# method 1: shuffle timesteps
spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))

# method 2: exponential isi -> timestamps
# isi = np.random.exponential(1 / rate, n_spikes)
# spike_timestamps = np.cumsum(isi)

# method 3: homogenous spikes -> timestamps
# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)

## Mock data

In [3]:


def generate_mock_data(n_subjects, n_trials, n_stimuli):
 """[summary]

 # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time

 Args:
 n_subjects (int): [description]
 n_trials (int): [description]
 n_stimuli (int): [description]

 Returns:
 (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).
 """
 # stimuli
 X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))

 # response accuracy
 subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)
 subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials
 accuracies = np.empty(shape=(n_subjects, n_trials))
 for subj in range(n_subjects):
 accuracies[subj,:] = np.random.choice(
 [0,1],
 p=[1-subject_accuracies[subj],subject_accuracies[subj]],
 size=n_trials)

 # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)
 y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))

 # response time
 response_times = np.random.exponential(.5, size=accuracies.shape)

 return X, y, accuracies, response_times

In [4]:
# mock data parameters
n_subjects = 1
n_trials = 20
n_stimuli = 6

X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)

In [117]:
class PonderRNN(nn.Module):
 def __init__(self, n_inputs, n_channels, n_outputs, halting_prob_prior=0.0):
 super(PonderRNN, self).__init__()
 self.encode = nn.Sequential( # encode: x -> sent_msg
 nn.Linear(n_inputs, n_channels, bias=False),
 )
 self.transmit = nn.Sequential( # transmit: sent_msg -> rcvd_msg
 nn.RNN(n_channels, n_channels),
 )
 self.decode = nn.Sequential( # decode: rcvd_msg -> action
 nn.Linear(n_channels,n_outputs, bias=False),
 nn.Softmax(dim=2)
 )

 # \lambda_p
 self.halting_prob_prior = halting_prob_prior

 def forward(self, x):
 # x: one stimulus category, output: y[1..N] + halting_prob[1..N]
 # step 1: x -> x_n (repeat)
 # step 2: x_n -> y_n

 # VRC
 msg = F.one_hot(x).type(torch.float)
 msg = self.encode(msg)
 msg, _ = self.transmit(msg)
 msg = self.decode(msg)
 y = msg.squeeze()

 # TODO
 halting_prob = 0.
 # lambda_n = ...
 # halting_prob = torch.distributions.Geometric(self.halting_prob_prior)
 # self.halting_dist = torch.distributions.Geometric(prob)
 # self.halting_probs = torch.cat(halting_prob, halting_prob)

 return y, halting_prob


n_epoches = 1500

logs = SummaryWriter()

model = PonderRNN(n_stimuli+1, n_stimuli, n_stimuli+1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

X_train = torch.tensor(X)
y_train = torch.tensor(y) - 1

for epoch in tqdm(range(n_epoches), desc='Epochs'):
 model.train()
 optimizer.zero_grad()
 y_pred, _ = model(X_train)

 # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')

 model_accuracy = accuracy_score(y_train.squeeze(), torch.argmax(y_pred.detach(),dim=1))
 logs.add_scalar('accurracy/train', model_accuracy, epoch) 

 loss = criterion(y_pred, y_train.squeeze())

 logs.add_scalar('loss/train', loss, epoch)
 loss.backward()
 optimizer.step()

# tensorboard --logdir=runs

Epochs: 100%|██████████| 1500/1500 [00:02<00:00, 659.68it/s]


In [137]:
model.eval()
y_pred, _ = model(X_train)
y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1
y_pred, y

(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),
 array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))

In [107]:
torch.distributions.Geometric(.01).sample()

tensor(56.)