Newer
Older
notebooks / vrc_pondernet.ipynb
Morteza Ansarinia on 30 Aug 2021 12 KB docs
{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "Building on [PonderNet](https://arxiv.org/abs/2107.05407), this notebook implements a neural alternative of the  [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0) model to produce human-like responses.\n",
    "\n",
    "Given stimulus symbols as inputs, the model produces two outputs:\n",
    "\n",
    "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n",
    "- Remaining entropy (to be contrasted against a decision threshold and ultimateely halt the process).\n",
    "\n",
    "Under the hood, the model uses a RNN along with multiple Poisson processes to...\n",
    "\n",
    "\n",
    "## Resources\n",
    "\n",
    "- [Network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "# Setup and imports\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns; sns.set()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "source": [
    "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n",
    "\n",
    "signal_rate = 2\n",
    "noise_rate = 1\n",
    "rate = signal_rate + noise_rate\n",
    "duration_in_sec = 10.\n",
    "resolution_in_sec = .1\n",
    "\n",
    "n_total_timesteps = int(duration_in_sec / resolution_in_sec)\n",
    "n_spikes = np.random.poisson(rate * duration_in_sec)\n",
    "\n",
    "# method 1: shuffle timesteps\n",
    "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n",
    "\n",
    "# method 2: exponential isi -> timestamps\n",
    "# isi = np.random.exponential(1 / rate, n_spikes)\n",
    "# spike_timestamps = np.cumsum(isi)\n",
    "\n",
    "# method 3: homogenous spikes -> timestamps\n",
    "# spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## RNN"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "source": [
    "class PonderVRC(nn.Module):\n",
    "  def __init__(self, n_inputs, n_channels):\n",
    "    super(PonderVRC, self).__init__()\n",
    "    self.rnn = nn.RNN(n_inputs, n_channels)\n",
    "    self.fc1 = nn.Linear(n_channels, n_inputs, bias=False)\n",
    "    self.fc2 = nn.Linear(n_channels,1, bias=False)\n",
    "\n",
    "  def forward(self, x):\n",
    "    h = self.rnn(x)\n",
    "    y = self.fc1(h)\n",
    "    y = self.fc1(h)\n",
    "\n",
    "    return y"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Mock data"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 330,
   "source": [
    "n_trials = 30\n",
    "n_stimuli = 6\n",
    "n_subjects = 1\n",
    "\n",
    "# required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n",
    "# TODO: generate random data and reshape into the following\n",
    "\n",
    "# stimuli\n",
    "X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n",
    "\n",
    "# accuracy (index=0)\n",
    "accuracies = np.random.randint(low=0, high=2, size=(n_subjects, n_trials))\n",
    "response_times = np.random.exponential(.5, size=(n_subjects, n_trials))\n",
    "\n",
    "response_times\n",
    "# responses = np.empty((n_subjects, n_trials, 2))\n",
    "# responses[:,:,0] = np.where(accuracies==1., X, )\n",
    "# response_time (index=1)\n",
    "# responses[:,:,1].exponential_(.5)\n",
    "\n"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([[0.89434811, 0.64370202, 0.19886853, 1.39208346, 0.41082363,\n",
       "        0.08900332, 1.11360565, 0.46728826, 0.36291653, 0.67963475,\n",
       "        0.45148227, 0.38839379, 0.64743332, 0.41294597, 0.45289691,\n",
       "        0.13357337, 0.85012272, 0.7988117 , 1.23502906, 0.53615726,\n",
       "        0.07061297, 0.80473662, 0.38354505, 0.58555392, 0.38719181,\n",
       "        0.42993123, 0.23014178, 0.13333575, 0.26819837, 0.28917237]])"
      ]
     },
     "metadata": {},
     "execution_count": 330
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "source": [
    "\n",
    "n_epoches = 10\n",
    "\n",
    "model = PonderVRC(10,10)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "criterion = torch.nn.BCELoss()\n",
    "\n",
    "for epoch in range(n_epoches):\n",
    "  model.train()\n",
    "  optimizer.zero_grad()\n",
    "  x = ...\n",
    "  y_true = ...\n",
    "  y_pred = model(x)\n",
    "\n",
    "  loss = criterion(y_pred, y_pred)\n",
    "\n",
    "  loss.backward()\n",
    "  optimizer.step()"
   ],
   "outputs": [
    {
     "output_type": "error",
     "ename": "AssertionError",
     "evalue": "",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-176-fe19421f351b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m   \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m   \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m   \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m   \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-174-11ee17fdc0d0>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m     \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m     \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    887\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    888\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    890\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    891\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m    242\u001b[0m             \u001b[0mmax_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_sizes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    243\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 244\u001b[0;31m             \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    245\u001b[0m             \u001b[0mbatch_sizes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    246\u001b[0m             \u001b[0mmax_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_first\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.9.4",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.9.4 64-bit ('py3': conda)"
  },
  "interpreter": {
   "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}