diff --git a/PonderModel.ipynb b/PonderModel.ipynb new file mode 100644 index 0000000..c629f77 --- /dev/null +++ b/PonderModel.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Intro\n", + "\n", + "In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.\n", + "\n", + "\n", + "Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.\n", + "\n", + "Given stimulus symbols as inputs, the model produces two outputs:\n", + "\n", + "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n", + "- Halting probability ($\\lambda_n$).\n", + "\n", + "Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.\n", + "\n", + "### Additional resources\n", + "\n", + "- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Problem setting\n", + "\n", + "### Model\n", + "Given input and output data, we want to learn a supervised model of the function $X \\to y$ as follows:\n", + "\n", + "$\n", + "f: X,h_n \\mapsto \\tilde{y},h_{n+1}, \\lambda_n\n", + "$\n", + "\n", + "where $X$ and $y$ denote stimulus and response symbols, $\\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.\n", + "\n", + "For the brevity and compatibility, both data are one-hot encoded.\n", + "\n", + "\n", + "### Input\n", + "\n", + "One-hot encoded symbols.\n", + "\n", + "### Output\n", + "\n", + "One-hot encoded symbols.\n", + "\n", + "### Criterion\n", + "\n", + "L = L_cross_entropy + L_halting" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 157, + "source": [ + "# Setup and imports\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset, DataLoader, random_split\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "from sklearn.metrics import accuracy_score\n", + "\n", + "import numpy as np\n", + "from scipy import stats\n", + "import pandas as pd\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns; sns.set()\n", + "\n", + "import tensorflow as tf\n", + "import tensorboard as tb\n", + "tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 158, + "source": [ + "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", + "\n", + "signal_rate = 2\n", + "noise_rate = 1\n", + "rate = signal_rate + noise_rate\n", + "max_duration_in_sec = 10.\n", + "resolution_in_sec = .1\n", + "\n", + "n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)\n", + "n_spikes = np.random.poisson(rate * max_duration_in_sec)\n", + "\n", + "# method 1: shuffle timesteps\n", + "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n", + "\n", + "# method 2: exponential isi -> timestamps\n", + "# isi = np.random.exponential(1 / rate, n_spikes)\n", + "# spike_timestamps = np.cumsum(isi)\n", + "\n", + "# method 3: homogenous spikes -> timestamps\n", + "# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Mock data" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 226, + "source": [ + "\n", + "\n", + "def generate_mock_data(n_subjects, n_trials, n_stimuli):\n", + " \"\"\"[summary]\n", + "\n", + " # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n", + "\n", + " Args:\n", + " n_subjects (int): [description]\n", + " n_trials (int): [description]\n", + " n_stimuli (int): [description]\n", + "\n", + " Returns:\n", + " (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).\n", + " \"\"\"\n", + " # stimuli\n", + " X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n", + "\n", + " # response accuracy\n", + " subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)\n", + " subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials\n", + " accuracies = np.empty(shape=(n_subjects, n_trials))\n", + " for subj in range(n_subjects):\n", + " accuracies[subj,:] = np.random.choice(\n", + " [0,1],\n", + " p=[1-subject_accuracies[subj],subject_accuracies[subj]],\n", + " size=n_trials)\n", + "\n", + " # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)\n", + " y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))\n", + "\n", + " # response time\n", + " response_times = np.random.exponential(.5, size=accuracies.shape)\n", + "\n", + " if n_subjects == 1:\n", + " X = X.squeeze()\n", + " y = y.squeeze()\n", + " accuracies = accuracies.squeeze()\n", + " response_times = response_times.squeeze()\n", + "\n", + " return X, y, accuracies, response_times" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 269, + "source": [ + "# mock data parameters\n", + "n_subjects = 1\n", + "n_trials = 20\n", + "n_stimuli = 6\n", + "\n", + "X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 270, + "source": [ + "class ICOM(nn.Module):\n", + " def __init__(self, n_inputs, n_channels, n_outputs):\n", + " super(ICOM, self).__init__()\n", + "\n", + " self.n_inputs = n_inputs\n", + " # encode: x -> sent_msg\n", + " self.encode = nn.Linear(n_inputs, n_channels, bias=False)\n", + "\n", + " # transmit: sent_msg -> rcvd_msg\n", + " self.transmit = nn.RNN(n_channels, n_channels)\n", + "\n", + " # decode: rcvd_msg -> action\n", + " self.decode = nn.Sequential(\n", + " nn.Linear(n_channels,n_outputs, bias=False),\n", + " nn.Softmax(dim=2)\n", + " )\n", + "\n", + " def forward(self, x, h):\n", + "\n", + " msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)\n", + " msg = self.encode(msg)\n", + " msg, h = self.transmit(msg, h)\n", + " y = self.decode(msg)\n", + "\n", + " return y.squeeze(), h" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 291, + "source": [ + "class PonderNet(nn.Module):\n", + " def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):\n", + " super(PonderNet, self).__init__()\n", + "\n", + " self.output_layer = ICOM(n_inputs, n_embeddings, n_outputs)\n", + "\n", + " self.halting_layer = nn.Sequential(\n", + " nn.Linear(n_embeddings, 1),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + " # \\lambda_p\n", + " self.max_steps = max_steps\n", + " self.n_embeddings = n_embeddings\n", + " self.n_outputs = n_outputs\n", + "\n", + " def forward(self, x):\n", + "\n", + " batch_size = x.shape[0]\n", + "\n", + " p = []\n", + " y = []\n", + "\n", + " p_continue = torch.ones((batch_size,))\n", + " halt = torch.zeros((batch_size,))\n", + " p_m = torch.zeros((batch_size,))\n", + " y_m = torch.zeros((batch_size,))\n", + " p_n = torch.zeros((batch_size,))\n", + " h = torch.zeros((1,batch_size,self.n_embeddings))\n", + "\n", + " for n in range(1, self.max_steps + 1):\n", + " y_n, h = self.output_layer(x.unsqueeze(0), h)\n", + " \n", + " if n == self.max_steps:\n", + " lambda_n = torch.tensor(1.)\n", + " halt_steps = torch.empty((batch_size,)).fill_(n)\n", + " else:\n", + " lambda_n = self.halting_layer(h)\n", + " halt_steps = torch.empty((batch_size,)).geometric_(lambda_n.detach()[0,0].item()) #FIXME\n", + "\n", + " if n % 500 == 0:\n", + " print('lambda:',lambda_n)\n", + " p_n = p_continue * lambda_n\n", + " p_continue = p_continue * (1 - lambda_n)\n", + "\n", + " p.append(p_n)\n", + " y.append(y_n)\n", + "\n", + " is_halted = (halt_steps <= n).type(torch.float)\n", + " p_m = p_m * (1 - is_halted) + p_n * is_halted\n", + " y_m = y_m * (1 - is_halted) + y_n * is_halted\n", + "\n", + " if all(halt):\n", + " break\n", + "\n", + " return torch.stack(y), torch.stack(p), y_m, p_m" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 300, + "source": [ + "\n", + "# split params\n", + "train_size = int(n_trials * .8)\n", + "test_size = n_trials - train_size\n", + "\n", + "# training parrms\n", + "n_epoches = 100\n", + "\n", + "logs = SummaryWriter()\n", + "\n", + "model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n", + "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n", + "\n", + "X_train, y_train = dataset[train_subset.indices]\n", + "X_test, y_test = dataset[test_subset.indices]\n", + "\n", + "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n", + "\n", + " for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " ys, ps, y_pred, p_m = model(X_batch)\n", + "\n", + " # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n", + "\n", + " model_accuracy = accuracy_score(y_batch, torch.argmax(y_pred.detach(),dim=0).unsqueeze(0))\n", + " logs.add_scalar('accurracy/train', model_accuracy, epoch) \n", + "\n", + " loss = criterion(y_pred.unsqueeze(0), y_batch)\n", + " \n", + " logs.add_scalar('loss/train', loss, epoch)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # model.eval()\n", + " # with torch.no_grad():\n", + " # _, _, y_pred, _ = model(X_test)\n", + " # loss = criterion(y_test, y_pred)\n", + " # logs.add_scalar('loss/test', loss.detach(), epoch)\n", + "\n", + "# tensorboard --logdir=runs" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Epochs: 100%|██████████| 100/100 [00:57<00:00, 1.73it/s]\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "model.eval()\n", + "_, _, y_pred, _ = model(X)\n", + "y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1\n", + "y_pred, y" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),\n", + " array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))" + ] + }, + "metadata": {}, + "execution_count": 137 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 45, + "source": [ + "# example code to decode a stimulus into multiple sequence (one per channel)\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "n_inputs = 7\n", + "max_timestep = 10\n", + "n_channels = 5\n", + "\n", + "X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)\n", + "\n", + "decode = nn.Linear(n_inputs, n_channels * max_timestep)\n", + "out = decode(X).reshape((n_channels, max_timestep))\n", + "\n", + "print(out.shape)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([5, 10])\n" + ] + } + ], + "metadata": {} + } + ], + "metadata": { + "orig_nbformat": 4, + "language_info": { + "name": "python", + "version": "3.9.4", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.4 64-bit" + }, + "interpreter": { + "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/PonderModel.ipynb b/PonderModel.ipynb new file mode 100644 index 0000000..c629f77 --- /dev/null +++ b/PonderModel.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Intro\n", + "\n", + "In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.\n", + "\n", + "\n", + "Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.\n", + "\n", + "Given stimulus symbols as inputs, the model produces two outputs:\n", + "\n", + "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n", + "- Halting probability ($\\lambda_n$).\n", + "\n", + "Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.\n", + "\n", + "### Additional resources\n", + "\n", + "- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Problem setting\n", + "\n", + "### Model\n", + "Given input and output data, we want to learn a supervised model of the function $X \\to y$ as follows:\n", + "\n", + "$\n", + "f: X,h_n \\mapsto \\tilde{y},h_{n+1}, \\lambda_n\n", + "$\n", + "\n", + "where $X$ and $y$ denote stimulus and response symbols, $\\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.\n", + "\n", + "For the brevity and compatibility, both data are one-hot encoded.\n", + "\n", + "\n", + "### Input\n", + "\n", + "One-hot encoded symbols.\n", + "\n", + "### Output\n", + "\n", + "One-hot encoded symbols.\n", + "\n", + "### Criterion\n", + "\n", + "L = L_cross_entropy + L_halting" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 157, + "source": [ + "# Setup and imports\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset, DataLoader, random_split\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "from sklearn.metrics import accuracy_score\n", + "\n", + "import numpy as np\n", + "from scipy import stats\n", + "import pandas as pd\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns; sns.set()\n", + "\n", + "import tensorflow as tf\n", + "import tensorboard as tb\n", + "tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 158, + "source": [ + "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", + "\n", + "signal_rate = 2\n", + "noise_rate = 1\n", + "rate = signal_rate + noise_rate\n", + "max_duration_in_sec = 10.\n", + "resolution_in_sec = .1\n", + "\n", + "n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)\n", + "n_spikes = np.random.poisson(rate * max_duration_in_sec)\n", + "\n", + "# method 1: shuffle timesteps\n", + "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n", + "\n", + "# method 2: exponential isi -> timestamps\n", + "# isi = np.random.exponential(1 / rate, n_spikes)\n", + "# spike_timestamps = np.cumsum(isi)\n", + "\n", + "# method 3: homogenous spikes -> timestamps\n", + "# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Mock data" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 226, + "source": [ + "\n", + "\n", + "def generate_mock_data(n_subjects, n_trials, n_stimuli):\n", + " \"\"\"[summary]\n", + "\n", + " # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n", + "\n", + " Args:\n", + " n_subjects (int): [description]\n", + " n_trials (int): [description]\n", + " n_stimuli (int): [description]\n", + "\n", + " Returns:\n", + " (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).\n", + " \"\"\"\n", + " # stimuli\n", + " X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n", + "\n", + " # response accuracy\n", + " subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)\n", + " subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials\n", + " accuracies = np.empty(shape=(n_subjects, n_trials))\n", + " for subj in range(n_subjects):\n", + " accuracies[subj,:] = np.random.choice(\n", + " [0,1],\n", + " p=[1-subject_accuracies[subj],subject_accuracies[subj]],\n", + " size=n_trials)\n", + "\n", + " # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)\n", + " y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))\n", + "\n", + " # response time\n", + " response_times = np.random.exponential(.5, size=accuracies.shape)\n", + "\n", + " if n_subjects == 1:\n", + " X = X.squeeze()\n", + " y = y.squeeze()\n", + " accuracies = accuracies.squeeze()\n", + " response_times = response_times.squeeze()\n", + "\n", + " return X, y, accuracies, response_times" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 269, + "source": [ + "# mock data parameters\n", + "n_subjects = 1\n", + "n_trials = 20\n", + "n_stimuli = 6\n", + "\n", + "X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 270, + "source": [ + "class ICOM(nn.Module):\n", + " def __init__(self, n_inputs, n_channels, n_outputs):\n", + " super(ICOM, self).__init__()\n", + "\n", + " self.n_inputs = n_inputs\n", + " # encode: x -> sent_msg\n", + " self.encode = nn.Linear(n_inputs, n_channels, bias=False)\n", + "\n", + " # transmit: sent_msg -> rcvd_msg\n", + " self.transmit = nn.RNN(n_channels, n_channels)\n", + "\n", + " # decode: rcvd_msg -> action\n", + " self.decode = nn.Sequential(\n", + " nn.Linear(n_channels,n_outputs, bias=False),\n", + " nn.Softmax(dim=2)\n", + " )\n", + "\n", + " def forward(self, x, h):\n", + "\n", + " msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)\n", + " msg = self.encode(msg)\n", + " msg, h = self.transmit(msg, h)\n", + " y = self.decode(msg)\n", + "\n", + " return y.squeeze(), h" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 291, + "source": [ + "class PonderNet(nn.Module):\n", + " def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):\n", + " super(PonderNet, self).__init__()\n", + "\n", + " self.output_layer = ICOM(n_inputs, n_embeddings, n_outputs)\n", + "\n", + " self.halting_layer = nn.Sequential(\n", + " nn.Linear(n_embeddings, 1),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + " # \\lambda_p\n", + " self.max_steps = max_steps\n", + " self.n_embeddings = n_embeddings\n", + " self.n_outputs = n_outputs\n", + "\n", + " def forward(self, x):\n", + "\n", + " batch_size = x.shape[0]\n", + "\n", + " p = []\n", + " y = []\n", + "\n", + " p_continue = torch.ones((batch_size,))\n", + " halt = torch.zeros((batch_size,))\n", + " p_m = torch.zeros((batch_size,))\n", + " y_m = torch.zeros((batch_size,))\n", + " p_n = torch.zeros((batch_size,))\n", + " h = torch.zeros((1,batch_size,self.n_embeddings))\n", + "\n", + " for n in range(1, self.max_steps + 1):\n", + " y_n, h = self.output_layer(x.unsqueeze(0), h)\n", + " \n", + " if n == self.max_steps:\n", + " lambda_n = torch.tensor(1.)\n", + " halt_steps = torch.empty((batch_size,)).fill_(n)\n", + " else:\n", + " lambda_n = self.halting_layer(h)\n", + " halt_steps = torch.empty((batch_size,)).geometric_(lambda_n.detach()[0,0].item()) #FIXME\n", + "\n", + " if n % 500 == 0:\n", + " print('lambda:',lambda_n)\n", + " p_n = p_continue * lambda_n\n", + " p_continue = p_continue * (1 - lambda_n)\n", + "\n", + " p.append(p_n)\n", + " y.append(y_n)\n", + "\n", + " is_halted = (halt_steps <= n).type(torch.float)\n", + " p_m = p_m * (1 - is_halted) + p_n * is_halted\n", + " y_m = y_m * (1 - is_halted) + y_n * is_halted\n", + "\n", + " if all(halt):\n", + " break\n", + "\n", + " return torch.stack(y), torch.stack(p), y_m, p_m" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 300, + "source": [ + "\n", + "# split params\n", + "train_size = int(n_trials * .8)\n", + "test_size = n_trials - train_size\n", + "\n", + "# training parrms\n", + "n_epoches = 100\n", + "\n", + "logs = SummaryWriter()\n", + "\n", + "model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n", + "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n", + "\n", + "X_train, y_train = dataset[train_subset.indices]\n", + "X_test, y_test = dataset[test_subset.indices]\n", + "\n", + "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n", + "\n", + " for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " ys, ps, y_pred, p_m = model(X_batch)\n", + "\n", + " # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n", + "\n", + " model_accuracy = accuracy_score(y_batch, torch.argmax(y_pred.detach(),dim=0).unsqueeze(0))\n", + " logs.add_scalar('accurracy/train', model_accuracy, epoch) \n", + "\n", + " loss = criterion(y_pred.unsqueeze(0), y_batch)\n", + " \n", + " logs.add_scalar('loss/train', loss, epoch)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # model.eval()\n", + " # with torch.no_grad():\n", + " # _, _, y_pred, _ = model(X_test)\n", + " # loss = criterion(y_test, y_pred)\n", + " # logs.add_scalar('loss/test', loss.detach(), epoch)\n", + "\n", + "# tensorboard --logdir=runs" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Epochs: 100%|██████████| 100/100 [00:57<00:00, 1.73it/s]\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "model.eval()\n", + "_, _, y_pred, _ = model(X)\n", + "y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1\n", + "y_pred, y" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),\n", + " array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))" + ] + }, + "metadata": {}, + "execution_count": 137 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 45, + "source": [ + "# example code to decode a stimulus into multiple sequence (one per channel)\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "n_inputs = 7\n", + "max_timestep = 10\n", + "n_channels = 5\n", + "\n", + "X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)\n", + "\n", + "decode = nn.Linear(n_inputs, n_channels * max_timestep)\n", + "out = decode(X).reshape((n_channels, max_timestep))\n", + "\n", + "print(out.shape)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([5, 10])\n" + ] + } + ], + "metadata": {} + } + ], + "metadata": { + "orig_nbformat": 4, + "language_info": { + "name": "python", + "version": "3.9.4", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.4 64-bit" + }, + "interpreter": { + "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/VRC_PonderRNN.ipynb b/VRC_PonderRNN.ipynb deleted file mode 100644 index 98a30aa..0000000 --- a/VRC_PonderRNN.ipynb +++ /dev/null @@ -1,430 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "source": [ - "## Intro\n", - "\n", - "In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.\n", - "\n", - "\n", - "Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.\n", - "\n", - "Given stimulus symbols as inputs, the model produces two outputs:\n", - "\n", - "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n", - "- Halting probability ($\\lambda_n$).\n", - "\n", - "Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.\n", - "\n", - "### Additional resources\n", - "\n", - "- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n" - ], - "metadata": {} - }, - { - "cell_type": "markdown", - "source": [ - "## Problem setting\n", - "\n", - "### Model\n", - "Given input and output data, we want to learn a supervised model of the function $X \\to y$ as follows:\n", - "\n", - "$\n", - "f: X,h_n \\mapsto \\tilde{y},h_{n+1}, \\lambda_n\n", - "$\n", - "\n", - "where $X$ and $y$ denote stimulus and response symbols, $\\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.\n", - "\n", - "For the brevity and compatibility, both data are one-hot encoded.\n", - "\n", - "\n", - "### Input\n", - "\n", - "One-hot encoded symbols.\n", - "\n", - "### Output\n", - "\n", - "One-hot encoded symbols.\n", - "\n", - "### Criterion\n", - "\n", - "L = L_cross_entropy + L_halting" - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 157, - "source": [ - "# Setup and imports\n", - "import torch\n", - "from torch import nn\n", - "import torch.nn.functional as F\n", - "from torch.utils.data import TensorDataset, DataLoader, random_split\n", - "from torch.utils.tensorboard import SummaryWriter\n", - "\n", - "from tqdm import tqdm\n", - "\n", - "from sklearn.metrics import accuracy_score\n", - "\n", - "import numpy as np\n", - "from scipy import stats\n", - "import pandas as pd\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns; sns.set()\n", - "\n", - "import tensorflow as tf\n", - "import tensorboard as tb\n", - "tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 158, - "source": [ - "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", - "\n", - "signal_rate = 2\n", - "noise_rate = 1\n", - "rate = signal_rate + noise_rate\n", - "max_duration_in_sec = 10.\n", - "resolution_in_sec = .1\n", - "\n", - "n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)\n", - "n_spikes = np.random.poisson(rate * max_duration_in_sec)\n", - "\n", - "# method 1: shuffle timesteps\n", - "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n", - "\n", - "# method 2: exponential isi -> timestamps\n", - "# isi = np.random.exponential(1 / rate, n_spikes)\n", - "# spike_timestamps = np.cumsum(isi)\n", - "\n", - "# method 3: homogenous spikes -> timestamps\n", - "# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "markdown", - "source": [ - "## Mock data" - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 226, - "source": [ - "\n", - "\n", - "def generate_mock_data(n_subjects, n_trials, n_stimuli):\n", - " \"\"\"[summary]\n", - "\n", - " # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n", - "\n", - " Args:\n", - " n_subjects (int): [description]\n", - " n_trials (int): [description]\n", - " n_stimuli (int): [description]\n", - "\n", - " Returns:\n", - " (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).\n", - " \"\"\"\n", - " # stimuli\n", - " X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n", - "\n", - " # response accuracy\n", - " subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)\n", - " subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials\n", - " accuracies = np.empty(shape=(n_subjects, n_trials))\n", - " for subj in range(n_subjects):\n", - " accuracies[subj,:] = np.random.choice(\n", - " [0,1],\n", - " p=[1-subject_accuracies[subj],subject_accuracies[subj]],\n", - " size=n_trials)\n", - "\n", - " # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)\n", - " y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))\n", - "\n", - " # response time\n", - " response_times = np.random.exponential(.5, size=accuracies.shape)\n", - "\n", - " if n_subjects == 1:\n", - " X = X.squeeze()\n", - " y = y.squeeze()\n", - " accuracies = accuracies.squeeze()\n", - " response_times = response_times.squeeze()\n", - "\n", - " return X, y, accuracies, response_times" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 269, - "source": [ - "# mock data parameters\n", - "n_subjects = 1\n", - "n_trials = 20\n", - "n_stimuli = 6\n", - "\n", - "X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 270, - "source": [ - "class ICOM(nn.Module):\n", - " def __init__(self, n_inputs, n_channels, n_outputs):\n", - " super(ICOM, self).__init__()\n", - "\n", - " self.n_inputs = n_inputs\n", - " # encode: x -> sent_msg\n", - " self.encode = nn.Linear(n_inputs, n_channels, bias=False)\n", - "\n", - " # transmit: sent_msg -> rcvd_msg\n", - " self.transmit = nn.RNN(n_channels, n_channels)\n", - "\n", - " # decode: rcvd_msg -> action\n", - " self.decode = nn.Sequential(\n", - " nn.Linear(n_channels,n_outputs, bias=False),\n", - " nn.Softmax(dim=2)\n", - " )\n", - "\n", - " def forward(self, x, h):\n", - "\n", - " msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)\n", - " msg = self.encode(msg)\n", - " msg, h = self.transmit(msg, h)\n", - " y = self.decode(msg)\n", - "\n", - " return y.squeeze(), h" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 291, - "source": [ - "class PonderNet(nn.Module):\n", - " def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):\n", - " super(PonderNet, self).__init__()\n", - "\n", - " self.output_layer = ICOM(n_inputs, n_embeddings, n_outputs)\n", - "\n", - " self.halting_layer = nn.Sequential(\n", - " nn.Linear(n_embeddings, 1),\n", - " nn.Sigmoid()\n", - " )\n", - "\n", - " # \\lambda_p\n", - " self.max_steps = max_steps\n", - " self.n_embeddings = n_embeddings\n", - " self.n_outputs = n_outputs\n", - "\n", - " def forward(self, x):\n", - "\n", - " batch_size = x.shape[0]\n", - "\n", - " p = []\n", - " y = []\n", - "\n", - " p_continue = torch.ones((batch_size,))\n", - " halt = torch.zeros((batch_size,))\n", - " p_m = torch.zeros((batch_size,))\n", - " y_m = torch.zeros((batch_size,))\n", - " p_n = torch.zeros((batch_size,))\n", - " h = torch.zeros((1,batch_size,self.n_embeddings))\n", - "\n", - " for n in range(1, self.max_steps + 1):\n", - " y_n, h = self.output_layer(x.unsqueeze(0), h)\n", - " \n", - " if n == self.max_steps:\n", - " lambda_n = torch.tensor(1.)\n", - " halt_steps = torch.empty((batch_size,)).fill_(n)\n", - " else:\n", - " lambda_n = self.halting_layer(h)\n", - " halt_steps = torch.empty((batch_size,)).geometric_(lambda_n.detach()[0,0].item()) #FIXME\n", - "\n", - " if n % 500 == 0:\n", - " print('lambda:',lambda_n)\n", - " p_n = p_continue * lambda_n\n", - " p_continue = p_continue * (1 - lambda_n)\n", - "\n", - " p.append(p_n)\n", - " y.append(y_n)\n", - "\n", - " is_halted = (halt_steps <= n).type(torch.float)\n", - " p_m = p_m * (1 - is_halted) + p_n * is_halted\n", - " y_m = y_m * (1 - is_halted) + y_n * is_halted\n", - "\n", - " if all(halt):\n", - " break\n", - "\n", - " return torch.stack(y), torch.stack(p), y_m, p_m" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 300, - "source": [ - "\n", - "# split params\n", - "train_size = int(n_trials * .8)\n", - "test_size = n_trials - train_size\n", - "\n", - "# training parrms\n", - "n_epoches = 100\n", - "\n", - "logs = SummaryWriter()\n", - "\n", - "model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", - "criterion = torch.nn.CrossEntropyLoss()\n", - "\n", - "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n", - "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n", - "\n", - "X_train, y_train = dataset[train_subset.indices]\n", - "X_test, y_test = dataset[test_subset.indices]\n", - "\n", - "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n", - "\n", - " for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n", - " model.train()\n", - " optimizer.zero_grad()\n", - " ys, ps, y_pred, p_m = model(X_batch)\n", - "\n", - " # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n", - "\n", - " model_accuracy = accuracy_score(y_batch, torch.argmax(y_pred.detach(),dim=0).unsqueeze(0))\n", - " logs.add_scalar('accurracy/train', model_accuracy, epoch) \n", - "\n", - " loss = criterion(y_pred.unsqueeze(0), y_batch)\n", - " \n", - " logs.add_scalar('loss/train', loss, epoch)\n", - "\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # model.eval()\n", - " # with torch.no_grad():\n", - " # _, _, y_pred, _ = model(X_test)\n", - " # loss = criterion(y_test, y_pred)\n", - " # logs.add_scalar('loss/test', loss.detach(), epoch)\n", - "\n", - "# tensorboard --logdir=runs" - ], - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epochs: 100%|██████████| 100/100 [00:57<00:00, 1.73it/s]\n" - ] - } - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "model.eval()\n", - "_, _, y_pred, _ = model(X)\n", - "y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1\n", - "y_pred, y" - ], - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),\n", - " array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))" - ] - }, - "metadata": {}, - "execution_count": 137 - } - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 45, - "source": [ - "# example code to decode a stimulus into multiple sequence (one per channel)\n", - "\n", - "import torch\n", - "from torch import nn\n", - "\n", - "n_inputs = 7\n", - "max_timestep = 10\n", - "n_channels = 5\n", - "\n", - "X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)\n", - "\n", - "decode = nn.Linear(n_inputs, n_channels * max_timestep)\n", - "out = decode(X).reshape((n_channels, max_timestep))\n", - "\n", - "print(out.shape)" - ], - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([5, 10])\n" - ] - } - ], - "metadata": {} - } - ], - "metadata": { - "orig_nbformat": 4, - "language_info": { - "name": "python", - "version": "3.9.4", - "mimetype": "text/x-python", - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "pygments_lexer": "ipython3", - "nbconvert_exporter": "python", - "file_extension": ".py" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3.9.4 64-bit ('py3': conda)" - }, - "interpreter": { - "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} \ No newline at end of file