Newer
Older
notebooks / vrc_torch.ipynb
{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "A variable rate coding model of human respose time can be implemented using PyTorch (see the [network model](https://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharinghttps://drive.google.com/file/d/16eiUUwKGWfh9pu9VUxzlx046hQNHV0Qe/view?usp=sharing))."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns; sns.set()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Encode the input"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "source": [
    "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n",
    "\n",
    "rate = 2\n",
    "duration_in_sec = 10.\n",
    "\n",
    "n_spikes = np.random.poisson(rate * duration_in_sec)\n",
    "spike_timestamps = stats.uniform.rvs(loc=0, scale=duration_in_sec, size=n_spikes)\n",
    "spike_timestamps = np.sort(spike_timestamps)\n",
    "spike_timestamps = spike_timestamps[spike_timestamps<duration_in_sec]\n",
    "\n",
    "spike_timestamps.shape"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(19,)"
      ]
     },
     "metadata": {},
     "execution_count": 171
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## RNN"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "class Net(nn.Module):\n",
    "  def __init__(self, n_inputs, n_hiddens, n_outputs):\n",
    "    super(Net, self).__init__()\n",
    "    self.rnn = nn.RNN(n_inputs, n_hiddens)\n",
    "    self.fc1 = nn.Linear(10,10, bias=False)\n",
    "\n",
    "  def forward(self, x):\n",
    "    h = self.rnn(x)\n",
    "    y = self.fc1(h)\n",
    "\n",
    "    return y"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Train the RNN"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "\n",
    "n_epoches = 10\n",
    "\n",
    "model = Net(10,10,10)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "criterion = torch.nn.BCELoss()\n",
    "\n",
    "for epoch in range(n_epoches):\n",
    "  model.train()\n",
    "  optimizer.zero_grad()\n",
    "  x = ...\n",
    "  y_true = ...\n",
    "  y_pred = model(x)\n",
    "\n",
    "  loss = criterion(y_pred, y_pred)\n",
    "\n",
    "  loss.backward()\n",
    "  optimizer.step()"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.9.4",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.9.4 64-bit ('py3': conda)"
  },
  "interpreter": {
   "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}