Newer
Older
notebooks / PonderModel.ipynb
{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# PonderICOM: Joint Modeling of Accuracy and Speed in Cognitive Tasks\n",
    "## Intro\n",
    "\n",
    "In the context of behavioral data, we are interested in simultaneously modeling speed and accuracy. Yet, most advanced techniques in machine learning cannot capture such a duality of decision making data.\n",
    "\n",
    "\n",
    "Building on [PonderNet](https://arxiv.org/abs/2107.05407) and [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0), this notebook implements a neural model that captures speed and accuracy of human-like responses.\n",
    "\n",
    "Given stimulus symbols as inputs, the model produces two outputs:\n",
    "\n",
    "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n",
    "- Halting probability ($\\lambda_n$).\n",
    "\n",
    "Under the hood, the model iterates over a ICOM-like component to reach a halting point in time. Unlike DDM and ICOM models, all the parameters and outcomes of the current model *seem* cognitively interpretable.\n",
    "\n",
    "### Additional resources\n",
    "\n",
    "- [ICOM network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Problem setting\n",
    "\n",
    "### Model\n",
    "Given input and output data, we want to learn a supervised model of the function $X \\to y$ as follows:\n",
    "\n",
    "$\n",
    "f: X,h_n \\mapsto \\tilde{y},h_{n+1}, \\lambda_n\n",
    "$\n",
    "\n",
    "where $X$ and $y$ denote stimulus and response symbols, $\\lambda_n$ denotes halting probability at time $n$, and $h_{n}$ is the latent state of the model. The learninig continious up to the time point $N$.\n",
    "\n",
    "For the brevity and compatibility, both data are one-hot encoded.\n",
    "\n",
    "\n",
    "### Input\n",
    "\n",
    "One-hot encoded symbols.\n",
    "\n",
    "### Output\n",
    "\n",
    "One-hot encoded symbols.\n",
    "\n",
    "### Criterion\n",
    "\n",
    "L = L_cross_entropy + L_halting"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "source": [
    "# Setup and imports\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader, random_split\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorboard as tb\n",
    "tf.io.gfile = tb.compat.tensorflow_stub.io.gfile #FIX storing embeddings using tensorboard"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "source": [
    "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n",
    "\n",
    "signal_rate = 2\n",
    "noise_rate = 1\n",
    "rate = signal_rate + noise_rate\n",
    "max_duration_in_sec = 10.\n",
    "resolution_in_sec = .1\n",
    "\n",
    "n_total_timesteps = int(max_duration_in_sec / resolution_in_sec)\n",
    "n_spikes = np.random.poisson(rate * max_duration_in_sec)\n",
    "\n",
    "# method 1: shuffle timesteps\n",
    "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n",
    "\n",
    "# method 2: exponential isi -> timestamps\n",
    "# isi = np.random.exponential(1 / rate, n_spikes)\n",
    "# spike_timestamps = np.cumsum(isi)\n",
    "\n",
    "# method 3: homogenous spikes -> timestamps\n",
    "# spike_timestamps = stats.uniform.rvs(loc=0, scale=max_duration_in_sec, size=n_spikes)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Mock data"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "source": [
    "\n",
    "\n",
    "def generate_mock_data(n_subjects, n_trials, n_stimuli):\n",
    "  \"\"\"[summary]\n",
    "\n",
    "  # TODO required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n",
    "\n",
    "  Args:\n",
    "      n_subjects (int): [description]\n",
    "      n_trials (int): [description]\n",
    "      n_stimuli (int): [description]\n",
    "\n",
    "  Returns:\n",
    "      (X, accuracies, response_times): A tuple containing generated mock X, accuracies, and response_times (in sec).\n",
    "  \"\"\"\n",
    "  # stimuli\n",
    "  X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n",
    "\n",
    "  # response accuracy\n",
    "  subject_accuracies = np.random.uniform(low=0.2, high=1.0, size=n_subjects)\n",
    "  subject_accuracies = np.round(subject_accuracies * n_trials) / n_trials\n",
    "  accuracies = np.empty(shape=(n_subjects, n_trials))\n",
    "  for subj in range(n_subjects):\n",
    "    accuracies[subj,:] = np.random.choice(\n",
    "      [0,1],\n",
    "      p=[1-subject_accuracies[subj],subject_accuracies[subj]],\n",
    "      size=n_trials)\n",
    "\n",
    "  # generate output w.r.t the accuracy (and fill incorrect trials with invalid response)\n",
    "  y = np.where(accuracies == 1., X, X+1 % (n_stimuli+1))\n",
    "\n",
    "  # response time\n",
    "  response_times = np.random.exponential(.5, size=accuracies.shape)\n",
    "\n",
    "  if n_subjects == 1:\n",
    "    X = X.squeeze()\n",
    "    y = y.squeeze()\n",
    "    accuracies = accuracies.squeeze()\n",
    "    response_times = response_times.squeeze()\n",
    "\n",
    "  return X, y, accuracies, response_times"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "source": [
    "# mock data parameters\n",
    "n_subjects = 1\n",
    "n_trials = 20\n",
    "n_stimuli = 6\n",
    "\n",
    "X, y, accuracies, response_times = generate_mock_data(n_subjects, n_trials, n_stimuli)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "source": [
    "class ICOM(nn.Module):\n",
    "    def __init__(self, n_inputs, n_channels, n_outputs):\n",
    "      super(ICOM, self).__init__()\n",
    "\n",
    "      self.n_inputs = n_inputs\n",
    "      # encode: x -> sent_msg\n",
    "      self.encode = nn.Linear(n_inputs, n_channels, bias=False)\n",
    "\n",
    "      # transmit: sent_msg -> rcvd_msg\n",
    "      self.transmit = nn.RNN(n_channels, n_channels)\n",
    "\n",
    "      # decode: rcvd_msg -> action\n",
    "      self.decode = nn.Sequential(\n",
    "        nn.Linear(n_channels,n_outputs, bias=False),\n",
    "        nn.Softmax(dim=2)\n",
    "      )\n",
    "\n",
    "    def forward(self, x, h):\n",
    "\n",
    "      msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)\n",
    "      msg = self.encode(msg)\n",
    "      msg, h = self.transmit(msg, h)\n",
    "      y = self.decode(msg)\n",
    "\n",
    "      return y.squeeze(), h"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "source": [
    "class PonderNet(nn.Module):\n",
    "  def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):\n",
    "    super(PonderNet, self).__init__()\n",
    "\n",
    "    self.output_layer = ICOM(n_inputs, n_embeddings, n_outputs)\n",
    "\n",
    "    self.halting_layer = nn.Sequential(\n",
    "      nn.Linear(n_embeddings, 1),\n",
    "      nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "    # \\lambda_p\n",
    "    self.max_steps = max_steps\n",
    "    self.n_embeddings = n_embeddings\n",
    "    self.n_outputs = n_outputs\n",
    "\n",
    "  def forward(self, x):\n",
    "\n",
    "    batch_size = x.shape[0]\n",
    "\n",
    "    p = []\n",
    "    y = []\n",
    "\n",
    "    p_continue = torch.ones((batch_size,))\n",
    "    halt = torch.zeros((batch_size,))\n",
    "    p_m = torch.zeros((batch_size,))\n",
    "    y_m = torch.zeros((batch_size,))\n",
    "    p_n = torch.zeros((batch_size,))\n",
    "    h = torch.zeros((1,batch_size,self.n_embeddings))\n",
    "\n",
    "    for n in range(1, self.max_steps + 1):\n",
    "      y_n, h = self.output_layer(x.unsqueeze(0), h)\n",
    "      \n",
    "      if n == self.max_steps:\n",
    "        lambda_n = torch.tensor(1.)\n",
    "        halt_steps = torch.empty((batch_size,)).fill_(n)\n",
    "      else:\n",
    "        lambda_n = self.halting_layer(h)\n",
    "        halt_steps = torch.empty((batch_size,)).geometric_(lambda_n.detach()[0,0].item()) #FIXME\n",
    "\n",
    "      if n % 500 == 0:\n",
    "        print('lambda:',lambda_n)\n",
    "      p_n = p_continue * lambda_n\n",
    "      p_continue = p_continue * (1 - lambda_n)\n",
    "\n",
    "      p.append(p_n)\n",
    "      y.append(y_n)\n",
    "\n",
    "      is_halted = (halt_steps <= n).type(torch.float)\n",
    "      p_m = p_m * (1 - is_halted) + p_n * is_halted\n",
    "      y_m = y_m * (1 - is_halted) + y_n * is_halted\n",
    "\n",
    "      if all(halt):\n",
    "        break\n",
    "\n",
    "    return torch.stack(y), torch.stack(p), y_m, p_m"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 300,
   "source": [
    "\n",
    "# split params\n",
    "train_size = int(n_trials * .8)\n",
    "test_size = n_trials - train_size\n",
    "\n",
    "# training parrms\n",
    "n_epoches = 100\n",
    "\n",
    "logs = SummaryWriter()\n",
    "\n",
    "model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n",
    "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n",
    "\n",
    "X_train, y_train = dataset[train_subset.indices]\n",
    "X_test, y_test = dataset[test_subset.indices]\n",
    "\n",
    "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n",
    "\n",
    "  for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    ys, ps, y_pred, p_m = model(X_batch)\n",
    "\n",
    "  # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n",
    "\n",
    "  model_accuracy = accuracy_score(y_batch, torch.argmax(y_pred.detach(),dim=0).unsqueeze(0))\n",
    "  logs.add_scalar('accurracy/train', model_accuracy, epoch)  \n",
    "\n",
    "  loss = criterion(y_pred.unsqueeze(0), y_batch)\n",
    "  \n",
    "  logs.add_scalar('loss/train', loss, epoch)\n",
    "\n",
    "  loss.backward()\n",
    "  optimizer.step()\n",
    "\n",
    "  # model.eval()\n",
    "  # with torch.no_grad():\n",
    "  #   _, _, y_pred, _ = model(X_test)\n",
    "  #   loss = criterion(y_test, y_pred)\n",
    "  #   logs.add_scalar('loss/test', loss.detach(), epoch)\n",
    "\n",
    "# tensorboard --logdir=runs"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Epochs: 100%|██████████| 100/100 [00:57<00:00,  1.73it/s]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "model.eval()\n",
    "_, _, y_pred, _ = model(X)\n",
    "y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1\n",
    "y_pred, y"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(array([2, 6, 5, 2, 6, 4, 4, 7, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]),\n",
       " array([[2, 6, 5, 1, 6, 4, 4, 6, 3, 7, 5, 3, 4, 3, 3, 6, 3, 5, 7, 2]]))"
      ]
     },
     "metadata": {},
     "execution_count": 137
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "source": [
    "# example code to decode a stimulus into multiple sequence (one per channel)\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "n_inputs = 7\n",
    "max_timestep = 10\n",
    "n_channels = 5\n",
    "\n",
    "X = torch.nn.functional.one_hot(torch.tensor(4), num_classes=n_inputs).type(torch.float)\n",
    "\n",
    "decode = nn.Linear(n_inputs, n_channels * max_timestep)\n",
    "out = decode(X).reshape((n_channels, max_timestep))\n",
    "\n",
    "print(out.shape)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "torch.Size([5, 10])\n"
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.9.4",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.9.4 64-bit ('py3': conda)"
  },
  "interpreter": {
   "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}