diff --git a/gnn_pg.ipynb b/gnn_pg.ipynb index 923a908..46a30f2 100644 --- a/gnn_pg.ipynb +++ b/gnn_pg.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 36, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -19,12 +19,15 @@ "import numpy as np\n", "import spektral\n", "from spektral.datasets.citation import Citation\n", - "from tqdm import tqdm" + "from tqdm import tqdm\n", + "from spektral.transforms import LayerPreprocess\n", + "from spektral.layers import GCNConv\n", + "from spektral.transforms import NormalizeAdj\n" ] }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -46,17 +49,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Pre-processing node features\n", - "epoch: 0,train_loss: 1.0980415344238281,val_acc: 0.4699999988079071,test_acc: 0.48900002241134644\n", - "epoch: 1,train_loss: 1.0857902765274048,val_acc: 0.531999945640564,test_acc: 0.5269998908042908\n", - "epoch: 5,train_loss: 1.0279830694198608,val_acc: 0.5459999442100525,test_acc: 0.5679998993873596\n", - "epoch: 6,train_loss: 1.0107402801513672,val_acc: 0.6380000710487366,test_acc: 0.6259998679161072\n", - "epoch: 7,train_loss: 0.9920498132705688,val_acc: 0.6820001006126404,test_acc: 0.6719997525215149\n", - "epoch: 8,train_loss: 0.9725067615509033,val_acc: 0.7120000720024109,test_acc: 0.6989997625350952\n", - "epoch: 9,train_loss: 0.9518073201179504,val_acc: 0.724000096321106,test_acc: 0.7109997868537903\n", - "epoch: 10,train_loss: 0.929932177066803,val_acc: 0.736000120639801,test_acc: 0.7179997563362122\n", - "epoch: 14,train_loss: 0.8337627053260803,val_acc: 0.7380000948905945,test_acc: 0.7209997773170471\n", - "epoch: 15,train_loss: 0.8073488473892212,val_acc: 0.7420001029968262,test_acc: 0.7239997982978821\n" + "Pre-processing node features\n" ] } ], @@ -119,39 +112,41 @@ " print(f'epoch: {epoch},'\n", " f'train_loss: {loss.numpy()},'\n", " f'val_acc: {val_accuracy.numpy()},'\n", - " f'test_acc: {test_accuracy.numpy()}')\n", - "\n", - "train(features, adj, gnn_fn, 32, 20, 0.01)" + " f'test_acc: {test_accuracy.numpy()}')\n" ] }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 0,train_loss: 1.0996336936950684,val_acc: 0.5119999647140503,test_acc: 0.4919999837875366\n", - "epoch: 2,train_loss: 1.073289394378662,val_acc: 0.5379999876022339,test_acc: 0.5219999551773071\n", - "epoch: 3,train_loss: 1.0569909811019897,val_acc: 0.6000000238418579,test_acc: 0.5999998450279236\n", - "epoch: 4,train_loss: 1.0379494428634644,val_acc: 0.6700000166893005,test_acc: 0.6739997863769531\n", - "epoch: 5,train_loss: 1.0155421495437622,val_acc: 0.7160000801086426,test_acc: 0.7109997868537903\n", - "epoch: 6,train_loss: 0.9916436672210693,val_acc: 0.7280001044273376,test_acc: 0.7139998078346252\n", - "epoch: 9,train_loss: 0.9153128266334534,val_acc: 0.7300000786781311,test_acc: 0.7069997787475586\n", - "epoch: 11,train_loss: 0.8576377034187317,val_acc: 0.7320001125335693,test_acc: 0.711999773979187\n", - "epoch: 12,train_loss: 0.8269254565238953,val_acc: 0.7380000948905945,test_acc: 0.714999794960022\n", - "epoch: 13,train_loss: 0.7955628037452698,val_acc: 0.7440001368522644,test_acc: 0.7159997820854187\n", - "epoch: 15,train_loss: 0.7319157123565674,val_acc: 0.7460001111030579,test_acc: 0.714999794960022\n", - "epoch: 17,train_loss: 0.6676653623580933,val_acc: 0.7500001192092896,test_acc: 0.7179997563362122\n", - "epoch: 18,train_loss: 0.6357672810554504,val_acc: 0.752000093460083,test_acc: 0.7169997692108154\n" + "epoch: 0,train_loss: 1.0988209247589111,val_acc: 0.49799996614456177,test_acc: 0.5019999742507935\n", + "epoch: 3,train_loss: 1.0498652458190918,val_acc: 0.515999972820282,test_acc: 0.528999924659729\n", + "epoch: 4,train_loss: 1.0287455320358276,val_acc: 0.5720000267028809,test_acc: 0.6109998226165771\n", + "epoch: 5,train_loss: 1.0060378313064575,val_acc: 0.6500000357627869,test_acc: 0.6719998121261597\n", + "epoch: 6,train_loss: 0.9812273383140564,val_acc: 0.6760000586509705,test_acc: 0.6909998059272766\n", + "epoch: 7,train_loss: 0.9548961520195007,val_acc: 0.6880000233650208,test_acc: 0.6979997158050537\n", + "epoch: 8,train_loss: 0.9275616407394409,val_acc: 0.7100000977516174,test_acc: 0.6979997158050537\n", + "epoch: 9,train_loss: 0.899043619632721,val_acc: 0.7140000462532043,test_acc: 0.7019997835159302\n", + "epoch: 11,train_loss: 0.8381972908973694,val_acc: 0.7180001139640808,test_acc: 0.7079997658729553\n", + "epoch: 12,train_loss: 0.8062182664871216,val_acc: 0.7220000624656677,test_acc: 0.711999773979187\n", + "epoch: 13,train_loss: 0.7739933133125305,val_acc: 0.7300000786781311,test_acc: 0.714999794960022\n", + "epoch: 14,train_loss: 0.7413352131843567,val_acc: 0.7380000948905945,test_acc: 0.7249997854232788\n", + "epoch: 15,train_loss: 0.7085117697715759,val_acc: 0.7420001029968262,test_acc: 0.7279998064041138\n", + "epoch: 19,train_loss: 0.5778310894966125,val_acc: 0.7440000772476196,test_acc: 0.7369998097419739\n", + "epoch: 20,train_loss: 0.5461804866790771,val_acc: 0.7440001368522644,test_acc: 0.7369998097419739\n" ] } ], "source": [ "# on identity train(features, tf.eye(adj.shape[0]), gnn_fn, 32, 20, 0.01)\n", "\n", + "# no normalization: train(features, adj, gnn_fn, 32, 20, 0.01)\n", + "\n", "# normalize adj by degree\n", "deg = tf.reduce_sum(adj, axis=-1)\n", "train(features, adj / deg, gnn_fn, 32, 20, 0.01)" @@ -159,30 +154,15 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 23, "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/3_/gmvd1nkx285133z5yh3chz2c0000gp/T/ipykernel_40315/3405598776.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mnorm_deg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1.9\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdeg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mnorm_adj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorm_deg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_deg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_adj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgnn_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;34m\"\"\"Call target, and fall back on dispatchers if there is a TypeError.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\u001b[0m in \u001b[0;36mmatmul\u001b[0;34m(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, output_type, name)\u001b[0m\n\u001b[1;32m 3652\u001b[0m a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name)\n\u001b[1;32m 3653\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3654\u001b[0;31m return gen_math_ops.mat_mul(\n\u001b[0m\u001b[1;32m 3655\u001b[0m a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)\n\u001b[1;32m 3656\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py\u001b[0m in \u001b[0;36mmat_mul\u001b[0;34m(a, b, transpose_a, transpose_b, name)\u001b[0m\n\u001b[1;32m 5689\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtld\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_eager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5690\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5691\u001b[0;31m _result = pywrap_tfe.TFE_Py_FastPathExecute(\n\u001b[0m\u001b[1;32m 5692\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"MatMul\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"transpose_a\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtranspose_a\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"transpose_b\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5693\u001b[0m transpose_b)\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "# thomas gibb in graph conv networks: normalize by 1/sqrt(deg)\n", "\n", - "norm_deg = tf.linalg.diag(1.9 / tf.sqrt(deg))\n", - "norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n", - "train(features, norm_adj, gnn_fn, 32, 200, 0.01)" + "# norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg))\n", + "# norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n", + "# train(features, norm_adj, gnn_fn, 32, 200, 0.01)" ] }, { @@ -194,351 +174,217 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/1000\n", - "1/1 [==============================] - 0s 462ms/step - loss: 1.1087 - acc: 0.2833 - val_loss: 1.1010 - val_acc: 0.6500\n", - "Epoch 2/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 1.1005 - acc: 0.6167 - val_loss: 1.0974 - val_acc: 0.7500\n", - "Epoch 3/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 1.0946 - acc: 0.7500 - val_loss: 1.0928 - val_acc: 0.7600\n", - "Epoch 4/1000\n", - "1/1 [==============================] - 0s 101ms/step - loss: 1.0894 - acc: 0.8833 - val_loss: 1.0872 - val_acc: 0.7560\n", - "Epoch 5/1000\n", - "1/1 [==============================] - 0s 93ms/step - loss: 1.0805 - acc: 0.8167 - val_loss: 1.0809 - val_acc: 0.7520\n", - "Epoch 6/1000\n", - "1/1 [==============================] - 0s 98ms/step - loss: 1.0727 - acc: 0.8667 - val_loss: 1.0744 - val_acc: 0.7580\n", - "Epoch 7/1000\n", - "1/1 [==============================] - 0s 95ms/step - loss: 1.0636 - acc: 0.9000 - val_loss: 1.0681 - val_acc: 0.7520\n", - "Epoch 8/1000\n", - "1/1 [==============================] - 0s 105ms/step - loss: 1.0563 - acc: 0.8667 - val_loss: 1.0619 - val_acc: 0.7460\n", - "Epoch 9/1000\n", - "1/1 [==============================] - 0s 113ms/step - loss: 1.0451 - acc: 0.8667 - val_loss: 1.0557 - val_acc: 0.7460\n", - "Epoch 10/1000\n", - "1/1 [==============================] - 0s 120ms/step - loss: 1.0412 - acc: 0.8667 - val_loss: 1.0497 - val_acc: 0.7460\n", - "Epoch 11/1000\n", - "1/1 [==============================] - 0s 111ms/step - loss: 1.0284 - acc: 0.8833 - val_loss: 1.0435 - val_acc: 0.7460\n", - "Epoch 12/1000\n", - "1/1 [==============================] - 0s 101ms/step - loss: 1.0180 - acc: 0.9000 - val_loss: 1.0375 - val_acc: 0.7480\n", - "Epoch 13/1000\n", - "1/1 [==============================] - 0s 93ms/step - loss: 1.0036 - acc: 0.9000 - val_loss: 1.0315 - val_acc: 0.7520\n", - "Epoch 14/1000\n", - "1/1 [==============================] - 0s 95ms/step - loss: 0.9942 - acc: 0.9333 - val_loss: 1.0255 - val_acc: 0.7580\n", - "Epoch 15/1000\n", - "1/1 [==============================] - 0s 94ms/step - loss: 0.9864 - acc: 0.9167 - val_loss: 1.0196 - val_acc: 0.7620\n", - "Epoch 16/1000\n", - "1/1 [==============================] - 0s 93ms/step - loss: 0.9864 - acc: 0.9167 - val_loss: 1.0137 - val_acc: 0.7620\n", - "Epoch 17/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.9612 - acc: 0.9500 - val_loss: 1.0077 - val_acc: 0.7620\n", - "Epoch 18/1000\n", - "1/1 [==============================] - 0s 95ms/step - loss: 0.9546 - acc: 0.9167 - val_loss: 1.0018 - val_acc: 0.7640\n", - "Epoch 19/1000\n", - "1/1 [==============================] - 0s 104ms/step - loss: 0.9423 - acc: 0.9500 - val_loss: 0.9957 - val_acc: 0.7660\n", - "Epoch 20/1000\n", - "1/1 [==============================] - 0s 108ms/step - loss: 0.9271 - acc: 0.9167 - val_loss: 0.9895 - val_acc: 0.7620\n", - "Epoch 21/1000\n", - "1/1 [==============================] - 0s 101ms/step - loss: 0.9319 - acc: 0.8833 - val_loss: 0.9832 - val_acc: 0.7620\n", - "Epoch 22/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.9068 - acc: 0.9000 - val_loss: 0.9770 - val_acc: 0.7640\n", - "Epoch 23/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.8857 - acc: 0.8667 - val_loss: 0.9708 - val_acc: 0.7640\n", - "Epoch 24/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.8685 - acc: 0.8833 - val_loss: 0.9646 - val_acc: 0.7640\n", - "Epoch 25/1000\n", - "1/1 [==============================] - 0s 95ms/step - loss: 0.8696 - acc: 0.9333 - val_loss: 0.9583 - val_acc: 0.7640\n", - "Epoch 26/1000\n", - "1/1 [==============================] - 0s 98ms/step - loss: 0.8706 - acc: 0.9500 - val_loss: 0.9521 - val_acc: 0.7640\n", - "Epoch 27/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.8500 - acc: 0.9167 - val_loss: 0.9459 - val_acc: 0.7680\n", - "Epoch 28/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.8448 - acc: 0.9167 - val_loss: 0.9399 - val_acc: 0.7700\n", - "Epoch 29/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.8462 - acc: 0.9333 - val_loss: 0.9339 - val_acc: 0.7660\n", - "Epoch 30/1000\n", - "1/1 [==============================] - 0s 84ms/step - loss: 0.8169 - acc: 0.9000 - val_loss: 0.9280 - val_acc: 0.7680\n", - "Epoch 31/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.7901 - acc: 0.9500 - val_loss: 0.9223 - val_acc: 0.7700\n", - "Epoch 32/1000\n", - "1/1 [==============================] - 0s 95ms/step - loss: 0.8125 - acc: 0.9500 - val_loss: 0.9167 - val_acc: 0.7700\n", - "Epoch 33/1000\n", - "1/1 [==============================] - 0s 93ms/step - loss: 0.7813 - acc: 0.9667 - val_loss: 0.9110 - val_acc: 0.7700\n", - "Epoch 34/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.7728 - acc: 0.9167 - val_loss: 0.9055 - val_acc: 0.7700\n", - "Epoch 35/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.7708 - acc: 0.9333 - val_loss: 0.9001 - val_acc: 0.7720\n", - "Epoch 36/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.7519 - acc: 0.9333 - val_loss: 0.8949 - val_acc: 0.7700\n", - "Epoch 37/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.7435 - acc: 0.9500 - val_loss: 0.8899 - val_acc: 0.7700\n", - "Epoch 38/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.7134 - acc: 0.9333 - val_loss: 0.8849 - val_acc: 0.7680\n", - "Epoch 39/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.7055 - acc: 0.9500 - val_loss: 0.8800 - val_acc: 0.7680\n", - "Epoch 40/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.6887 - acc: 0.9667 - val_loss: 0.8752 - val_acc: 0.7680\n", - "Epoch 41/1000\n", - "1/1 [==============================] - 0s 93ms/step - loss: 0.6847 - acc: 0.9667 - val_loss: 0.8703 - val_acc: 0.7680\n", - "Epoch 42/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.6918 - acc: 0.9167 - val_loss: 0.8655 - val_acc: 0.7700\n", - "Epoch 43/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.6699 - acc: 0.9333 - val_loss: 0.8608 - val_acc: 0.7700\n", - "Epoch 44/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.6433 - acc: 0.9667 - val_loss: 0.8565 - val_acc: 0.7720\n", - "Epoch 45/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.6713 - acc: 0.9167 - val_loss: 0.8523 - val_acc: 0.7720\n", - "Epoch 46/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.6396 - acc: 0.9667 - val_loss: 0.8482 - val_acc: 0.7720\n", - "Epoch 47/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.6291 - acc: 0.9167 - val_loss: 0.8444 - val_acc: 0.7740\n", - "Epoch 48/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.6029 - acc: 0.9667 - val_loss: 0.8407 - val_acc: 0.7700\n", - "Epoch 49/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.6102 - acc: 0.9500 - val_loss: 0.8371 - val_acc: 0.7700\n", - "Epoch 50/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.6154 - acc: 0.9167 - val_loss: 0.8335 - val_acc: 0.7700\n", - "Epoch 51/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.6160 - acc: 0.9500 - val_loss: 0.8302 - val_acc: 0.7700\n", - "Epoch 52/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.5654 - acc: 0.9667 - val_loss: 0.8271 - val_acc: 0.7680\n", - "Epoch 53/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.6097 - acc: 0.9500 - val_loss: 0.8242 - val_acc: 0.7700\n", - "Epoch 54/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.5976 - acc: 0.9500 - val_loss: 0.8213 - val_acc: 0.7720\n", - "Epoch 55/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.5914 - acc: 0.9000 - val_loss: 0.8188 - val_acc: 0.7720\n", - "Epoch 56/1000\n", - "1/1 [==============================] - 0s 93ms/step - loss: 0.5761 - acc: 0.9667 - val_loss: 0.8164 - val_acc: 0.7760\n", - "Epoch 57/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.5613 - acc: 0.9667 - val_loss: 0.8141 - val_acc: 0.7800\n", - "Epoch 58/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.5773 - acc: 0.9500 - val_loss: 0.8117 - val_acc: 0.7800\n", - "Epoch 59/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.5665 - acc: 0.9500 - val_loss: 0.8094 - val_acc: 0.7800\n", - "Epoch 60/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.5408 - acc: 0.9667 - val_loss: 0.8069 - val_acc: 0.7800\n", - "Epoch 61/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.5425 - acc: 0.9500 - val_loss: 0.8043 - val_acc: 0.7840\n", - "Epoch 62/1000\n", - "1/1 [==============================] - 0s 104ms/step - loss: 0.5627 - acc: 0.9167 - val_loss: 0.8019 - val_acc: 0.7840\n", - "Epoch 63/1000\n", - "1/1 [==============================] - 0s 96ms/step - loss: 0.5338 - acc: 0.9333 - val_loss: 0.7996 - val_acc: 0.7820\n", - "Epoch 64/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.5082 - acc: 0.9667 - val_loss: 0.7972 - val_acc: 0.7820\n", - "Epoch 65/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.5194 - acc: 0.9667 - val_loss: 0.7950 - val_acc: 0.7820\n", - "Epoch 66/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.4943 - acc: 0.9333 - val_loss: 0.7929 - val_acc: 0.7840\n", - "Epoch 67/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.5025 - acc: 0.9667 - val_loss: 0.7909 - val_acc: 0.7820\n", - "Epoch 68/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.4835 - acc: 0.9833 - val_loss: 0.7891 - val_acc: 0.7800\n", - "Epoch 69/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.5119 - acc: 0.9833 - val_loss: 0.7874 - val_acc: 0.7800\n", - "Epoch 70/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.5110 - acc: 0.9667 - val_loss: 0.7857 - val_acc: 0.7800\n", - "Epoch 71/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.4949 - acc: 0.9333 - val_loss: 0.7842 - val_acc: 0.7800\n", - "Epoch 72/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.4662 - acc: 0.9667 - val_loss: 0.7828 - val_acc: 0.7800\n", - "Epoch 73/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.5109 - acc: 0.9833 - val_loss: 0.7810 - val_acc: 0.7780\n", - "Epoch 74/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.4513 - acc: 0.9833 - val_loss: 0.7792 - val_acc: 0.7820\n", - "Epoch 75/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.4873 - acc: 0.9500 - val_loss: 0.7772 - val_acc: 0.7840\n", - "Epoch 76/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.4595 - acc: 0.9667 - val_loss: 0.7753 - val_acc: 0.7820\n", - "Epoch 77/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.5105 - acc: 0.9500 - val_loss: 0.7736 - val_acc: 0.7860\n", - "Epoch 78/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.4737 - acc: 0.9667 - val_loss: 0.7717 - val_acc: 0.7880\n", - "Epoch 79/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.4541 - acc: 0.9833 - val_loss: 0.7702 - val_acc: 0.7900\n", - "Epoch 80/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.4477 - acc: 1.0000 - val_loss: 0.7685 - val_acc: 0.7900\n", - "Epoch 81/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.4596 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7920\n", - "Epoch 82/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.4661 - acc: 0.9333 - val_loss: 0.7651 - val_acc: 0.7920\n", - "Epoch 83/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.4452 - acc: 0.9500 - val_loss: 0.7639 - val_acc: 0.7940\n", - "Epoch 84/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.4559 - acc: 0.9667 - val_loss: 0.7631 - val_acc: 0.7940\n", - "Epoch 85/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.4705 - acc: 0.9500 - val_loss: 0.7623 - val_acc: 0.7960\n", - "Epoch 86/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.4418 - acc: 0.9500 - val_loss: 0.7616 - val_acc: 0.7940\n", - "Epoch 87/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.4240 - acc: 0.9667 - val_loss: 0.7610 - val_acc: 0.7940\n", - "Epoch 88/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.4489 - acc: 0.9667 - val_loss: 0.7603 - val_acc: 0.7940\n", - "Epoch 89/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.4326 - acc: 0.9833 - val_loss: 0.7596 - val_acc: 0.7920\n", - "Epoch 90/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.4337 - acc: 0.9667 - val_loss: 0.7586 - val_acc: 0.7900\n", - "Epoch 91/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.4263 - acc: 0.9667 - val_loss: 0.7579 - val_acc: 0.7900\n", - "Epoch 92/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.4165 - acc: 1.0000 - val_loss: 0.7568 - val_acc: 0.7900\n", - "Epoch 93/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.4393 - acc: 0.9667 - val_loss: 0.7557 - val_acc: 0.7900\n", - "Epoch 94/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.4316 - acc: 1.0000 - val_loss: 0.7544 - val_acc: 0.7900\n", - "Epoch 95/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.4527 - acc: 0.9667 - val_loss: 0.7529 - val_acc: 0.7900\n", - "Epoch 96/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.4238 - acc: 0.9667 - val_loss: 0.7512 - val_acc: 0.7880\n", - "Epoch 97/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.3868 - acc: 0.9833 - val_loss: 0.7496 - val_acc: 0.7940\n", - "Epoch 98/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3973 - acc: 1.0000 - val_loss: 0.7478 - val_acc: 0.7940\n", - "Epoch 99/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.4228 - acc: 0.9667 - val_loss: 0.7459 - val_acc: 0.7940\n", - "Epoch 100/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3667 - acc: 0.9833 - val_loss: 0.7443 - val_acc: 0.7960\n", - "Epoch 101/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3947 - acc: 0.9667 - val_loss: 0.7430 - val_acc: 0.7960\n", - "Epoch 102/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3955 - acc: 0.9667 - val_loss: 0.7416 - val_acc: 0.7940\n", - "Epoch 103/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.4087 - acc: 0.9833 - val_loss: 0.7402 - val_acc: 0.7940\n", - "Epoch 104/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.4469 - acc: 0.9833 - val_loss: 0.7387 - val_acc: 0.7940\n", - "Epoch 105/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3881 - acc: 0.9833 - val_loss: 0.7375 - val_acc: 0.7940\n", - "Epoch 106/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3735 - acc: 0.9833 - val_loss: 0.7365 - val_acc: 0.7940\n", - "Epoch 107/1000\n", - "1/1 [==============================] - 0s 84ms/step - loss: 0.4000 - acc: 0.9667 - val_loss: 0.7358 - val_acc: 0.7940\n", - "Epoch 108/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3820 - acc: 1.0000 - val_loss: 0.7352 - val_acc: 0.7960\n", - "Epoch 109/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.3557 - acc: 0.9833 - val_loss: 0.7345 - val_acc: 0.7980\n", - "Epoch 110/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3602 - acc: 1.0000 - val_loss: 0.7337 - val_acc: 0.7960\n", - "Epoch 111/1000\n", - "1/1 [==============================] - 0s 98ms/step - loss: 0.4138 - acc: 0.9667 - val_loss: 0.7329 - val_acc: 0.7940\n", - "Epoch 112/1000\n", - "1/1 [==============================] - 0s 83ms/step - loss: 0.3778 - acc: 0.9667 - val_loss: 0.7324 - val_acc: 0.7960\n", - "Epoch 113/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.3877 - acc: 0.9667 - val_loss: 0.7322 - val_acc: 0.7920\n", - "Epoch 114/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3661 - acc: 1.0000 - val_loss: 0.7317 - val_acc: 0.7920\n", - "Epoch 115/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3956 - acc: 0.9667 - val_loss: 0.7310 - val_acc: 0.7920\n", - "Epoch 116/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3616 - acc: 1.0000 - val_loss: 0.7300 - val_acc: 0.7920\n", - "Epoch 117/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.3594 - acc: 0.9667 - val_loss: 0.7291 - val_acc: 0.7880\n", - "Epoch 118/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3739 - acc: 0.9833 - val_loss: 0.7285 - val_acc: 0.7860\n", - "Epoch 119/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.3778 - acc: 0.9833 - val_loss: 0.7280 - val_acc: 0.7880\n", - "Epoch 120/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3669 - acc: 1.0000 - val_loss: 0.7274 - val_acc: 0.7900\n", - "Epoch 121/1000\n", - "1/1 [==============================] - 0s 85ms/step - loss: 0.3288 - acc: 1.0000 - val_loss: 0.7266 - val_acc: 0.7880\n", - "Epoch 122/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.3726 - acc: 0.9833 - val_loss: 0.7259 - val_acc: 0.7880\n", - "Epoch 123/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3813 - acc: 0.9667 - val_loss: 0.7250 - val_acc: 0.7900\n", - "Epoch 124/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3615 - acc: 0.9833 - val_loss: 0.7240 - val_acc: 0.7920\n", - "Epoch 125/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3283 - acc: 1.0000 - val_loss: 0.7229 - val_acc: 0.7920\n", - "Epoch 126/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3312 - acc: 0.9833 - val_loss: 0.7215 - val_acc: 0.7940\n", - "Epoch 127/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.4086 - acc: 0.9667 - val_loss: 0.7206 - val_acc: 0.7960\n", - "Epoch 128/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3588 - acc: 0.9667 - val_loss: 0.7195 - val_acc: 0.7920\n", - "Epoch 129/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3678 - acc: 1.0000 - val_loss: 0.7184 - val_acc: 0.7920\n", - "Epoch 130/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3201 - acc: 0.9833 - val_loss: 0.7175 - val_acc: 0.7920\n", - "Epoch 131/1000\n", - "1/1 [==============================] - 0s 100ms/step - loss: 0.3630 - acc: 0.9833 - val_loss: 0.7167 - val_acc: 0.7920\n", - "Epoch 132/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3601 - acc: 0.9667 - val_loss: 0.7163 - val_acc: 0.7900\n", - "Epoch 133/1000\n", - "1/1 [==============================] - 0s 83ms/step - loss: 0.3231 - acc: 1.0000 - val_loss: 0.7158 - val_acc: 0.7900\n", - "Epoch 134/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3124 - acc: 1.0000 - val_loss: 0.7150 - val_acc: 0.7900\n", - "Epoch 135/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.3430 - acc: 0.9833 - val_loss: 0.7142 - val_acc: 0.7920\n", - "Epoch 136/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3344 - acc: 0.9833 - val_loss: 0.7135 - val_acc: 0.7900\n", - "Epoch 137/1000\n", - "1/1 [==============================] - 0s 84ms/step - loss: 0.3537 - acc: 0.9833 - val_loss: 0.7130 - val_acc: 0.7900\n", - "Epoch 138/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3424 - acc: 0.9667 - val_loss: 0.7128 - val_acc: 0.7880\n", - "Epoch 139/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3199 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n", - "Epoch 140/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3173 - acc: 0.9833 - val_loss: 0.7127 - val_acc: 0.7860\n", - "Epoch 141/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3474 - acc: 0.9833 - val_loss: 0.7126 - val_acc: 0.7840\n", - "Epoch 142/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3292 - acc: 1.0000 - val_loss: 0.7127 - val_acc: 0.7840\n", - "Epoch 143/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3552 - acc: 0.9667 - val_loss: 0.7127 - val_acc: 0.7820\n", - "Epoch 144/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3256 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n", - "Epoch 145/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3204 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n", - "Epoch 146/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3133 - acc: 0.9833 - val_loss: 0.7122 - val_acc: 0.7800\n", - "Epoch 147/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.3209 - acc: 1.0000 - val_loss: 0.7116 - val_acc: 0.7800\n", - "Epoch 148/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3291 - acc: 1.0000 - val_loss: 0.7100 - val_acc: 0.7820\n", - "Epoch 149/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3524 - acc: 0.9833 - val_loss: 0.7079 - val_acc: 0.7840\n", - "Epoch 150/1000\n", - "1/1 [==============================] - 0s 99ms/step - loss: 0.3336 - acc: 1.0000 - val_loss: 0.7057 - val_acc: 0.7860\n", - "Epoch 151/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.3318 - acc: 0.9667 - val_loss: 0.7037 - val_acc: 0.7900\n", - "Epoch 152/1000\n", - "1/1 [==============================] - 0s 82ms/step - loss: 0.3615 - acc: 0.9833 - val_loss: 0.7022 - val_acc: 0.7920\n", - "Epoch 153/1000\n", - "1/1 [==============================] - 0s 86ms/step - loss: 0.3089 - acc: 1.0000 - val_loss: 0.7011 - val_acc: 0.7920\n", - "Epoch 154/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3616 - acc: 0.9667 - val_loss: 0.7006 - val_acc: 0.7920\n", - "Epoch 155/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.2931 - acc: 1.0000 - val_loss: 0.7000 - val_acc: 0.7940\n", - "Epoch 156/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3272 - acc: 1.0000 - val_loss: 0.6993 - val_acc: 0.7960\n", - "Epoch 157/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3255 - acc: 0.9833 - val_loss: 0.6992 - val_acc: 0.7980\n", - "Epoch 158/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3066 - acc: 0.9833 - val_loss: 0.6993 - val_acc: 0.7980\n", - "Epoch 159/1000\n", - "1/1 [==============================] - 0s 89ms/step - loss: 0.3397 - acc: 0.9500 - val_loss: 0.6993 - val_acc: 0.7960\n", - "Epoch 160/1000\n", - "1/1 [==============================] - 0s 91ms/step - loss: 0.3396 - acc: 0.9833 - val_loss: 0.6993 - val_acc: 0.7960\n", - "Epoch 161/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3275 - acc: 1.0000 - val_loss: 0.6997 - val_acc: 0.7960\n", - "Epoch 162/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.3218 - acc: 0.9667 - val_loss: 0.7001 - val_acc: 0.7940\n", - "Epoch 163/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3345 - acc: 0.9833 - val_loss: 0.7006 - val_acc: 0.7940\n", - "Epoch 164/1000\n", - "1/1 [==============================] - 0s 88ms/step - loss: 0.3028 - acc: 0.9833 - val_loss: 0.7010 - val_acc: 0.7940\n", - "Epoch 165/1000\n", - "1/1 [==============================] - 0s 90ms/step - loss: 0.3109 - acc: 1.0000 - val_loss: 0.7012 - val_acc: 0.7940\n", - "Epoch 166/1000\n", - "1/1 [==============================] - 0s 87ms/step - loss: 0.3116 - acc: 0.9667 - val_loss: 0.7015 - val_acc: 0.7940\n", - "Epoch 167/1000\n", - "1/1 [==============================] - 0s 92ms/step - loss: 0.2926 - acc: 1.0000 - val_loss: 0.7014 - val_acc: 0.7900\n", - "1/1 [==============================] - 0s 38ms/step - loss: 0.7112 - acc: 0.7930\n", + "Epoch 1/100\n", + "1/1 [==============================] - 0s 429ms/step - loss: 1.1066 - acc: 0.3833 - val_loss: 1.1029 - val_acc: 0.4720\n", + "Epoch 2/100\n", + "1/1 [==============================] - 0s 84ms/step - loss: 1.1028 - acc: 0.5167 - val_loss: 1.0993 - val_acc: 0.5660\n", + "Epoch 3/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 1.0975 - acc: 0.6167 - val_loss: 1.0951 - val_acc: 0.6140\n", + "Epoch 4/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 1.0911 - acc: 0.7500 - val_loss: 1.0905 - val_acc: 0.6480\n", + "Epoch 5/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 1.0841 - acc: 0.8667 - val_loss: 1.0855 - val_acc: 0.6700\n", + "Epoch 6/100\n", + "1/1 [==============================] - 0s 96ms/step - loss: 1.0779 - acc: 0.8333 - val_loss: 1.0803 - val_acc: 0.6780\n", + "Epoch 7/100\n", + "1/1 [==============================] - 0s 98ms/step - loss: 1.0696 - acc: 0.8667 - val_loss: 1.0750 - val_acc: 0.6840\n", + "Epoch 8/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 1.0639 - acc: 0.8667 - val_loss: 1.0696 - val_acc: 0.6880\n", + "Epoch 9/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 1.0566 - acc: 0.8333 - val_loss: 1.0641 - val_acc: 0.6920\n", + "Epoch 10/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 1.0499 - acc: 0.9000 - val_loss: 1.0586 - val_acc: 0.6900\n", + "Epoch 11/100\n", + "1/1 [==============================] - 0s 96ms/step - loss: 1.0377 - acc: 0.8833 - val_loss: 1.0531 - val_acc: 0.6900\n", + "Epoch 12/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 1.0297 - acc: 0.8833 - val_loss: 1.0475 - val_acc: 0.6940\n", + "Epoch 13/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 1.0265 - acc: 0.8833 - val_loss: 1.0419 - val_acc: 0.6940\n", + "Epoch 14/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 1.0154 - acc: 0.8167 - val_loss: 1.0364 - val_acc: 0.6960\n", + "Epoch 15/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 1.0056 - acc: 0.8500 - val_loss: 1.0309 - val_acc: 0.7000\n", + "Epoch 16/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.9946 - acc: 0.8500 - val_loss: 1.0254 - val_acc: 0.7080\n", + "Epoch 17/100\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.9747 - acc: 0.9167 - val_loss: 1.0198 - val_acc: 0.7100\n", + "Epoch 18/100\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.9775 - acc: 0.8833 - val_loss: 1.0143 - val_acc: 0.7240\n", + "Epoch 19/100\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.9703 - acc: 0.8167 - val_loss: 1.0088 - val_acc: 0.7280\n", + "Epoch 20/100\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.9489 - acc: 0.8833 - val_loss: 1.0032 - val_acc: 0.7320\n", + "Epoch 21/100\n", + "1/1 [==============================] - 0s 99ms/step - loss: 0.9471 - acc: 0.8500 - val_loss: 0.9977 - val_acc: 0.7340\n", + "Epoch 22/100\n", + "1/1 [==============================] - 0s 111ms/step - loss: 0.9274 - acc: 0.8833 - val_loss: 0.9921 - val_acc: 0.7360\n", + "Epoch 23/100\n", + "1/1 [==============================] - 0s 106ms/step - loss: 0.9140 - acc: 0.9333 - val_loss: 0.9863 - val_acc: 0.7320\n", + "Epoch 24/100\n", + "1/1 [==============================] - 0s 103ms/step - loss: 0.9167 - acc: 0.8833 - val_loss: 0.9805 - val_acc: 0.7380\n", + "Epoch 25/100\n", + "1/1 [==============================] - 0s 96ms/step - loss: 0.9001 - acc: 0.9000 - val_loss: 0.9748 - val_acc: 0.7380\n", + "Epoch 26/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.8726 - acc: 0.9500 - val_loss: 0.9691 - val_acc: 0.7440\n", + "Epoch 27/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.8624 - acc: 0.9333 - val_loss: 0.9636 - val_acc: 0.7420\n", + "Epoch 28/100\n", + "1/1 [==============================] - 0s 105ms/step - loss: 0.8646 - acc: 0.8833 - val_loss: 0.9581 - val_acc: 0.7460\n", + "Epoch 29/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.8331 - acc: 0.9333 - val_loss: 0.9527 - val_acc: 0.7460\n", + "Epoch 30/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.8285 - acc: 0.9500 - val_loss: 0.9473 - val_acc: 0.7460\n", + "Epoch 31/100\n", + "1/1 [==============================] - 0s 85ms/step - loss: 0.8413 - acc: 0.9167 - val_loss: 0.9419 - val_acc: 0.7460\n", + "Epoch 32/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.8097 - acc: 0.9167 - val_loss: 0.9364 - val_acc: 0.7520\n", + "Epoch 33/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.7880 - acc: 0.9000 - val_loss: 0.9310 - val_acc: 0.7540\n", + "Epoch 34/100\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.7893 - acc: 0.9667 - val_loss: 0.9257 - val_acc: 0.7560\n", + "Epoch 35/100\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.7557 - acc: 0.9667 - val_loss: 0.9205 - val_acc: 0.7580\n", + "Epoch 36/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.7780 - acc: 0.9167 - val_loss: 0.9150 - val_acc: 0.7580\n", + "Epoch 37/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.7612 - acc: 0.9500 - val_loss: 0.9096 - val_acc: 0.7640\n", + "Epoch 38/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.7419 - acc: 0.9500 - val_loss: 0.9043 - val_acc: 0.7660\n", + "Epoch 39/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.7383 - acc: 0.8833 - val_loss: 0.8992 - val_acc: 0.7680\n", + "Epoch 40/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.7241 - acc: 0.9500 - val_loss: 0.8943 - val_acc: 0.7700\n", + "Epoch 41/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.7062 - acc: 0.9500 - val_loss: 0.8897 - val_acc: 0.7720\n", + "Epoch 42/100\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.7407 - acc: 0.8500 - val_loss: 0.8853 - val_acc: 0.7720\n", + "Epoch 43/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.7206 - acc: 0.9667 - val_loss: 0.8810 - val_acc: 0.7740\n", + "Epoch 44/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.6992 - acc: 0.9333 - val_loss: 0.8767 - val_acc: 0.7760\n", + "Epoch 45/100\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.6923 - acc: 0.9500 - val_loss: 0.8725 - val_acc: 0.7760\n", + "Epoch 46/100\n", + "1/1 [==============================] - 0s 97ms/step - loss: 0.6670 - acc: 0.9667 - val_loss: 0.8683 - val_acc: 0.7740\n", + "Epoch 47/100\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.6801 - acc: 0.9333 - val_loss: 0.8643 - val_acc: 0.7720\n", + "Epoch 48/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.6728 - acc: 0.9500 - val_loss: 0.8605 - val_acc: 0.7740\n", + "Epoch 49/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.6843 - acc: 0.9500 - val_loss: 0.8567 - val_acc: 0.7740\n", + "Epoch 50/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.6378 - acc: 0.9333 - val_loss: 0.8531 - val_acc: 0.7780\n", + "Epoch 51/100\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.6508 - acc: 0.9500 - val_loss: 0.8495 - val_acc: 0.7760\n", + "Epoch 52/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.6212 - acc: 0.9500 - val_loss: 0.8461 - val_acc: 0.7740\n", + "Epoch 53/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.6273 - acc: 0.9000 - val_loss: 0.8430 - val_acc: 0.7760\n", + "Epoch 54/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.6076 - acc: 0.9667 - val_loss: 0.8399 - val_acc: 0.7740\n", + "Epoch 55/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.5923 - acc: 0.9333 - val_loss: 0.8370 - val_acc: 0.7740\n", + "Epoch 56/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.6108 - acc: 0.9333 - val_loss: 0.8343 - val_acc: 0.7760\n", + "Epoch 57/100\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.5720 - acc: 0.9333 - val_loss: 0.8317 - val_acc: 0.7760\n", + "Epoch 58/100\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.5809 - acc: 0.9500 - val_loss: 0.8291 - val_acc: 0.7760\n", + "Epoch 59/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.6027 - acc: 0.9333 - val_loss: 0.8264 - val_acc: 0.7740\n", + "Epoch 60/100\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.5793 - acc: 0.9333 - val_loss: 0.8238 - val_acc: 0.7740\n", + "Epoch 61/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.5688 - acc: 0.9333 - val_loss: 0.8211 - val_acc: 0.7700\n", + "Epoch 62/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.5791 - acc: 0.8833 - val_loss: 0.8184 - val_acc: 0.7720\n", + "Epoch 63/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5321 - acc: 0.9667 - val_loss: 0.8156 - val_acc: 0.7720\n", + "Epoch 64/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5385 - acc: 0.9500 - val_loss: 0.8129 - val_acc: 0.7700\n", + "Epoch 65/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5376 - acc: 0.9833 - val_loss: 0.8100 - val_acc: 0.7740\n", + "Epoch 66/100\n", + "1/1 [==============================] - 0s 97ms/step - loss: 0.5222 - acc: 0.9833 - val_loss: 0.8073 - val_acc: 0.7780\n", + "Epoch 67/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5196 - acc: 0.9500 - val_loss: 0.8046 - val_acc: 0.7800\n", + "Epoch 68/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.5265 - acc: 0.9500 - val_loss: 0.8022 - val_acc: 0.7860\n", + "Epoch 69/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5388 - acc: 0.9500 - val_loss: 0.7998 - val_acc: 0.7900\n", + "Epoch 70/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4889 - acc: 0.9667 - val_loss: 0.7975 - val_acc: 0.7920\n", + "Epoch 71/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.5209 - acc: 0.9333 - val_loss: 0.7953 - val_acc: 0.7940\n", + "Epoch 72/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5043 - acc: 0.9500 - val_loss: 0.7933 - val_acc: 0.7960\n", + "Epoch 73/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.5118 - acc: 0.9667 - val_loss: 0.7914 - val_acc: 0.7960\n", + "Epoch 74/100\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4964 - acc: 0.9833 - val_loss: 0.7894 - val_acc: 0.7960\n", + "Epoch 75/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5301 - acc: 0.9500 - val_loss: 0.7874 - val_acc: 0.7940\n", + "Epoch 76/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4703 - acc: 0.9833 - val_loss: 0.7856 - val_acc: 0.7940\n", + "Epoch 77/100\n", + "1/1 [==============================] - 0s 106ms/step - loss: 0.4777 - acc: 0.9667 - val_loss: 0.7839 - val_acc: 0.7920\n", + "Epoch 78/100\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.4782 - acc: 0.9500 - val_loss: 0.7822 - val_acc: 0.7920\n", + "Epoch 79/100\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.4735 - acc: 0.9667 - val_loss: 0.7804 - val_acc: 0.7920\n", + "Epoch 80/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4739 - acc: 0.9333 - val_loss: 0.7786 - val_acc: 0.7920\n", + "Epoch 81/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5066 - acc: 0.9500 - val_loss: 0.7768 - val_acc: 0.7900\n", + "Epoch 82/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.4626 - acc: 0.9667 - val_loss: 0.7750 - val_acc: 0.7900\n", + "Epoch 83/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4573 - acc: 0.9667 - val_loss: 0.7733 - val_acc: 0.7920\n", + "Epoch 84/100\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4399 - acc: 1.0000 - val_loss: 0.7715 - val_acc: 0.7920\n", + "Epoch 85/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4751 - acc: 0.9667 - val_loss: 0.7698 - val_acc: 0.7940\n", + "Epoch 86/100\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4270 - acc: 0.9833 - val_loss: 0.7683 - val_acc: 0.7940\n", + "Epoch 87/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4611 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7940\n", + "Epoch 88/100\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.3962 - acc: 0.9833 - val_loss: 0.7652 - val_acc: 0.7980\n", + "Epoch 89/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.4901 - acc: 0.9667 - val_loss: 0.7640 - val_acc: 0.7980\n", + "Epoch 90/100\n", + "1/1 [==============================] - 0s 97ms/step - loss: 0.4380 - acc: 0.9500 - val_loss: 0.7630 - val_acc: 0.7980\n", + "Epoch 91/100\n", + "1/1 [==============================] - 0s 100ms/step - loss: 0.4628 - acc: 0.9500 - val_loss: 0.7620 - val_acc: 0.7960\n", + "Epoch 92/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4599 - acc: 0.9667 - val_loss: 0.7612 - val_acc: 0.7960\n", + "Epoch 93/100\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.4459 - acc: 1.0000 - val_loss: 0.7601 - val_acc: 0.7960\n", + "Epoch 94/100\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.4031 - acc: 0.9667 - val_loss: 0.7587 - val_acc: 0.7960\n", + "Epoch 95/100\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.4624 - acc: 0.9500 - val_loss: 0.7571 - val_acc: 0.7940\n", + "Epoch 96/100\n", + "1/1 [==============================] - 0s 103ms/step - loss: 0.4164 - acc: 0.9667 - val_loss: 0.7559 - val_acc: 0.7940\n", + "Epoch 97/100\n", + "1/1 [==============================] - 0s 97ms/step - loss: 0.4244 - acc: 0.9833 - val_loss: 0.7549 - val_acc: 0.7940\n", + "Epoch 98/100\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.4013 - acc: 0.9833 - val_loss: 0.7542 - val_acc: 0.7920\n", + "Epoch 99/100\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4387 - acc: 0.9667 - val_loss: 0.7537 - val_acc: 0.7920\n", + "Epoch 100/100\n", + "1/1 [==============================] - 0s 101ms/step - loss: 0.4377 - acc: 0.9667 - val_loss: 0.7528 - val_acc: 0.7920\n", + "1/1 [==============================] - 0s 39ms/step - loss: 0.7651 - acc: 0.7890\n", "Done.\n", - "Test loss: 0.7112346291542053\n", - "Test accuracy: 0.7930001020431519\n" + "Test loss: 0.77\n", + "Test accuracy: 0.79\n" ] } ], @@ -555,24 +401,17 @@ "from tensorflow.keras.optimizers import Adam\n", "\n", "from spektral.data.loaders import SingleLoader\n", - "from spektral.layers import GCNConv\n", "from spektral.models.gcn import GCN\n", - "from spektral.transforms import LayerPreprocess\n", "\n", "learning_rate = .01\n", "epochs = 100\n", "patience = 10\n", "\n", - "# We convert the binary masks to sample weights so that we can compute the\n", - "# average loss over the nodes (following original implementation by\n", - "# Kipf & Welling)\n", "def mask_to_weights(mask):\n", " return mask.astype(np.float32) / np.count_nonzero(mask)\n", "\n", - "\n", "weights_tr, weights_va, weights_te = (\n", - " mask_to_weights(mask)\n", - " for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n", + " mask_to_weights(mask) for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n", ")\n", "\n", "# define the model\n", @@ -601,9 +440,53 @@ "eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)\n", "\n", "print('Done.\\n'\n", - " f'Test loss: {eval_results[0]:2.f}\\n'\n", + " f'Test loss: {eval_results[0]:.2f}\\n'\n", " f'Test accuracy: {eval_results[1]:.2f}')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class PubMed(Dataset):\n", + "\n", + " def __init__(self, n_samples, n_nodes=3, n_min=10, n_max=100, p=0.1, **kwargs):\n", + " self.n_samples = n_samples\n", + " self.n_nodes = n_nodes\n", + " self.n_min = n_min\n", + " self.n_max = n_max\n", + " self.p = p\n", + " super().__init__(**kwargs)\n", + "\n", + " def read(self):\n", + " def make_graph():\n", + " n = np.random.randint(self.n_min, self.n_max)\n", + " colors = np.random.randint(0, self.n_colors, size=n)\n", + "\n", + " # Node features\n", + " x = np.zeros((n, self.n_colors))\n", + " x[np.arange(n), colors] = 1\n", + "\n", + " # Edges\n", + " a = np.random.rand(n, n) <= self.p\n", + " a = np.maximum(a, a.T).astype(int)\n", + " a = sp.csr_matrix(a)\n", + "\n", + " # Labels\n", + " y = np.zeros((self.n_colors,))\n", + " color_counts = x.sum(0)\n", + " y[np.argmax(color_counts)] = 1\n", + "\n", + " return Graph(x=x, a=a, y=y)\n", + "\n", + " # We must return a list of Graph objects\n", + " return [make_graph() for _ in range(self.n_samples)]\n", + "\n", + "\n", + "data = PubMed(1000, transforms=NormalizeAdj())\n" + ] } ], "metadata": {