Building on [PonderNet](https://arxiv.org/abs/2107.05407), this notebook implements a neural alternative of the  [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0) model to produce human-like responses.

Given stimulus symbols as inputs, the model produces two outputs:

- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).
- Remaining entropy (to be contrasted against a decision threshold and ultimateely halt the process).

Under the hood, the model uses a RNN along with multiple Poisson processes to...


## Resources

- [Network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)


In [3]:
# Setup and imports
import torch
from torch import nn

import numpy as np
from scipy import stats
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

In [170]:
# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.

signal_rate = 2
noise_rate = 1
rate = signal_rate + noise_rate
duration_in_sec = 10.
resolution_in_sec = .1

n_total_timesteps = int(duration_in_sec / resolution_in_sec)
n_spikes = np.random.poisson(rate * duration_in_sec)

# method 1: shuffle timesteps
spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))

# method 2: exponential isi -> timestamps
# isi = np.random.exponential(1 / rate, n_spikes)
# spike_timestamps = np.cumsum(isi)

# method 3: homogenous spikes -> timestamps
# spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)


## RNN

In [174]:
class PonderVRC(nn.Module):
  def __init__(self, n_inputs, n_channels):
    super(PonderVRC, self).__init__()
    self.rnn = nn.RNN(n_inputs, n_channels)
    self.fc1 = nn.Linear(n_channels, n_inputs, bias=False)
    self.fc2 = nn.Linear(n_channels,1, bias=False)

  def forward(self, x):
    h = self.rnn(x)
    y = self.fc1(h)
    y = self.fc1(h)

    return y

## Mock data

In [330]:
n_trials = 30
n_stimuli = 6
n_subjects = 1

# required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time
# TODO: generate random data and reshape into the following

# stimuli
X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))

# accuracy (index=0)
accuracies = np.random.randint(low=0, high=2, size=(n_subjects, n_trials))
response_times = np.random.exponential(.5, size=(n_subjects, n_trials))

response_times
# responses = np.empty((n_subjects, n_trials, 2))
# responses[:,:,0] = np.where(accuracies==1., X, )
# response_time (index=1)
# responses[:,:,1].exponential_(.5)



array([[0.89434811, 0.64370202, 0.19886853, 1.39208346, 0.41082363,
        0.08900332, 1.11360565, 0.46728826, 0.36291653, 0.67963475,
        0.45148227, 0.38839379, 0.64743332, 0.41294597, 0.45289691,
        0.13357337, 0.85012272, 0.7988117 , 1.23502906, 0.53615726,
        0.07061297, 0.80473662, 0.38354505, 0.58555392, 0.38719181,
        0.42993123, 0.23014178, 0.13333575, 0.26819837, 0.28917237]])

In [176]:

n_epoches = 10

model = PonderVRC(10,10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

for epoch in range(n_epoches):
  model.train()
  optimizer.zero_grad()
  x = ...
  y_true = ...
  y_pred = model(x)

  loss = criterion(y_pred, y_pred)

  loss.backward()
  optimizer.step()

AssertionError: 