{
"cells": [
{
"cell_type": "markdown",
"source": [
"A variable rate coding model of human respose time can be implemented using PyTorch (see the [network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing))."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 68,
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"import numpy as np\n",
"from scipy import stats\n",
"import pandas as pd\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns; sns.set()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Encode the input"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 171,
"source": [
"# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n",
"\n",
"rate = 2\n",
"duration_in_sec = 10.\n",
"\n",
"n_spikes = np.random.poisson(rate * duration_in_sec)\n",
"spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n",
"spike_timestamps = np.sort(spike_timestamps)\n",
"spike_timestamps = spike_timestamps[spike_timestamps<duration_in_sec]\n",
"\n",
"spike_timestamps.shape"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(19,)"
]
},
"metadata": {},
"execution_count": 171
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## RNN"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"class Net(nn.Module):\n",
" def __init__(self, n_inputs, n_hiddens, n_outputs):\n",
" super(Net, self).__init__()\n",
" self.rnn = nn.RNN(n_inputs, n_hiddens)\n",
" self.fc1 = nn.Linear(10,10, bias=False)\n",
"\n",
" def forward(self, x):\n",
" h = self.rnn(x)\n",
" y = self.fc1(h)\n",
"\n",
" return y"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## Train the RNN"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"\n",
"n_epoches = 10\n",
"\n",
"model = Net(10,10,10)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"criterion = torch.nn.BCELoss()\n",
"\n",
"for epoch in range(n_epoches):\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" x = ...\n",
" y_true = ...\n",
" y_pred = model(x)\n",
"\n",
" loss = criterion(y_pred, y_pred)\n",
"\n",
" loss.backward()\n",
" optimizer.step()"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.9.4",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.4 64-bit ('py3': conda)"
},
"interpreter": {
"hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84"
}
},
"nbformat": 4,
"nbformat_minor": 2
}