{ "cells": [ { "cell_type": "markdown", "source": [ "Building on [PonderNet](https://arxiv.org/abs/2107.05407), this notebook implements a neural alternative of the [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0) model to produce human-like responses.\n", "\n", "Given stimulus symbols as inputs, the model produces two outputs:\n", "\n", "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n", "- Remaining entropy (to be contrasted against a decision threshold and ultimateely halt the process).\n", "\n", "Under the hood, the model uses a RNN along with multiple Poisson processes to...\n", "\n", "\n", "## Resources\n", "\n", "- [Network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Problem setting\n", "\n", "### Model\n", "Given input and output data, we want to learn a supervised model of the function $X \\to y$ as follows:\n", "\n", "$\n", "f: X,h_n \\mapsto \\tilde{y},h_{n+1}, \\lambda_n\n", "$\n", "\n", "where $X$ and $y$ denote stimulus and response symbols, $\\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.\n", "\n", "For the brevity and compatibility, both data are one-hot encoded.\n", "\n", "\n", "### Input\n", "\n", "One-hot encoded symbols.\n", "\n", "### Output\n", "\n", "One-hot encoded symbols.\n", "\n", "### Criterion\n", "\n", "L = L_cross_entropy + L_halting" ], "metadata": {} }, { "cell_type": "code", "execution_count": 206, "source": [ "# Setup and imports\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import TensorDataset, DataLoader, random_split\n", "from torch.utils.tensorboard import SummaryWriter\n", "\n", "from tqdm import tqdm\n", "\n", "from sklearn.metrics import accuracy_score\n", "\n", "import numpy as np\n", "from scipy import stats\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns; sns.set()\n", "\n", "import tensorflow as tf\n", "import tensorboard as tb\n", "tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 202, "source": [ "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", "\n", "signal_rate = 2\n", "noise_rate = 1\n", "rate = signal_rate + noise_rate\n", "max_duration_in_sec = 10.\n", "resolution_in_sec = .1\n", "\n", "n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)\n", "n_spikes = np.random.poisson(rate * max_duration_in_sec)\n", "\n", "# method 1: shuffle timesteps\n", "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n", "\n", "# method 2: exponential isi -> timestamps\n", "# isi = np.random.exponential(1 / rate, n_spikes)\n", "# spike_timestamps = np.cumsum(isi)\n", "\n", "# method 3: homogenous spikes -> timestamps\n", "# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Mock data" ], "metadata": {} }, { "cell_type": "code", "execution_count": 203, "source": [ "\n", "\n", "def generate_mock_data(n_subjects, n_trials, n_stimuli):\n", " \"\"\"[summary]\n", "\n", " # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n", "\n", " Args:\n", " n_subjects (int): [description]\n", " n_trials (int): [description]\n", " n_stimuli (int): [description]\n", "\n", " Returns:\n", " (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).\n", " \"\"\"\n", " # stimuli\n", " X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n", "\n", " # response accuracy\n", " subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)\n", " subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials\n", " accuracies = np.empty(shape=(n_subjects, n_trials))\n", " for subj in range(n_subjects):\n", " accuracies[subj,:] = np.random.choice(\n", " [0,1],\n", " p=[1-subject_accuracies[subj],subject_accuracies[subj]],\n", " size=n_trials)\n", "\n", " # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)\n", " y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))\n", "\n", " # response time\n", " response_times = np.random.exponential(.5, size=accuracies.shape)\n", "\n", " return X, y, accuracies, response_times" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 204, "source": [ "# mock data parameters\n", "n_subjects = 1\n", "n_trials = 20\n", "n_stimuli = 6\n", "\n", "X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 210, "source": [ "class PonderRNN(nn.Module):\n", " def __init__(self, n_inputs, n_channels, n_outputs, halting_prob_prior=0.2):\n", " super(PonderRNN, self).__init__()\n", " self.encode = nn.Sequential( # encode: x -> sent_msg\n", " nn.Linear(n_inputs, n_channels, bias=False),\n", " )\n", " self.transmit = nn.Sequential( # transmit: sent_msg -> rcvd_msg\n", " nn.RNN(n_channels, n_channels),\n", " )\n", " self.decode = nn.Sequential( # decode: rcvd_msg -> action\n", " nn.Linear(n_channels,n_outputs, bias=False),\n", " nn.Softmax(dim=2)\n", " )\n", "\n", " # \\lambda_p\n", " self.halting_prob = halting_prob_prior\n", "\n", " def forward(self, x):\n", " # x: one stimulus category, output: y[1..N] + halting_prob[1..N]\n", " # step 1: x -> x_n (repeat)\n", " # step 2: x_n -> y_n\n", "\n", " # VRC\n", " msg = F.one_hot(x).type(torch.float)\n", " msg = self.encode(msg)\n", " msg, _ = self.transmit(msg)\n", " msg = self.decode(msg)\n", " y = msg.squeeze()\n", "\n", " halting_n = torch.distributions.Geometric(self.halting_prob).sample().detach()\n", " # self.halting_dist = torch.distributions.Geometric(prob)\n", " # self.halting_probs = torch.cat(halting_prob, halting_prob)\n", "\n", " return y, halting_n\n", "\n", "\n", "# split params\n", "train_size = int(n_trials * .8)\n", "test_size = n_trials - train_size\n", "\n", "# training parrms\n", "n_epoches = 1500\n", "\n", "logs = SummaryWriter()\n", "\n", "model = PonderRNN(n_stimuli+1, n_stimuli, n_stimuli+1)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n", "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n", "\n", "X_test, y_test = dataset[test_subset.indices]\n", "\n", "print(X_test)\n", "\n", "xx\n", "\n", "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n", "\n", " for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n", " model.train()\n", " optimizer.zero_grad()\n", " y_pred, _ = model(X_batch)\n", "\n", " # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n", "\n", " model_accuracy = accuracy_score(y_batch.squeeze(), torch.argmax(y_pred.detach(),dim=1))\n", " logs.add_scalar('accurracy/train', model_accuracy, epoch) \n", "\n", " loss = criterion(y_pred, y_batch.squeeze())\n", " logs.add_scalar('loss/train', loss, epoch)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " model.eval()\n", " with torch.no_grad():\n", " y_pred = model(X_test)\n", " loss = criterion(y_test, y_pred)\n", " logger.add_scalar('loss/test', loss.detach(), epoch)\n", "\n", "\n", "# tensorboard --logdir=runs" ], "outputs": [ { "output_type": "error", "ename": "ValueError", "evalue": "Sum of input lengths does not equal the length of the input dataset!", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-210-d7a3afc60ba2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTensorDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mtrain_subset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_subset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlengths\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest_subset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36mrandom_split\u001b[0;34m(dataset, lengths, generator)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0;31m# Cannot verify that dataset is Sized\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 351\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Sum of input lengths does not equal the length of the input dataset!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 352\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandperm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: Sum of input lengths does not equal the length of the input dataset!" ] } ], "metadata": {} }, { "cell_type": "code", "execution_count": 137, "source": [ "model.eval()\n", "y_pred, _ = model(X_train)\n", "y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1\n", "y_pred, y" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),\n", " array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))" ] }, "metadata": {}, "execution_count": 137 } ], "metadata": {} }, { "cell_type": "code", "execution_count": 199, "source": [ "halting_prob = .1\n", "x = torch.distributions.Geometric(halting_prob).sample().item()\n", "type(x)\n", "# .2*.2*.2*.8" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "float" ] }, "metadata": {}, "execution_count": 199 } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.4", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.4 64-bit ('py3': conda)" }, "interpreter": { "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" } }, "nbformat": 4, "nbformat_minor": 2 }