diff --git a/VRC_PonderRNN.ipynb b/VRC_PonderRNN.ipynb index 8ef7d13..8dc13b0 100644 --- a/VRC_PonderRNN.ipynb +++ b/VRC_PonderRNN.ipynb @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 206, + "execution_count": 157, "source": [ "# Setup and imports\n", "import torch\n", @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 202, + "execution_count": 158, "source": [ "# produce a tarin of spikes and store timestamps of each spike in `spike_timestamps`.\n", "\n", @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 203, + "execution_count": 226, "source": [ "\n", "\n", @@ -152,6 +152,12 @@ " # response time\n", " response_times = np.random.exponential(.5, size=accuracies.shape)\n", "\n", + " if n_subjects == 1:\n", + " X = X.squeeze()\n", + " y = y.squeeze()\n", + " accuracies = accuracies.squeeze()\n", + " response_times = response_times.squeeze()\n", + "\n", " return X, y, accuracies, response_times" ], "outputs": [], @@ -159,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 204, + "execution_count": 269, "source": [ "# mock data parameters\n", "n_subjects = 1\n", @@ -173,104 +179,158 @@ }, { "cell_type": "code", - "execution_count": 210, + "execution_count": 270, "source": [ - "class PonderRNN(nn.Module):\n", - " def __init__(self, n_inputs, n_channels, n_outputs, halting_prob_prior=0.2):\n", - " super(PonderRNN, self).__init__()\n", - " self.encode = nn.Sequential( # encode: x -> sent_msg\n", - " nn.Linear(n_inputs, n_channels, bias=False),\n", - " )\n", - " self.transmit = nn.Sequential( # transmit: sent_msg -> rcvd_msg\n", - " nn.RNN(n_channels, n_channels),\n", - " )\n", - " self.decode = nn.Sequential( # decode: rcvd_msg -> action\n", - " nn.Linear(n_channels,n_outputs, bias=False),\n", - " nn.Softmax(dim=2)\n", + "class ICOM(nn.Module):\n", + " def __init__(self, n_inputs, n_channels, n_outputs):\n", + " super(ICOM, self).__init__()\n", + "\n", + " self.n_inputs = n_inputs\n", + " # encode: x -> sent_msg\n", + " self.encode = nn.Linear(n_inputs, n_channels, bias=False)\n", + "\n", + " # transmit: sent_msg -> rcvd_msg\n", + " self.transmit = nn.RNN(n_channels, n_channels)\n", + "\n", + " # decode: rcvd_msg -> action\n", + " self.decode = nn.Sequential(\n", + " nn.Linear(n_channels,n_outputs, bias=False),\n", + " nn.Softmax(dim=2)\n", + " )\n", + "\n", + " def forward(self, x, h):\n", + "\n", + " msg = F.one_hot(x, num_classes=self.n_inputs).type(torch.float)\n", + " msg = self.encode(msg)\n", + " msg, h = self.transmit(msg, h)\n", + " y = self.decode(msg)\n", + "\n", + " return y.squeeze(), h" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 291, + "source": [ + "class PonderNet(nn.Module):\n", + " def __init__(self, n_inputs, n_embeddings, n_outputs, max_steps):\n", + " super(PonderNet, self).__init__()\n", + "\n", + " self.output_layer = ICOM(n_inputs, n_embeddings, n_outputs)\n", + "\n", + " self.halting_layer = nn.Sequential(\n", + " nn.Linear(n_embeddings, 1),\n", + " nn.Sigmoid()\n", " )\n", "\n", " # \\lambda_p\n", - " self.halting_prob = halting_prob_prior\n", + " self.max_steps = max_steps\n", + " self.n_embeddings = n_embeddings\n", + " self.n_outputs = n_outputs\n", "\n", " def forward(self, x):\n", - " # x: one stimulus category, output: y[1..N] + halting_prob[1..N]\n", - " # step 1: x -> x_n (repeat)\n", - " # step 2: x_n -> y_n\n", "\n", - " # VRC\n", - " msg = F.one_hot(x).type(torch.float)\n", - " msg = self.encode(msg)\n", - " msg, _ = self.transmit(msg)\n", - " msg = self.decode(msg)\n", - " y = msg.squeeze()\n", + " batch_size = x.shape[0]\n", "\n", - " halting_n = torch.distributions.Geometric(self.halting_prob).sample().detach()\n", - " # self.halting_dist = torch.distributions.Geometric(prob)\n", - " # self.halting_probs = torch.cat(halting_prob, halting_prob)\n", + " p = []\n", + " y = []\n", "\n", - " return y, halting_n\n", + " p_continue = torch.ones((batch_size,))\n", + " halt = torch.zeros((batch_size,))\n", + " p_m = torch.zeros((batch_size,))\n", + " y_m = torch.zeros((batch_size,))\n", + " p_n = torch.zeros((batch_size,))\n", + " h = torch.zeros((1,batch_size,self.n_embeddings))\n", "\n", + " for n in range(1, self.max_steps + 1):\n", + " y_n, h = self.output_layer(x.unsqueeze(0), h)\n", + " \n", + " if n == self.max_steps:\n", + " lambda_n = torch.tensor(1.)\n", + " halt_steps = torch.empty((batch_size,)).fill_(n)\n", + " else:\n", + " lambda_n = self.halting_layer(h)\n", + " halt_steps = torch.empty((batch_size,)).geometric_(lambda_n.detach()[0,0].item()) #FIXME\n", + "\n", + " if n % 500 == 0:\n", + " print('lambda:',lambda_n)\n", + " p_n = p_continue * lambda_n\n", + " p_continue = p_continue * (1 - lambda_n)\n", + "\n", + " p.append(p_n)\n", + " y.append(y_n)\n", + "\n", + " is_halted = (halt_steps <= n).type(torch.float)\n", + " p_m = p_m * (1 - is_halted) + p_n * is_halted\n", + " y_m = y_m * (1 - is_halted) + y_n * is_halted\n", + "\n", + " if all(halt):\n", + " break\n", + "\n", + " return torch.stack(y), torch.stack(p), y_m, p_m" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 300, + "source": [ "\n", "# split params\n", "train_size = int(n_trials * .8)\n", "test_size = n_trials - train_size\n", "\n", "# training parrms\n", - "n_epoches = 1500\n", + "n_epoches = 100\n", "\n", "logs = SummaryWriter()\n", "\n", - "model = PonderRNN(n_stimuli+1, n_stimuli, n_stimuli+1)\n", + "model = PonderNet(n_stimuli+1, n_stimuli, n_stimuli+1, 100)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "dataset = TensorDataset(torch.tensor(X), torch.tensor(y)-1)\n", "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n", "\n", + "X_train, y_train = dataset[train_subset.indices]\n", "X_test, y_test = dataset[test_subset.indices]\n", "\n", - "print(X_test)\n", - "\n", - "xx\n", - "\n", "for epoch in tqdm(range(n_epoches), desc='Epochs'):\n", "\n", " for X_batch, y_batch in DataLoader(train_subset, batch_size=1):\n", " model.train()\n", " optimizer.zero_grad()\n", - " y_pred, _ = model(X_batch)\n", + " ys, ps, y_pred, p_m = model(X_batch)\n", "\n", " # logs.add_embedding(h.reshape(n_trials,n_stimuli), global_step=epoch, tag='embedding')\n", "\n", - " model_accuracy = accuracy_score(y_batch.squeeze(), torch.argmax(y_pred.detach(),dim=1))\n", + " model_accuracy = accuracy_score(y_batch, torch.argmax(y_pred.detach(),dim=0).unsqueeze(0))\n", " logs.add_scalar('accurracy/train', model_accuracy, epoch) \n", "\n", - " loss = criterion(y_pred, y_batch.squeeze())\n", + " loss = criterion(y_pred.unsqueeze(0), y_batch)\n", + " \n", " logs.add_scalar('loss/train', loss, epoch)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", - " model.eval()\n", - " with torch.no_grad():\n", - " y_pred = model(X_test)\n", - " loss = criterion(y_test, y_pred)\n", - " logger.add_scalar('loss/test', loss.detach(), epoch)\n", - "\n", + " # model.eval()\n", + " # with torch.no_grad():\n", + " # _, _, y_pred, _ = model(X_test)\n", + " # loss = criterion(y_test, y_pred)\n", + " # logs.add_scalar('loss/test', loss.detach(), epoch)\n", "\n", "# tensorboard --logdir=runs" ], "outputs": [ { - "output_type": "error", - "ename": "ValueError", - "evalue": "Sum of input lengths does not equal the length of the input dataset!", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTensorDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mtrain_subset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_subset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlengths\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest_subset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/envs/py3/lib/python3.9/site-packages/torch/utils/data/dataset.py\u001b[0m in \u001b[0;36mrandom_split\u001b[0;34m(dataset, lengths, generator)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0;31m# Cannot verify that dataset is Sized\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 351\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Sum of input lengths does not equal the length of the input dataset!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 352\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandperm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlengths\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgenerator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Sum of input lengths does not equal the length of the input dataset!" + "output_type": "stream", + "name": "stderr", + "text": [ + "Epochs: 100%|██████████| 100/100 [00:57<00:00, 1.73it/s]\n" ] } ], @@ -278,10 +338,10 @@ }, { "cell_type": "code", - "execution_count": 137, + "execution_count": null, "source": [ "model.eval()\n", - "y_pred, _ = model(X_train)\n", + "_, _, y_pred, _ = model(X)\n", "y_pred = np.argmax(y_pred.detach().numpy(), axis=1) + 1\n", "y_pred, y" ], @@ -302,25 +362,16 @@ }, { "cell_type": "code", - "execution_count": 199, - "source": [ - "halting_prob = .1\n", - "x = torch.distributions.Geometric(halting_prob).sample().item()\n", - "type(x)\n", - "# .2*.2*.2*.8" - ], - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "float" - ] - }, - "metadata": {}, - "execution_count": 199 - } - ], + "execution_count": null, + "source": [], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [], + "outputs": [], "metadata": {} } ],