diff --git a/VRC_PonderRNN.ipynb b/VRC_PonderRNN.ipynb index e6c2393..8ef7d13 100644 --- a/VRC_PonderRNN.ipynb +++ b/VRC_PonderRNN.ipynb @@ -52,12 +52,13 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 206, "source": [ "# Setup and imports\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", + "from torch.utils.data import TensorDataset, DataLoader, random_split\n", "from torch.utils.tensorboard import SummaryWriter\n", "\n", "from tqdm import tqdm\n", @@ -80,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 202, "source": [ "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", "\n", @@ -115,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 203, "source": [ "\n", "\n", @@ -158,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 204, "source": [ "# mock data parameters\n", "n_subjects = 1\n", @@ -172,10 +173,10 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 210, "source": [ "class PonderRNN(nn.Module):\n", - " def __init__(self, n_inputs, n_channels, n_outputs, halting_prob_prior=0.0):\n", + " def __init__(self, n_inputs, n_channels, n_outputs, halting_prob_prior=0.2):\n", " super(PonderRNN, self).__init__()\n", " self.encode = nn.Sequential( # encode: x -> sent_msg\n", " nn.Linear(n_inputs, n_channels, bias=False),\n", @@ -189,7 +190,7 @@ " )\n", "\n", " # \\lambda_p\n", - " self.halting_prob_prior = halting_prob_prior\n", + " self.halting_prob = halting_prob_prior\n", "\n", " def forward(self, x):\n", " # x: one stimulus category, output: y[1..N] + halting_prob[1..N]\n", @@ -203,16 +204,18 @@ " msg = self.decode(msg)\n", " y = msg.squeeze()\n", "\n", - " # TODO\n", - " halting_prob = 0.\n", - " # lambda_n = ...\n", - " # halting_prob = torch.distributions.Geometric(self.halting_prob_prior)\n", + " halting_n = torch.distributions.Geometric(self.halting_prob).sample().detach()\n", " # self.halting_dist = torch.distributions.Geometric(prob)\n", " # self.halting_probs = torch.cat(halting_prob, halting_prob)\n", "\n", - " return y, halting_prob\n", + " return y, halting_n\n", "\n", "\n", + "# split params\n", + "train_size = int(n_trials * .8)\n", + "test_size = n_trials - train_size\n", + "\n", + "# training parrms\n", "n_epoches = 1500\n", "\n", "logs = SummaryWriter()\n", @@ -221,33 +224,53 @@ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", - "X_train = torch.tensor(X)\n", - "y_train = torch.tensor(y) - 1\n", + "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n", + "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n", + "\n", + "X_test, y_test = dataset[test_subset.indices]\n", + "\n", + "print(X_test)\n", + "\n", + "xx\n", "\n", "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n", - " model.train()\n", - " optimizer.zero_grad()\n", - " y_pred, _ = model(X_train)\n", + "\n", + " for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " y_pred, _ = model(X_batch)\n", "\n", " # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n", "\n", - " model_accuracy = accuracy_score(y_train.squeeze(), torch.argmax(y_pred.detach(),dim=1))\n", + " model_accuracy = accuracy_score(y_batch.squeeze(), torch.argmax(y_pred.detach(),dim=1))\n", " logs.add_scalar('accurracy/train', model_accuracy, epoch) \n", "\n", - " loss = criterion(y_pred, y_train.squeeze())\n", - "\n", + " loss = criterion(y_pred, y_batch.squeeze())\n", " logs.add_scalar('loss/train', loss, epoch)\n", + "\n", " loss.backward()\n", " optimizer.step()\n", "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " y_pred = model(X_test)\n", + " loss = criterion(y_test, y_pred)\n", + " logger.add_scalar('loss/test', loss.detach(), epoch)\n", + "\n", + "\n", "# tensorboard --logdir=runs" ], "outputs": [ { - "output_type": "stream", - "name": "stderr", - "text": [ - "Epochs: 100%|██████████| 1500/1500 [00:02<00:00, 659.68it/s]\n" + "output_type": "error", + "ename": "ValueError", + "evalue": "Sum of input lengths does not equal the length of the input dataset!", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTensorDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mtrain_subset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_subset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlengths\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest_subset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36mrandom_split\u001b[0;34m(dataset, lengths, generator)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0;31m# Cannot verify that dataset is Sized\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 351\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Sum of input lengths does not equal the length of the input dataset!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 352\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandperm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Sum of input lengths does not equal the length of the input dataset!" ] } ], @@ -279,20 +302,23 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 199, "source": [ - "torch.distributions.Geometric(.01).sample()" + "halting_prob = .1\n", + "x = torch.distributions.Geometric(halting_prob).sample().item()\n", + "type(x)\n", + "# .2*.2*.2*.8" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ - "tensor(56.)" + "float" ] }, "metadata": {}, - "execution_count": 107 + "execution_count": 199 } ], "metadata": {}