diff --git a/vrc_pondernet.ipynb b/vrc_pondernet.ipynb new file mode 100644 index 0000000..64304c6 --- /dev/null +++ b/vrc_pondernet.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "Building on [PonderNet](https://arxiv.org/abs/2107.05407), this notebook implements a neural alternative of the [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0) model to produce human-like responses.\n", + "\n", + "Given stimulus symbols as inputs, the model produces two outputs:\n", + "\n", + "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n", + "- Remaining entropy (to be contrasted against a decision threshold and ultimateely halt the process).\n", + "\n", + "Under the hood, the model uses a RNN along with multiple Poisson processes to...\n", + "\n", + "\n", + "## Resources\n", + "\n", + "- [Network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 3, + "source": [ + "# Setup and imports\n", + "import torch\n", + "from torch import nn\n", + "\n", + "import numpy as np\n", + "from scipy import stats\n", + "import pandas as pd\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns; sns.set()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 170, + "source": [ + "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", + "\n", + "signal_rate = 2\n", + "noise_rate = 1\n", + "rate = signal_rate + noise_rate\n", + "duration_in_sec = 10.\n", + "resolution_in_sec = .1\n", + "\n", + "n_total_timesteps = int(duration_in_sec / resolution_in_sec)\n", + "n_spikes = np.random.poisson(rate * duration_in_sec)\n", + "\n", + "# method 1: shuffle timesteps\n", + "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n", + "\n", + "# method 2: exponential isi -> timestamps\n", + "# isi = np.random.exponential(1 / rate, n_spikes)\n", + "# spike_timestamps = np.cumsum(isi)\n", + "\n", + "# method 3: homogenous spikes -> timestamps\n", + "# spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## RNN" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 174, + "source": [ + "class PonderVRC(nn.Module):\n", + " def __init__(self, n_inputs, n_channels):\n", + " super(PonderVRC, self).__init__()\n", + " self.rnn = nn.RNN(n_inputs, n_channels)\n", + " self.fc1 = nn.Linear(n_channels, n_inputs, bias=False)\n", + " self.fc2 = nn.Linear(n_channels,1, bias=False)\n", + "\n", + " def forward(self, x):\n", + " h = self.rnn(x)\n", + " y = self.fc1(h)\n", + " y = self.fc1(h)\n", + "\n", + " return y" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Mock data" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 330, + "source": [ + "n_trials = 30\n", + "n_stimuli = 6\n", + "n_subjects = 1\n", + "\n", + "# required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n", + "# TODO: generate random data and reshape into the following\n", + "\n", + "# stimuli\n", + "X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n", + "\n", + "# accuracy (index=0)\n", + "accuracies = np.random.randint(low=0, high=2, size=(n_subjects, n_trials))\n", + "response_times = np.random.exponential(.5, size=(n_subjects, n_trials))\n", + "\n", + "response_times\n", + "# responses = np.empty((n_subjects, n_trials, 2))\n", + "# responses[:,:,0] = np.where(accuracies==1., X, )\n", + "# response_time (index=1)\n", + "# responses[:,:,1].exponential_(.5)\n", + "\n" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[0.89434811, 0.64370202, 0.19886853, 1.39208346, 0.41082363,\n", + " 0.08900332, 1.11360565, 0.46728826, 0.36291653, 0.67963475,\n", + " 0.45148227, 0.38839379, 0.64743332, 0.41294597, 0.45289691,\n", + " 0.13357337, 0.85012272, 0.7988117 , 1.23502906, 0.53615726,\n", + " 0.07061297, 0.80473662, 0.38354505, 0.58555392, 0.38719181,\n", + " 0.42993123, 0.23014178, 0.13333575, 0.26819837, 0.28917237]])" + ] + }, + "metadata": {}, + "execution_count": 330 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 176, + "source": [ + "\n", + "n_epoches = 10\n", + "\n", + "model = PonderVRC(10,10)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.BCELoss()\n", + "\n", + "for epoch in range(n_epoches):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " x = ...\n", + " y_true = ...\n", + " y_pred = model(x)\n", + "\n", + " loss = criterion(y_pred, y_pred)\n", + "\n", + " loss.backward()\n", + " optimizer.step()" + ], + "outputs": [ + { + "output_type": "error", + "ename": "AssertionError", + "evalue": "", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0mmax_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_sizes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 244\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 245\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0mmax_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_first\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "metadata": {} + } + ], + "metadata": { + "orig_nbformat": 4, + "language_info": { + "name": "python", + "version": "3.9.4", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.4 64-bit ('py3': conda)" + }, + "interpreter": { + "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/vrc_pondernet.ipynb b/vrc_pondernet.ipynb new file mode 100644 index 0000000..64304c6 --- /dev/null +++ b/vrc_pondernet.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "Building on [PonderNet](https://arxiv.org/abs/2107.05407), this notebook implements a neural alternative of the [Variable Rate Coding](https://doi.org/10.32470/CCN.2019.1397-0) model to produce human-like responses.\n", + "\n", + "Given stimulus symbols as inputs, the model produces two outputs:\n", + "\n", + "- Response symbol, which, in comparison with the input stimuli, can be used to measure accuracy).\n", + "- Remaining entropy (to be contrasted against a decision threshold and ultimateely halt the process).\n", + "\n", + "Under the hood, the model uses a RNN along with multiple Poisson processes to...\n", + "\n", + "\n", + "## Resources\n", + "\n", + "- [Network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing)\n" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 3, + "source": [ + "# Setup and imports\n", + "import torch\n", + "from torch import nn\n", + "\n", + "import numpy as np\n", + "from scipy import stats\n", + "import pandas as pd\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns; sns.set()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 170, + "source": [ + "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", + "\n", + "signal_rate = 2\n", + "noise_rate = 1\n", + "rate = signal_rate + noise_rate\n", + "duration_in_sec = 10.\n", + "resolution_in_sec = .1\n", + "\n", + "n_total_timesteps = int(duration_in_sec / resolution_in_sec)\n", + "n_spikes = np.random.poisson(rate * duration_in_sec)\n", + "\n", + "# method 1: shuffle timesteps\n", + "spike_timesteps = np.sort(np.random.choice(n_total_timesteps, size=n_spikes, replace=False))\n", + "\n", + "# method 2: exponential isi -> timestamps\n", + "# isi = np.random.exponential(1 / rate, n_spikes)\n", + "# spike_timestamps = np.cumsum(isi)\n", + "\n", + "# method 3: homogenous spikes -> timestamps\n", + "# spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## RNN" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 174, + "source": [ + "class PonderVRC(nn.Module):\n", + " def __init__(self, n_inputs, n_channels):\n", + " super(PonderVRC, self).__init__()\n", + " self.rnn = nn.RNN(n_inputs, n_channels)\n", + " self.fc1 = nn.Linear(n_channels, n_inputs, bias=False)\n", + " self.fc2 = nn.Linear(n_channels,1, bias=False)\n", + "\n", + " def forward(self, x):\n", + " h = self.rnn(x)\n", + " y = self.fc1(h)\n", + " y = self.fc1(h)\n", + "\n", + " return y" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## Mock data" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 330, + "source": [ + "n_trials = 30\n", + "n_stimuli = 6\n", + "n_subjects = 1\n", + "\n", + "# required data columns: subject_index, trial_index, stimulus_index, accuracy, response_time\n", + "# TODO: generate random data and reshape into the following\n", + "\n", + "# stimuli\n", + "X = np.random.randint(low=1, high=n_stimuli+1, size=(n_subjects, n_trials))\n", + "\n", + "# accuracy (index=0)\n", + "accuracies = np.random.randint(low=0, high=2, size=(n_subjects, n_trials))\n", + "response_times = np.random.exponential(.5, size=(n_subjects, n_trials))\n", + "\n", + "response_times\n", + "# responses = np.empty((n_subjects, n_trials, 2))\n", + "# responses[:,:,0] = np.where(accuracies==1., X, )\n", + "# response_time (index=1)\n", + "# responses[:,:,1].exponential_(.5)\n", + "\n" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[0.89434811, 0.64370202, 0.19886853, 1.39208346, 0.41082363,\n", + " 0.08900332, 1.11360565, 0.46728826, 0.36291653, 0.67963475,\n", + " 0.45148227, 0.38839379, 0.64743332, 0.41294597, 0.45289691,\n", + " 0.13357337, 0.85012272, 0.7988117 , 1.23502906, 0.53615726,\n", + " 0.07061297, 0.80473662, 0.38354505, 0.58555392, 0.38719181,\n", + " 0.42993123, 0.23014178, 0.13333575, 0.26819837, 0.28917237]])" + ] + }, + "metadata": {}, + "execution_count": 330 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 176, + "source": [ + "\n", + "n_epoches = 10\n", + "\n", + "model = PonderVRC(10,10)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.BCELoss()\n", + "\n", + "for epoch in range(n_epoches):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " x = ...\n", + " y_true = ...\n", + " y_pred = model(x)\n", + "\n", + " loss = criterion(y_pred, y_pred)\n", + "\n", + " loss.backward()\n", + " optimizer.step()" + ], + "outputs": [ + { + "output_type": "error", + "ename": "AssertionError", + "evalue": "", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0my_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0mmax_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_sizes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 244\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 245\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[0mmax_batch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_first\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "metadata": {} + } + ], + "metadata": { + "orig_nbformat": 4, + "language_info": { + "name": "python", + "version": "3.9.4", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.4 64-bit ('py3': conda)" + }, + "interpreter": { + "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/vrc_torch.ipynb b/vrc_torch.ipynb deleted file mode 100644 index 040492e..0000000 --- a/vrc_torch.ipynb +++ /dev/null @@ -1,148 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "source": [ - "A variable rate coding model of human respose time can be implemented using PyTorch (see the [network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing))." - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 68, - "source": [ - "import torch\n", - "from torch import nn\n", - "\n", - "import numpy as np\n", - "from scipy import stats\n", - "import pandas as pd\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns; sns.set()" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "markdown", - "source": [ - "## Load and prepare the data" - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 171, - "source": [ - "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", - "\n", - "rate = 2\n", - "duration_in_sec = 10.\n", - "\n", - "n_spikes = np.random.poisson(rate * duration_in_sec)\n", - "spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n", - "spike_timestamps = np.sort(spike_timestamps)\n", - "spike_timestamps = spike_timestamps[spike_timestamps