{ "cells": [ { "cell_type": "markdown", "source": [ "A variable rate coding model of human respose time can be implemented using PyTorch (see the [network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing))." ], "metadata": {} }, { "cell_type": "code", "execution_count": 68, "source": [ "import torch\n", "from torch import nn\n", "\n", "import numpy as np\n", "from scipy import stats\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns; sns.set()" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Encode the input" ], "metadata": {} }, { "cell_type": "code", "execution_count": 171, "source": [ "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", "\n", "rate = 2\n", "duration_in_sec = 10.\n", "\n", "n_spikes = np.random.poisson(rate * duration_in_sec)\n", "spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n", "spike_timestamps = np.sort(spike_timestamps)\n", "spike_timestamps = spike_timestamps[spike_timestamps<duration_in_sec]\n", "\n", "spike_timestamps.shape" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(19,)" ] }, "metadata": {}, "execution_count": 171 } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## RNN" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "class Net(nn.Module):\n", " def __init__(self, n_inputs, n_hiddens, n_outputs):\n", " super(Net, self).__init__()\n", " self.rnn = nn.RNN(n_inputs, n_hiddens)\n", " self.fc1 = nn.Linear(10,10, bias=False)\n", "\n", " def forward(self, x):\n", " h = self.rnn(x)\n", " y = self.fc1(h)\n", "\n", " return y" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Train the RNN" ], "metadata": {} }, { "cell_type": "code", "execution_count": null, "source": [ "\n", "n_epoches = 10\n", "\n", "model = Net(10,10,10)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "criterion = torch.nn.BCELoss()\n", "\n", "for epoch in range(n_epoches):\n", " model.train()\n", " optimizer.zero_grad()\n", " x = ...\n", " y_true = ...\n", " y_pred = model(x)\n", "\n", " loss = criterion(y_pred, y_pred)\n", "\n", " loss.backward()\n", " optimizer.step()" ], "outputs": [], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.4", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.4 64-bit ('py3': conda)" }, "interpreter": { "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84" } }, "nbformat": 4, "nbformat_minor": 2 }