diff --git a/gnn_pg.ipynb b/gnn_pg.ipynb new file mode 100644 index 0000000..923a908 --- /dev/null +++ b/gnn_pg.ipynb @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install numpy spektral tensorflow -q\n", + "import tensorflow as tf\n", + "import numpy as np\n", + "import spektral\n", + "from spektral.datasets.citation import Citation\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pre-processing node features\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/scipy/sparse/_index.py:125: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n", + " self._set_arrayXarray(i, j, x)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pre-processing node features\n", + "epoch: 0,train_loss: 1.0980415344238281,val_acc: 0.4699999988079071,test_acc: 0.48900002241134644\n", + "epoch: 1,train_loss: 1.0857902765274048,val_acc: 0.531999945640564,test_acc: 0.5269998908042908\n", + "epoch: 5,train_loss: 1.0279830694198608,val_acc: 0.5459999442100525,test_acc: 0.5679998993873596\n", + "epoch: 6,train_loss: 1.0107402801513672,val_acc: 0.6380000710487366,test_acc: 0.6259998679161072\n", + "epoch: 7,train_loss: 0.9920498132705688,val_acc: 0.6820001006126404,test_acc: 0.6719997525215149\n", + "epoch: 8,train_loss: 0.9725067615509033,val_acc: 0.7120000720024109,test_acc: 0.6989997625350952\n", + "epoch: 9,train_loss: 0.9518073201179504,val_acc: 0.724000096321106,test_acc: 0.7109997868537903\n", + "epoch: 10,train_loss: 0.929932177066803,val_acc: 0.736000120639801,test_acc: 0.7179997563362122\n", + "epoch: 14,train_loss: 0.8337627053260803,val_acc: 0.7380000948905945,test_acc: 0.7209997773170471\n", + "epoch: 15,train_loss: 0.8073488473892212,val_acc: 0.7420001029968262,test_acc: 0.7239997982978821\n" + ] + } + ], + "source": [ + "dataset = Citation('pubmed', normalize_x=True)\n", + "dataset = Citation('pubmed', normalize_x=True, transforms=[LayerPreprocess(GCNConv)])\n", + "\n", + "adj = dataset.graphs[0].a.todense()\n", + "features = dataset.graphs[0].x\n", + "labels = dataset.graphs[0].y\n", + "train_mask, val_mask, test_mask = dataset.mask_tr, dataset.mask_va, dataset.mask_te\n", + "\n", + "def masked_softmax_cross_entropy(logits, labels, mask):\n", + " loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n", + " mask = tf.cast(mask, dtype=tf.float32)\n", + " mask /= tf.reduce_mean(mask)\n", + " loss *= mask\n", + " return tf.reduce_mean(loss)\n", + "\n", + "def masked_accuracy(logits, labels, mask):\n", + " correct_preds = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))\n", + " accuracy = tf.cast(correct_preds, dtype=tf.float32)\n", + " mask = tf.cast(mask, dtype=tf.float32)\n", + " mask /= tf.reduce_mean(mask)\n", + " accuracy *= mask\n", + " return tf.reduce_mean(accuracy)\n", + "\n", + "def gnn_fn(features, adj, transform, activation):\n", + " seq_fts = transform(features)\n", + " ret_fts = tf.matmul(adj, seq_fts)\n", + " return activation(ret_fts)\n", + "\n", + "def train(features, adj, gnn_fn, units, epochs, lr=0.01):\n", + " lyr_1 = tf.keras.layers.Dense(units)\n", + " lyr_2 = tf.keras.layers.Dense(3)\n", + " \n", + " def gnn_net(features, adj):\n", + " hidden = gnn_fn(features, adj, lyr_1, tf.nn.relu)\n", + " logits = gnn_fn(hidden, adj, lyr_2, tf.identity)\n", + " return logits\n", + " \n", + " optimizer = tf.keras.optimizers.Adam(learning_rate=lr)\n", + " best_accuracy = 0.0\n", + " \n", + " for epoch in range(epochs+1):\n", + " with tf.GradientTape() as t:\n", + " logits = gnn_net(features, adj)\n", + " loss = masked_softmax_cross_entropy(logits, labels, train_mask)\n", + "\n", + " variables = t.watched_variables()\n", + " grads = t.gradient(loss, variables)\n", + " optimizer.apply_gradients(zip(grads, variables))\n", + "\n", + " logits = gnn_net(features, adj)\n", + " val_accuracy = masked_accuracy(logits, labels, val_mask)\n", + " test_accuracy = masked_accuracy(logits, labels, test_mask)\n", + " \n", + " if val_accuracy > best_accuracy:\n", + " best_accuracy = val_accuracy\n", + " print(f'epoch: {epoch},'\n", + " f'train_loss: {loss.numpy()},'\n", + " f'val_acc: {val_accuracy.numpy()},'\n", + " f'test_acc: {test_accuracy.numpy()}')\n", + "\n", + "train(features, adj, gnn_fn, 32, 20, 0.01)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0,train_loss: 1.0996336936950684,val_acc: 0.5119999647140503,test_acc: 0.4919999837875366\n", + "epoch: 2,train_loss: 1.073289394378662,val_acc: 0.5379999876022339,test_acc: 0.5219999551773071\n", + "epoch: 3,train_loss: 1.0569909811019897,val_acc: 0.6000000238418579,test_acc: 0.5999998450279236\n", + "epoch: 4,train_loss: 1.0379494428634644,val_acc: 0.6700000166893005,test_acc: 0.6739997863769531\n", + "epoch: 5,train_loss: 1.0155421495437622,val_acc: 0.7160000801086426,test_acc: 0.7109997868537903\n", + "epoch: 6,train_loss: 0.9916436672210693,val_acc: 0.7280001044273376,test_acc: 0.7139998078346252\n", + "epoch: 9,train_loss: 0.9153128266334534,val_acc: 0.7300000786781311,test_acc: 0.7069997787475586\n", + "epoch: 11,train_loss: 0.8576377034187317,val_acc: 0.7320001125335693,test_acc: 0.711999773979187\n", + "epoch: 12,train_loss: 0.8269254565238953,val_acc: 0.7380000948905945,test_acc: 0.714999794960022\n", + "epoch: 13,train_loss: 0.7955628037452698,val_acc: 0.7440001368522644,test_acc: 0.7159997820854187\n", + "epoch: 15,train_loss: 0.7319157123565674,val_acc: 0.7460001111030579,test_acc: 0.714999794960022\n", + "epoch: 17,train_loss: 0.6676653623580933,val_acc: 0.7500001192092896,test_acc: 0.7179997563362122\n", + "epoch: 18,train_loss: 0.6357672810554504,val_acc: 0.752000093460083,test_acc: 0.7169997692108154\n" + ] + } + ], + "source": [ + "# on identity train(features, tf.eye(adj.shape[0]), gnn_fn, 32, 20, 0.01)\n", + "\n", + "# normalize adj by degree\n", + "deg = tf.reduce_sum(adj, axis=-1)\n", + "train(features, adj / deg, gnn_fn, 32, 20, 0.01)" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/3_/gmvd1nkx285133z5yh3chz2c0000gp/T/ipykernel_40315/3405598776.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mnorm_deg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1.9\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdeg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mnorm_adj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorm_deg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_deg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_adj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgnn_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;34m\"\"\"Call target, and fall back on dispatchers if there is a TypeError.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\u001b[0m in \u001b[0;36mmatmul\u001b[0;34m(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, output_type, name)\u001b[0m\n\u001b[1;32m 3652\u001b[0m a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name)\n\u001b[1;32m 3653\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3654\u001b[0;31m return gen_math_ops.mat_mul(\n\u001b[0m\u001b[1;32m 3655\u001b[0m a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)\n\u001b[1;32m 3656\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py\u001b[0m in \u001b[0;36mmat_mul\u001b[0;34m(a, b, transpose_a, transpose_b, name)\u001b[0m\n\u001b[1;32m 5689\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtld\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_eager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5690\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5691\u001b[0;31m _result = pywrap_tfe.TFE_Py_FastPathExecute(\n\u001b[0m\u001b[1;32m 5692\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"MatMul\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"transpose_a\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtranspose_a\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"transpose_b\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5693\u001b[0m transpose_b)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# thomas gibb in graph conv networks: normalize by 1/sqrt(deg)\n", + "\n", + "norm_deg = tf.linalg.diag(1.9 / tf.sqrt(deg))\n", + "norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n", + "train(features, norm_adj, gnn_fn, 32, 200, 0.01)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another impl:" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/1000\n", + "1/1 [==============================] - 0s 462ms/step - loss: 1.1087 - acc: 0.2833 - val_loss: 1.1010 - val_acc: 0.6500\n", + "Epoch 2/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 1.1005 - acc: 0.6167 - val_loss: 1.0974 - val_acc: 0.7500\n", + "Epoch 3/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 1.0946 - acc: 0.7500 - val_loss: 1.0928 - val_acc: 0.7600\n", + "Epoch 4/1000\n", + "1/1 [==============================] - 0s 101ms/step - loss: 1.0894 - acc: 0.8833 - val_loss: 1.0872 - val_acc: 0.7560\n", + "Epoch 5/1000\n", + "1/1 [==============================] - 0s 93ms/step - loss: 1.0805 - acc: 0.8167 - val_loss: 1.0809 - val_acc: 0.7520\n", + "Epoch 6/1000\n", + "1/1 [==============================] - 0s 98ms/step - loss: 1.0727 - acc: 0.8667 - val_loss: 1.0744 - val_acc: 0.7580\n", + "Epoch 7/1000\n", + "1/1 [==============================] - 0s 95ms/step - loss: 1.0636 - acc: 0.9000 - val_loss: 1.0681 - val_acc: 0.7520\n", + "Epoch 8/1000\n", + "1/1 [==============================] - 0s 105ms/step - loss: 1.0563 - acc: 0.8667 - val_loss: 1.0619 - val_acc: 0.7460\n", + "Epoch 9/1000\n", + "1/1 [==============================] - 0s 113ms/step - loss: 1.0451 - acc: 0.8667 - val_loss: 1.0557 - val_acc: 0.7460\n", + "Epoch 10/1000\n", + "1/1 [==============================] - 0s 120ms/step - loss: 1.0412 - acc: 0.8667 - val_loss: 1.0497 - val_acc: 0.7460\n", + "Epoch 11/1000\n", + "1/1 [==============================] - 0s 111ms/step - loss: 1.0284 - acc: 0.8833 - val_loss: 1.0435 - val_acc: 0.7460\n", + "Epoch 12/1000\n", + "1/1 [==============================] - 0s 101ms/step - loss: 1.0180 - acc: 0.9000 - val_loss: 1.0375 - val_acc: 0.7480\n", + "Epoch 13/1000\n", + "1/1 [==============================] - 0s 93ms/step - loss: 1.0036 - acc: 0.9000 - val_loss: 1.0315 - val_acc: 0.7520\n", + "Epoch 14/1000\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.9942 - acc: 0.9333 - val_loss: 1.0255 - val_acc: 0.7580\n", + "Epoch 15/1000\n", + "1/1 [==============================] - 0s 94ms/step - loss: 0.9864 - acc: 0.9167 - val_loss: 1.0196 - val_acc: 0.7620\n", + "Epoch 16/1000\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.9864 - acc: 0.9167 - val_loss: 1.0137 - val_acc: 0.7620\n", + "Epoch 17/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.9612 - acc: 0.9500 - val_loss: 1.0077 - val_acc: 0.7620\n", + "Epoch 18/1000\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.9546 - acc: 0.9167 - val_loss: 1.0018 - val_acc: 0.7640\n", + "Epoch 19/1000\n", + "1/1 [==============================] - 0s 104ms/step - loss: 0.9423 - acc: 0.9500 - val_loss: 0.9957 - val_acc: 0.7660\n", + "Epoch 20/1000\n", + "1/1 [==============================] - 0s 108ms/step - loss: 0.9271 - acc: 0.9167 - val_loss: 0.9895 - val_acc: 0.7620\n", + "Epoch 21/1000\n", + "1/1 [==============================] - 0s 101ms/step - loss: 0.9319 - acc: 0.8833 - val_loss: 0.9832 - val_acc: 0.7620\n", + "Epoch 22/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.9068 - acc: 0.9000 - val_loss: 0.9770 - val_acc: 0.7640\n", + "Epoch 23/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.8857 - acc: 0.8667 - val_loss: 0.9708 - val_acc: 0.7640\n", + "Epoch 24/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.8685 - acc: 0.8833 - val_loss: 0.9646 - val_acc: 0.7640\n", + "Epoch 25/1000\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.8696 - acc: 0.9333 - val_loss: 0.9583 - val_acc: 0.7640\n", + "Epoch 26/1000\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.8706 - acc: 0.9500 - val_loss: 0.9521 - val_acc: 0.7640\n", + "Epoch 27/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.8500 - acc: 0.9167 - val_loss: 0.9459 - val_acc: 0.7680\n", + "Epoch 28/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.8448 - acc: 0.9167 - val_loss: 0.9399 - val_acc: 0.7700\n", + "Epoch 29/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.8462 - acc: 0.9333 - val_loss: 0.9339 - val_acc: 0.7660\n", + "Epoch 30/1000\n", + "1/1 [==============================] - 0s 84ms/step - loss: 0.8169 - acc: 0.9000 - val_loss: 0.9280 - val_acc: 0.7680\n", + "Epoch 31/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.7901 - acc: 0.9500 - val_loss: 0.9223 - val_acc: 0.7700\n", + "Epoch 32/1000\n", + "1/1 [==============================] - 0s 95ms/step - loss: 0.8125 - acc: 0.9500 - val_loss: 0.9167 - val_acc: 0.7700\n", + "Epoch 33/1000\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.7813 - acc: 0.9667 - val_loss: 0.9110 - val_acc: 0.7700\n", + "Epoch 34/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.7728 - acc: 0.9167 - val_loss: 0.9055 - val_acc: 0.7700\n", + "Epoch 35/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.7708 - acc: 0.9333 - val_loss: 0.9001 - val_acc: 0.7720\n", + "Epoch 36/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.7519 - acc: 0.9333 - val_loss: 0.8949 - val_acc: 0.7700\n", + "Epoch 37/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.7435 - acc: 0.9500 - val_loss: 0.8899 - val_acc: 0.7700\n", + "Epoch 38/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.7134 - acc: 0.9333 - val_loss: 0.8849 - val_acc: 0.7680\n", + "Epoch 39/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.7055 - acc: 0.9500 - val_loss: 0.8800 - val_acc: 0.7680\n", + "Epoch 40/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.6887 - acc: 0.9667 - val_loss: 0.8752 - val_acc: 0.7680\n", + "Epoch 41/1000\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.6847 - acc: 0.9667 - val_loss: 0.8703 - val_acc: 0.7680\n", + "Epoch 42/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.6918 - acc: 0.9167 - val_loss: 0.8655 - val_acc: 0.7700\n", + "Epoch 43/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.6699 - acc: 0.9333 - val_loss: 0.8608 - val_acc: 0.7700\n", + "Epoch 44/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.6433 - acc: 0.9667 - val_loss: 0.8565 - val_acc: 0.7720\n", + "Epoch 45/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.6713 - acc: 0.9167 - val_loss: 0.8523 - val_acc: 0.7720\n", + "Epoch 46/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.6396 - acc: 0.9667 - val_loss: 0.8482 - val_acc: 0.7720\n", + "Epoch 47/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.6291 - acc: 0.9167 - val_loss: 0.8444 - val_acc: 0.7740\n", + "Epoch 48/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.6029 - acc: 0.9667 - val_loss: 0.8407 - val_acc: 0.7700\n", + "Epoch 49/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.6102 - acc: 0.9500 - val_loss: 0.8371 - val_acc: 0.7700\n", + "Epoch 50/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.6154 - acc: 0.9167 - val_loss: 0.8335 - val_acc: 0.7700\n", + "Epoch 51/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.6160 - acc: 0.9500 - val_loss: 0.8302 - val_acc: 0.7700\n", + "Epoch 52/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.5654 - acc: 0.9667 - val_loss: 0.8271 - val_acc: 0.7680\n", + "Epoch 53/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.6097 - acc: 0.9500 - val_loss: 0.8242 - val_acc: 0.7700\n", + "Epoch 54/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.5976 - acc: 0.9500 - val_loss: 0.8213 - val_acc: 0.7720\n", + "Epoch 55/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.5914 - acc: 0.9000 - val_loss: 0.8188 - val_acc: 0.7720\n", + "Epoch 56/1000\n", + "1/1 [==============================] - 0s 93ms/step - loss: 0.5761 - acc: 0.9667 - val_loss: 0.8164 - val_acc: 0.7760\n", + "Epoch 57/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.5613 - acc: 0.9667 - val_loss: 0.8141 - val_acc: 0.7800\n", + "Epoch 58/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5773 - acc: 0.9500 - val_loss: 0.8117 - val_acc: 0.7800\n", + "Epoch 59/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5665 - acc: 0.9500 - val_loss: 0.8094 - val_acc: 0.7800\n", + "Epoch 60/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5408 - acc: 0.9667 - val_loss: 0.8069 - val_acc: 0.7800\n", + "Epoch 61/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.5425 - acc: 0.9500 - val_loss: 0.8043 - val_acc: 0.7840\n", + "Epoch 62/1000\n", + "1/1 [==============================] - 0s 104ms/step - loss: 0.5627 - acc: 0.9167 - val_loss: 0.8019 - val_acc: 0.7840\n", + "Epoch 63/1000\n", + "1/1 [==============================] - 0s 96ms/step - loss: 0.5338 - acc: 0.9333 - val_loss: 0.7996 - val_acc: 0.7820\n", + "Epoch 64/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.5082 - acc: 0.9667 - val_loss: 0.7972 - val_acc: 0.7820\n", + "Epoch 65/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.5194 - acc: 0.9667 - val_loss: 0.7950 - val_acc: 0.7820\n", + "Epoch 66/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4943 - acc: 0.9333 - val_loss: 0.7929 - val_acc: 0.7840\n", + "Epoch 67/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.5025 - acc: 0.9667 - val_loss: 0.7909 - val_acc: 0.7820\n", + "Epoch 68/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4835 - acc: 0.9833 - val_loss: 0.7891 - val_acc: 0.7800\n", + "Epoch 69/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.5119 - acc: 0.9833 - val_loss: 0.7874 - val_acc: 0.7800\n", + "Epoch 70/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5110 - acc: 0.9667 - val_loss: 0.7857 - val_acc: 0.7800\n", + "Epoch 71/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4949 - acc: 0.9333 - val_loss: 0.7842 - val_acc: 0.7800\n", + "Epoch 72/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4662 - acc: 0.9667 - val_loss: 0.7828 - val_acc: 0.7800\n", + "Epoch 73/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.5109 - acc: 0.9833 - val_loss: 0.7810 - val_acc: 0.7780\n", + "Epoch 74/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.4513 - acc: 0.9833 - val_loss: 0.7792 - val_acc: 0.7820\n", + "Epoch 75/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.4873 - acc: 0.9500 - val_loss: 0.7772 - val_acc: 0.7840\n", + "Epoch 76/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4595 - acc: 0.9667 - val_loss: 0.7753 - val_acc: 0.7820\n", + "Epoch 77/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.5105 - acc: 0.9500 - val_loss: 0.7736 - val_acc: 0.7860\n", + "Epoch 78/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.4737 - acc: 0.9667 - val_loss: 0.7717 - val_acc: 0.7880\n", + "Epoch 79/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4541 - acc: 0.9833 - val_loss: 0.7702 - val_acc: 0.7900\n", + "Epoch 80/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.4477 - acc: 1.0000 - val_loss: 0.7685 - val_acc: 0.7900\n", + "Epoch 81/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.4596 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7920\n", + "Epoch 82/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4661 - acc: 0.9333 - val_loss: 0.7651 - val_acc: 0.7920\n", + "Epoch 83/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4452 - acc: 0.9500 - val_loss: 0.7639 - val_acc: 0.7940\n", + "Epoch 84/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4559 - acc: 0.9667 - val_loss: 0.7631 - val_acc: 0.7940\n", + "Epoch 85/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4705 - acc: 0.9500 - val_loss: 0.7623 - val_acc: 0.7960\n", + "Epoch 86/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.4418 - acc: 0.9500 - val_loss: 0.7616 - val_acc: 0.7940\n", + "Epoch 87/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4240 - acc: 0.9667 - val_loss: 0.7610 - val_acc: 0.7940\n", + "Epoch 88/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4489 - acc: 0.9667 - val_loss: 0.7603 - val_acc: 0.7940\n", + "Epoch 89/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4326 - acc: 0.9833 - val_loss: 0.7596 - val_acc: 0.7920\n", + "Epoch 90/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4337 - acc: 0.9667 - val_loss: 0.7586 - val_acc: 0.7900\n", + "Epoch 91/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4263 - acc: 0.9667 - val_loss: 0.7579 - val_acc: 0.7900\n", + "Epoch 92/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.4165 - acc: 1.0000 - val_loss: 0.7568 - val_acc: 0.7900\n", + "Epoch 93/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.4393 - acc: 0.9667 - val_loss: 0.7557 - val_acc: 0.7900\n", + "Epoch 94/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.4316 - acc: 1.0000 - val_loss: 0.7544 - val_acc: 0.7900\n", + "Epoch 95/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.4527 - acc: 0.9667 - val_loss: 0.7529 - val_acc: 0.7900\n", + "Epoch 96/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4238 - acc: 0.9667 - val_loss: 0.7512 - val_acc: 0.7880\n", + "Epoch 97/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.3868 - acc: 0.9833 - val_loss: 0.7496 - val_acc: 0.7940\n", + "Epoch 98/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3973 - acc: 1.0000 - val_loss: 0.7478 - val_acc: 0.7940\n", + "Epoch 99/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4228 - acc: 0.9667 - val_loss: 0.7459 - val_acc: 0.7940\n", + "Epoch 100/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3667 - acc: 0.9833 - val_loss: 0.7443 - val_acc: 0.7960\n", + "Epoch 101/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3947 - acc: 0.9667 - val_loss: 0.7430 - val_acc: 0.7960\n", + "Epoch 102/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3955 - acc: 0.9667 - val_loss: 0.7416 - val_acc: 0.7940\n", + "Epoch 103/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.4087 - acc: 0.9833 - val_loss: 0.7402 - val_acc: 0.7940\n", + "Epoch 104/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.4469 - acc: 0.9833 - val_loss: 0.7387 - val_acc: 0.7940\n", + "Epoch 105/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3881 - acc: 0.9833 - val_loss: 0.7375 - val_acc: 0.7940\n", + "Epoch 106/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3735 - acc: 0.9833 - val_loss: 0.7365 - val_acc: 0.7940\n", + "Epoch 107/1000\n", + "1/1 [==============================] - 0s 84ms/step - loss: 0.4000 - acc: 0.9667 - val_loss: 0.7358 - val_acc: 0.7940\n", + "Epoch 108/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3820 - acc: 1.0000 - val_loss: 0.7352 - val_acc: 0.7960\n", + "Epoch 109/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.3557 - acc: 0.9833 - val_loss: 0.7345 - val_acc: 0.7980\n", + "Epoch 110/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3602 - acc: 1.0000 - val_loss: 0.7337 - val_acc: 0.7960\n", + "Epoch 111/1000\n", + "1/1 [==============================] - 0s 98ms/step - loss: 0.4138 - acc: 0.9667 - val_loss: 0.7329 - val_acc: 0.7940\n", + "Epoch 112/1000\n", + "1/1 [==============================] - 0s 83ms/step - loss: 0.3778 - acc: 0.9667 - val_loss: 0.7324 - val_acc: 0.7960\n", + "Epoch 113/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.3877 - acc: 0.9667 - val_loss: 0.7322 - val_acc: 0.7920\n", + "Epoch 114/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3661 - acc: 1.0000 - val_loss: 0.7317 - val_acc: 0.7920\n", + "Epoch 115/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3956 - acc: 0.9667 - val_loss: 0.7310 - val_acc: 0.7920\n", + "Epoch 116/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3616 - acc: 1.0000 - val_loss: 0.7300 - val_acc: 0.7920\n", + "Epoch 117/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.3594 - acc: 0.9667 - val_loss: 0.7291 - val_acc: 0.7880\n", + "Epoch 118/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3739 - acc: 0.9833 - val_loss: 0.7285 - val_acc: 0.7860\n", + "Epoch 119/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.3778 - acc: 0.9833 - val_loss: 0.7280 - val_acc: 0.7880\n", + "Epoch 120/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3669 - acc: 1.0000 - val_loss: 0.7274 - val_acc: 0.7900\n", + "Epoch 121/1000\n", + "1/1 [==============================] - 0s 85ms/step - loss: 0.3288 - acc: 1.0000 - val_loss: 0.7266 - val_acc: 0.7880\n", + "Epoch 122/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.3726 - acc: 0.9833 - val_loss: 0.7259 - val_acc: 0.7880\n", + "Epoch 123/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3813 - acc: 0.9667 - val_loss: 0.7250 - val_acc: 0.7900\n", + "Epoch 124/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3615 - acc: 0.9833 - val_loss: 0.7240 - val_acc: 0.7920\n", + "Epoch 125/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3283 - acc: 1.0000 - val_loss: 0.7229 - val_acc: 0.7920\n", + "Epoch 126/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3312 - acc: 0.9833 - val_loss: 0.7215 - val_acc: 0.7940\n", + "Epoch 127/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.4086 - acc: 0.9667 - val_loss: 0.7206 - val_acc: 0.7960\n", + "Epoch 128/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3588 - acc: 0.9667 - val_loss: 0.7195 - val_acc: 0.7920\n", + "Epoch 129/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3678 - acc: 1.0000 - val_loss: 0.7184 - val_acc: 0.7920\n", + "Epoch 130/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3201 - acc: 0.9833 - val_loss: 0.7175 - val_acc: 0.7920\n", + "Epoch 131/1000\n", + "1/1 [==============================] - 0s 100ms/step - loss: 0.3630 - acc: 0.9833 - val_loss: 0.7167 - val_acc: 0.7920\n", + "Epoch 132/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3601 - acc: 0.9667 - val_loss: 0.7163 - val_acc: 0.7900\n", + "Epoch 133/1000\n", + "1/1 [==============================] - 0s 83ms/step - loss: 0.3231 - acc: 1.0000 - val_loss: 0.7158 - val_acc: 0.7900\n", + "Epoch 134/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3124 - acc: 1.0000 - val_loss: 0.7150 - val_acc: 0.7900\n", + "Epoch 135/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.3430 - acc: 0.9833 - val_loss: 0.7142 - val_acc: 0.7920\n", + "Epoch 136/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3344 - acc: 0.9833 - val_loss: 0.7135 - val_acc: 0.7900\n", + "Epoch 137/1000\n", + "1/1 [==============================] - 0s 84ms/step - loss: 0.3537 - acc: 0.9833 - val_loss: 0.7130 - val_acc: 0.7900\n", + "Epoch 138/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3424 - acc: 0.9667 - val_loss: 0.7128 - val_acc: 0.7880\n", + "Epoch 139/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3199 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n", + "Epoch 140/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3173 - acc: 0.9833 - val_loss: 0.7127 - val_acc: 0.7860\n", + "Epoch 141/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3474 - acc: 0.9833 - val_loss: 0.7126 - val_acc: 0.7840\n", + "Epoch 142/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3292 - acc: 1.0000 - val_loss: 0.7127 - val_acc: 0.7840\n", + "Epoch 143/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3552 - acc: 0.9667 - val_loss: 0.7127 - val_acc: 0.7820\n", + "Epoch 144/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3256 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n", + "Epoch 145/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3204 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n", + "Epoch 146/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3133 - acc: 0.9833 - val_loss: 0.7122 - val_acc: 0.7800\n", + "Epoch 147/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.3209 - acc: 1.0000 - val_loss: 0.7116 - val_acc: 0.7800\n", + "Epoch 148/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3291 - acc: 1.0000 - val_loss: 0.7100 - val_acc: 0.7820\n", + "Epoch 149/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3524 - acc: 0.9833 - val_loss: 0.7079 - val_acc: 0.7840\n", + "Epoch 150/1000\n", + "1/1 [==============================] - 0s 99ms/step - loss: 0.3336 - acc: 1.0000 - val_loss: 0.7057 - val_acc: 0.7860\n", + "Epoch 151/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.3318 - acc: 0.9667 - val_loss: 0.7037 - val_acc: 0.7900\n", + "Epoch 152/1000\n", + "1/1 [==============================] - 0s 82ms/step - loss: 0.3615 - acc: 0.9833 - val_loss: 0.7022 - val_acc: 0.7920\n", + "Epoch 153/1000\n", + "1/1 [==============================] - 0s 86ms/step - loss: 0.3089 - acc: 1.0000 - val_loss: 0.7011 - val_acc: 0.7920\n", + "Epoch 154/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3616 - acc: 0.9667 - val_loss: 0.7006 - val_acc: 0.7920\n", + "Epoch 155/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.2931 - acc: 1.0000 - val_loss: 0.7000 - val_acc: 0.7940\n", + "Epoch 156/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3272 - acc: 1.0000 - val_loss: 0.6993 - val_acc: 0.7960\n", + "Epoch 157/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3255 - acc: 0.9833 - val_loss: 0.6992 - val_acc: 0.7980\n", + "Epoch 158/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3066 - acc: 0.9833 - val_loss: 0.6993 - val_acc: 0.7980\n", + "Epoch 159/1000\n", + "1/1 [==============================] - 0s 89ms/step - loss: 0.3397 - acc: 0.9500 - val_loss: 0.6993 - val_acc: 0.7960\n", + "Epoch 160/1000\n", + "1/1 [==============================] - 0s 91ms/step - loss: 0.3396 - acc: 0.9833 - val_loss: 0.6993 - val_acc: 0.7960\n", + "Epoch 161/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3275 - acc: 1.0000 - val_loss: 0.6997 - val_acc: 0.7960\n", + "Epoch 162/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.3218 - acc: 0.9667 - val_loss: 0.7001 - val_acc: 0.7940\n", + "Epoch 163/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3345 - acc: 0.9833 - val_loss: 0.7006 - val_acc: 0.7940\n", + "Epoch 164/1000\n", + "1/1 [==============================] - 0s 88ms/step - loss: 0.3028 - acc: 0.9833 - val_loss: 0.7010 - val_acc: 0.7940\n", + "Epoch 165/1000\n", + "1/1 [==============================] - 0s 90ms/step - loss: 0.3109 - acc: 1.0000 - val_loss: 0.7012 - val_acc: 0.7940\n", + "Epoch 166/1000\n", + "1/1 [==============================] - 0s 87ms/step - loss: 0.3116 - acc: 0.9667 - val_loss: 0.7015 - val_acc: 0.7940\n", + "Epoch 167/1000\n", + "1/1 [==============================] - 0s 92ms/step - loss: 0.2926 - acc: 1.0000 - val_loss: 0.7014 - val_acc: 0.7900\n", + "1/1 [==============================] - 0s 38ms/step - loss: 0.7112 - acc: 0.7930\n", + "Done.\n", + "Test loss: 0.7112346291542053\n", + "Test accuracy: 0.7930001020431519\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "This example implements the experiments on citation networks from the paper:\n", + "Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907)\n", + "Thomas N. Kipf, Max Welling\n", + "\"\"\"\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow.keras.callbacks import EarlyStopping\n", + "from tensorflow.keras.losses import CategoricalCrossentropy\n", + "from tensorflow.keras.optimizers import Adam\n", + "\n", + "from spektral.data.loaders import SingleLoader\n", + "from spektral.layers import GCNConv\n", + "from spektral.models.gcn import GCN\n", + "from spektral.transforms import LayerPreprocess\n", + "\n", + "learning_rate = .01\n", + "epochs = 100\n", + "patience = 10\n", + "\n", + "# We convert the binary masks to sample weights so that we can compute the\n", + "# average loss over the nodes (following original implementation by\n", + "# Kipf & Welling)\n", + "def mask_to_weights(mask):\n", + " return mask.astype(np.float32) / np.count_nonzero(mask)\n", + "\n", + "\n", + "weights_tr, weights_va, weights_te = (\n", + " mask_to_weights(mask)\n", + " for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n", + ")\n", + "\n", + "# define the model\n", + "model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)\n", + "model.compile(\n", + " optimizer=Adam(learning_rate),\n", + " loss=CategoricalCrossentropy(reduction=\"sum\"),\n", + " weighted_metrics=[\"acc\"],\n", + ")\n", + "\n", + "loader_tr = SingleLoader(dataset, sample_weights=weights_tr)\n", + "loader_va = SingleLoader(dataset, sample_weights=weights_va)\n", + "loader_te = SingleLoader(dataset, sample_weights=weights_te)\n", + "\n", + "# Train\n", + "model.fit(\n", + " loader_tr.load(),\n", + " steps_per_epoch=loader_tr.steps_per_epoch,\n", + " validation_data=loader_va.load(),\n", + " validation_steps=loader_va.steps_per_epoch,\n", + " epochs=epochs,\n", + " callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],\n", + ")\n", + "\n", + "# Evaluate\n", + "eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)\n", + "\n", + "print('Done.\\n'\n", + " f'Test loss: {eval_results[0]:2.f}\\n'\n", + " f'Test accuracy: {eval_results[1]:.2f}')" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "4d4c55ad0dd25f9ca95e4d49a929aa3f71bfb37020ae570a9996c3e164818202" + }, + "kernelspec": { + "display_name": "Python 3.9.7 64-bit ('py3': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}