{ "cells": [ { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install numpy spektral tensorflow -q\n", "import tensorflow as tf\n", "import numpy as np\n", "import spektral\n", "from spektral.datasets.citation import Citation\n", "from tqdm import tqdm\n", "from spektral.transforms import LayerPreprocess\n", "from spektral.layers import GCNConv\n", "from spektral.transforms import NormalizeAdj\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pre-processing node features\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/scipy/sparse/_index.py:125: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n", " self._set_arrayXarray(i, j, x)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Pre-processing node features\n" ] } ], "source": [ "dataset = Citation('pubmed', normalize_x=True)\n", "dataset = Citation('pubmed', normalize_x=True, transforms=[LayerPreprocess(GCNConv)])\n", "\n", "adj = dataset.graphs[0].a.todense()\n", "features = dataset.graphs[0].x\n", "labels = dataset.graphs[0].y\n", "train_mask, val_mask, test_mask = dataset.mask_tr, dataset.mask_va, dataset.mask_te\n", "\n", "def masked_softmax_cross_entropy(logits, labels, mask):\n", " loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n", " mask = tf.cast(mask, dtype=tf.float32)\n", " mask /= tf.reduce_mean(mask)\n", " loss *= mask\n", " return tf.reduce_mean(loss)\n", "\n", "def masked_accuracy(logits, labels, mask):\n", " correct_preds = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))\n", " accuracy = tf.cast(correct_preds, dtype=tf.float32)\n", " mask = tf.cast(mask, dtype=tf.float32)\n", " mask /= tf.reduce_mean(mask)\n", " accuracy *= mask\n", " return tf.reduce_mean(accuracy)\n", "\n", "def gnn_fn(features, adj, transform, activation):\n", " seq_fts = transform(features)\n", " ret_fts = tf.matmul(adj, seq_fts)\n", " return activation(ret_fts)\n", "\n", "def train(features, adj, gnn_fn, units, epochs, lr=0.01):\n", " lyr_1 = tf.keras.layers.Dense(units)\n", " lyr_2 = tf.keras.layers.Dense(3)\n", " \n", " def gnn_net(features, adj):\n", " hidden = gnn_fn(features, adj, lyr_1, tf.nn.relu)\n", " logits = gnn_fn(hidden, adj, lyr_2, tf.identity)\n", " return logits\n", " \n", " optimizer = tf.keras.optimizers.Adam(learning_rate=lr)\n", " best_accuracy = 0.0\n", " \n", " for epoch in range(epochs+1):\n", " with tf.GradientTape() as t:\n", " logits = gnn_net(features, adj)\n", " loss = masked_softmax_cross_entropy(logits, labels, train_mask)\n", "\n", " variables = t.watched_variables()\n", " grads = t.gradient(loss, variables)\n", " optimizer.apply_gradients(zip(grads, variables))\n", "\n", " logits = gnn_net(features, adj)\n", " val_accuracy = masked_accuracy(logits, labels, val_mask)\n", " test_accuracy = masked_accuracy(logits, labels, test_mask)\n", " \n", " if val_accuracy > best_accuracy:\n", " best_accuracy = val_accuracy\n", " print(f'epoch: {epoch},'\n", " f'train_loss: {loss.numpy()},'\n", " f'val_acc: {val_accuracy.numpy()},'\n", " f'test_acc: {test_accuracy.numpy()}')\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0,train_loss: 1.0988209247589111,val_acc: 0.49799996614456177,test_acc: 0.5019999742507935\n", "epoch: 3,train_loss: 1.0498652458190918,val_acc: 0.515999972820282,test_acc: 0.528999924659729\n", "epoch: 4,train_loss: 1.0287455320358276,val_acc: 0.5720000267028809,test_acc: 0.6109998226165771\n", "epoch: 5,train_loss: 1.0060378313064575,val_acc: 0.6500000357627869,test_acc: 0.6719998121261597\n", "epoch: 6,train_loss: 0.9812273383140564,val_acc: 0.6760000586509705,test_acc: 0.6909998059272766\n", "epoch: 7,train_loss: 0.9548961520195007,val_acc: 0.6880000233650208,test_acc: 0.6979997158050537\n", "epoch: 8,train_loss: 0.9275616407394409,val_acc: 0.7100000977516174,test_acc: 0.6979997158050537\n", "epoch: 9,train_loss: 0.899043619632721,val_acc: 0.7140000462532043,test_acc: 0.7019997835159302\n", "epoch: 11,train_loss: 0.8381972908973694,val_acc: 0.7180001139640808,test_acc: 0.7079997658729553\n", "epoch: 12,train_loss: 0.8062182664871216,val_acc: 0.7220000624656677,test_acc: 0.711999773979187\n", "epoch: 13,train_loss: 0.7739933133125305,val_acc: 0.7300000786781311,test_acc: 0.714999794960022\n", "epoch: 14,train_loss: 0.7413352131843567,val_acc: 0.7380000948905945,test_acc: 0.7249997854232788\n", "epoch: 15,train_loss: 0.7085117697715759,val_acc: 0.7420001029968262,test_acc: 0.7279998064041138\n", "epoch: 19,train_loss: 0.5778310894966125,val_acc: 0.7440000772476196,test_acc: 0.7369998097419739\n", "epoch: 20,train_loss: 0.5461804866790771,val_acc: 0.7440001368522644,test_acc: 0.7369998097419739\n" ] } ], "source": [ "# on identity train(features, tf.eye(adj.shape[0]), gnn_fn, 32, 20, 0.01)\n", "\n", "# no normalization: train(features, adj, gnn_fn, 32, 20, 0.01)\n", "\n", "# normalize adj by degree\n", "deg = tf.reduce_sum(adj, axis=-1)\n", "train(features, adj / deg, gnn_fn, 32, 20, 0.01)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# thomas gibb in graph conv networks: normalize by 1/sqrt(deg)\n", "\n", "# norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg))\n", "# norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n", "# train(features, norm_adj, gnn_fn, 32, 200, 0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another impl:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "1/1 [==============================] - 0s 429ms/step - loss: 1.1066 - acc: 0.3833 - val_loss: 1.1029 - val_acc: 0.4720\n", "Epoch 2/100\n", "1/1 [==============================] - 0s 84ms/step - loss: 1.1028 - acc: 0.5167 - val_loss: 1.0993 - val_acc: 0.5660\n", "Epoch 3/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 1.0975 - acc: 0.6167 - val_loss: 1.0951 - val_acc: 0.6140\n", "Epoch 4/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 1.0911 - acc: 0.7500 - val_loss: 1.0905 - val_acc: 0.6480\n", "Epoch 5/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 1.0841 - acc: 0.8667 - val_loss: 1.0855 - val_acc: 0.6700\n", "Epoch 6/100\n", "1/1 [==============================] - 0s 96ms/step - loss: 1.0779 - acc: 0.8333 - val_loss: 1.0803 - val_acc: 0.6780\n", "Epoch 7/100\n", "1/1 [==============================] - 0s 98ms/step - loss: 1.0696 - acc: 0.8667 - val_loss: 1.0750 - val_acc: 0.6840\n", "Epoch 8/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 1.0639 - acc: 0.8667 - val_loss: 1.0696 - val_acc: 0.6880\n", "Epoch 9/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 1.0566 - acc: 0.8333 - val_loss: 1.0641 - val_acc: 0.6920\n", "Epoch 10/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 1.0499 - acc: 0.9000 - val_loss: 1.0586 - val_acc: 0.6900\n", "Epoch 11/100\n", "1/1 [==============================] - 0s 96ms/step - loss: 1.0377 - acc: 0.8833 - val_loss: 1.0531 - val_acc: 0.6900\n", "Epoch 12/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 1.0297 - acc: 0.8833 - val_loss: 1.0475 - val_acc: 0.6940\n", "Epoch 13/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 1.0265 - acc: 0.8833 - val_loss: 1.0419 - val_acc: 0.6940\n", "Epoch 14/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 1.0154 - acc: 0.8167 - val_loss: 1.0364 - val_acc: 0.6960\n", "Epoch 15/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 1.0056 - acc: 0.8500 - val_loss: 1.0309 - val_acc: 0.7000\n", "Epoch 16/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.9946 - acc: 0.8500 - val_loss: 1.0254 - val_acc: 0.7080\n", "Epoch 17/100\n", "1/1 [==============================] - 0s 86ms/step - loss: 0.9747 - acc: 0.9167 - val_loss: 1.0198 - val_acc: 0.7100\n", "Epoch 18/100\n", "1/1 [==============================] - 0s 98ms/step - loss: 0.9775 - acc: 0.8833 - val_loss: 1.0143 - val_acc: 0.7240\n", "Epoch 19/100\n", "1/1 [==============================] - 0s 98ms/step - loss: 0.9703 - acc: 0.8167 - val_loss: 1.0088 - val_acc: 0.7280\n", "Epoch 20/100\n", "1/1 [==============================] - 0s 98ms/step - loss: 0.9489 - acc: 0.8833 - val_loss: 1.0032 - val_acc: 0.7320\n", "Epoch 21/100\n", "1/1 [==============================] - 0s 99ms/step - loss: 0.9471 - acc: 0.8500 - val_loss: 0.9977 - val_acc: 0.7340\n", "Epoch 22/100\n", "1/1 [==============================] - 0s 111ms/step - loss: 0.9274 - acc: 0.8833 - val_loss: 0.9921 - val_acc: 0.7360\n", "Epoch 23/100\n", "1/1 [==============================] - 0s 106ms/step - loss: 0.9140 - acc: 0.9333 - val_loss: 0.9863 - val_acc: 0.7320\n", "Epoch 24/100\n", "1/1 [==============================] - 0s 103ms/step - loss: 0.9167 - acc: 0.8833 - val_loss: 0.9805 - val_acc: 0.7380\n", "Epoch 25/100\n", "1/1 [==============================] - 0s 96ms/step - loss: 0.9001 - acc: 0.9000 - val_loss: 0.9748 - val_acc: 0.7380\n", "Epoch 26/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 0.8726 - acc: 0.9500 - val_loss: 0.9691 - val_acc: 0.7440\n", "Epoch 27/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.8624 - acc: 0.9333 - val_loss: 0.9636 - val_acc: 0.7420\n", "Epoch 28/100\n", "1/1 [==============================] - 0s 105ms/step - loss: 0.8646 - acc: 0.8833 - val_loss: 0.9581 - val_acc: 0.7460\n", "Epoch 29/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.8331 - acc: 0.9333 - val_loss: 0.9527 - val_acc: 0.7460\n", "Epoch 30/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 0.8285 - acc: 0.9500 - val_loss: 0.9473 - val_acc: 0.7460\n", "Epoch 31/100\n", "1/1 [==============================] - 0s 85ms/step - loss: 0.8413 - acc: 0.9167 - val_loss: 0.9419 - val_acc: 0.7460\n", "Epoch 32/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.8097 - acc: 0.9167 - val_loss: 0.9364 - val_acc: 0.7520\n", "Epoch 33/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.7880 - acc: 0.9000 - val_loss: 0.9310 - val_acc: 0.7540\n", "Epoch 34/100\n", "1/1 [==============================] - 0s 94ms/step - loss: 0.7893 - acc: 0.9667 - val_loss: 0.9257 - val_acc: 0.7560\n", "Epoch 35/100\n", "1/1 [==============================] - 0s 94ms/step - loss: 0.7557 - acc: 0.9667 - val_loss: 0.9205 - val_acc: 0.7580\n", "Epoch 36/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 0.7780 - acc: 0.9167 - val_loss: 0.9150 - val_acc: 0.7580\n", "Epoch 37/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.7612 - acc: 0.9500 - val_loss: 0.9096 - val_acc: 0.7640\n", "Epoch 38/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.7419 - acc: 0.9500 - val_loss: 0.9043 - val_acc: 0.7660\n", "Epoch 39/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.7383 - acc: 0.8833 - val_loss: 0.8992 - val_acc: 0.7680\n", "Epoch 40/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.7241 - acc: 0.9500 - val_loss: 0.8943 - val_acc: 0.7700\n", "Epoch 41/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.7062 - acc: 0.9500 - val_loss: 0.8897 - val_acc: 0.7720\n", "Epoch 42/100\n", "1/1 [==============================] - 0s 86ms/step - loss: 0.7407 - acc: 0.8500 - val_loss: 0.8853 - val_acc: 0.7720\n", "Epoch 43/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.7206 - acc: 0.9667 - val_loss: 0.8810 - val_acc: 0.7740\n", "Epoch 44/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.6992 - acc: 0.9333 - val_loss: 0.8767 - val_acc: 0.7760\n", "Epoch 45/100\n", "1/1 [==============================] - 0s 94ms/step - loss: 0.6923 - acc: 0.9500 - val_loss: 0.8725 - val_acc: 0.7760\n", "Epoch 46/100\n", "1/1 [==============================] - 0s 97ms/step - loss: 0.6670 - acc: 0.9667 - val_loss: 0.8683 - val_acc: 0.7740\n", "Epoch 47/100\n", "1/1 [==============================] - 0s 87ms/step - loss: 0.6801 - acc: 0.9333 - val_loss: 0.8643 - val_acc: 0.7720\n", "Epoch 48/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.6728 - acc: 0.9500 - val_loss: 0.8605 - val_acc: 0.7740\n", "Epoch 49/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.6843 - acc: 0.9500 - val_loss: 0.8567 - val_acc: 0.7740\n", "Epoch 50/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.6378 - acc: 0.9333 - val_loss: 0.8531 - val_acc: 0.7780\n", "Epoch 51/100\n", "1/1 [==============================] - 0s 94ms/step - loss: 0.6508 - acc: 0.9500 - val_loss: 0.8495 - val_acc: 0.7760\n", "Epoch 52/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.6212 - acc: 0.9500 - val_loss: 0.8461 - val_acc: 0.7740\n", "Epoch 53/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.6273 - acc: 0.9000 - val_loss: 0.8430 - val_acc: 0.7760\n", "Epoch 54/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 0.6076 - acc: 0.9667 - val_loss: 0.8399 - val_acc: 0.7740\n", "Epoch 55/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.5923 - acc: 0.9333 - val_loss: 0.8370 - val_acc: 0.7740\n", "Epoch 56/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 0.6108 - acc: 0.9333 - val_loss: 0.8343 - val_acc: 0.7760\n", "Epoch 57/100\n", "1/1 [==============================] - 0s 94ms/step - loss: 0.5720 - acc: 0.9333 - val_loss: 0.8317 - val_acc: 0.7760\n", "Epoch 58/100\n", "1/1 [==============================] - 0s 87ms/step - loss: 0.5809 - acc: 0.9500 - val_loss: 0.8291 - val_acc: 0.7760\n", "Epoch 59/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 0.6027 - acc: 0.9333 - val_loss: 0.8264 - val_acc: 0.7740\n", "Epoch 60/100\n", "1/1 [==============================] - 0s 98ms/step - loss: 0.5793 - acc: 0.9333 - val_loss: 0.8238 - val_acc: 0.7740\n", "Epoch 61/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.5688 - acc: 0.9333 - val_loss: 0.8211 - val_acc: 0.7700\n", "Epoch 62/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.5791 - acc: 0.8833 - val_loss: 0.8184 - val_acc: 0.7720\n", "Epoch 63/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.5321 - acc: 0.9667 - val_loss: 0.8156 - val_acc: 0.7720\n", "Epoch 64/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.5385 - acc: 0.9500 - val_loss: 0.8129 - val_acc: 0.7700\n", "Epoch 65/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.5376 - acc: 0.9833 - val_loss: 0.8100 - val_acc: 0.7740\n", "Epoch 66/100\n", "1/1 [==============================] - 0s 97ms/step - loss: 0.5222 - acc: 0.9833 - val_loss: 0.8073 - val_acc: 0.7780\n", "Epoch 67/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.5196 - acc: 0.9500 - val_loss: 0.8046 - val_acc: 0.7800\n", "Epoch 68/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.5265 - acc: 0.9500 - val_loss: 0.8022 - val_acc: 0.7860\n", "Epoch 69/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.5388 - acc: 0.9500 - val_loss: 0.7998 - val_acc: 0.7900\n", "Epoch 70/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.4889 - acc: 0.9667 - val_loss: 0.7975 - val_acc: 0.7920\n", "Epoch 71/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 0.5209 - acc: 0.9333 - val_loss: 0.7953 - val_acc: 0.7940\n", "Epoch 72/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.5043 - acc: 0.9500 - val_loss: 0.7933 - val_acc: 0.7960\n", "Epoch 73/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.5118 - acc: 0.9667 - val_loss: 0.7914 - val_acc: 0.7960\n", "Epoch 74/100\n", "1/1 [==============================] - 0s 90ms/step - loss: 0.4964 - acc: 0.9833 - val_loss: 0.7894 - val_acc: 0.7960\n", "Epoch 75/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.5301 - acc: 0.9500 - val_loss: 0.7874 - val_acc: 0.7940\n", "Epoch 76/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.4703 - acc: 0.9833 - val_loss: 0.7856 - val_acc: 0.7940\n", "Epoch 77/100\n", "1/1 [==============================] - 0s 106ms/step - loss: 0.4777 - acc: 0.9667 - val_loss: 0.7839 - val_acc: 0.7920\n", "Epoch 78/100\n", "1/1 [==============================] - 0s 94ms/step - loss: 0.4782 - acc: 0.9500 - val_loss: 0.7822 - val_acc: 0.7920\n", "Epoch 79/100\n", "1/1 [==============================] - 0s 98ms/step - loss: 0.4735 - acc: 0.9667 - val_loss: 0.7804 - val_acc: 0.7920\n", "Epoch 80/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 0.4739 - acc: 0.9333 - val_loss: 0.7786 - val_acc: 0.7920\n", "Epoch 81/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.5066 - acc: 0.9500 - val_loss: 0.7768 - val_acc: 0.7900\n", "Epoch 82/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 0.4626 - acc: 0.9667 - val_loss: 0.7750 - val_acc: 0.7900\n", "Epoch 83/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 0.4573 - acc: 0.9667 - val_loss: 0.7733 - val_acc: 0.7920\n", "Epoch 84/100\n", "1/1 [==============================] - 0s 88ms/step - loss: 0.4399 - acc: 1.0000 - val_loss: 0.7715 - val_acc: 0.7920\n", "Epoch 85/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.4751 - acc: 0.9667 - val_loss: 0.7698 - val_acc: 0.7940\n", "Epoch 86/100\n", "1/1 [==============================] - 0s 89ms/step - loss: 0.4270 - acc: 0.9833 - val_loss: 0.7683 - val_acc: 0.7940\n", "Epoch 87/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.4611 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7940\n", "Epoch 88/100\n", "1/1 [==============================] - 0s 92ms/step - loss: 0.3962 - acc: 0.9833 - val_loss: 0.7652 - val_acc: 0.7980\n", "Epoch 89/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 0.4901 - acc: 0.9667 - val_loss: 0.7640 - val_acc: 0.7980\n", "Epoch 90/100\n", "1/1 [==============================] - 0s 97ms/step - loss: 0.4380 - acc: 0.9500 - val_loss: 0.7630 - val_acc: 0.7980\n", "Epoch 91/100\n", "1/1 [==============================] - 0s 100ms/step - loss: 0.4628 - acc: 0.9500 - val_loss: 0.7620 - val_acc: 0.7960\n", "Epoch 92/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.4599 - acc: 0.9667 - val_loss: 0.7612 - val_acc: 0.7960\n", "Epoch 93/100\n", "1/1 [==============================] - 0s 95ms/step - loss: 0.4459 - acc: 1.0000 - val_loss: 0.7601 - val_acc: 0.7960\n", "Epoch 94/100\n", "1/1 [==============================] - 0s 87ms/step - loss: 0.4031 - acc: 0.9667 - val_loss: 0.7587 - val_acc: 0.7960\n", "Epoch 95/100\n", "1/1 [==============================] - 0s 86ms/step - loss: 0.4624 - acc: 0.9500 - val_loss: 0.7571 - val_acc: 0.7940\n", "Epoch 96/100\n", "1/1 [==============================] - 0s 103ms/step - loss: 0.4164 - acc: 0.9667 - val_loss: 0.7559 - val_acc: 0.7940\n", "Epoch 97/100\n", "1/1 [==============================] - 0s 97ms/step - loss: 0.4244 - acc: 0.9833 - val_loss: 0.7549 - val_acc: 0.7940\n", "Epoch 98/100\n", "1/1 [==============================] - 0s 93ms/step - loss: 0.4013 - acc: 0.9833 - val_loss: 0.7542 - val_acc: 0.7920\n", "Epoch 99/100\n", "1/1 [==============================] - 0s 91ms/step - loss: 0.4387 - acc: 0.9667 - val_loss: 0.7537 - val_acc: 0.7920\n", "Epoch 100/100\n", "1/1 [==============================] - 0s 101ms/step - loss: 0.4377 - acc: 0.9667 - val_loss: 0.7528 - val_acc: 0.7920\n", "1/1 [==============================] - 0s 39ms/step - loss: 0.7651 - acc: 0.7890\n", "Done.\n", "Test loss: 0.77\n", "Test accuracy: 0.79\n" ] } ], "source": [ "\"\"\"\n", "This example implements the experiments on citation networks from the paper:\n", "Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907)\n", "Thomas N. Kipf, Max Welling\n", "\"\"\"\n", "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "from tensorflow.keras.losses import CategoricalCrossentropy\n", "from tensorflow.keras.optimizers import Adam\n", "\n", "from spektral.data.loaders import SingleLoader\n", "from spektral.models.gcn import GCN\n", "\n", "learning_rate = .01\n", "epochs = 100\n", "patience = 10\n", "\n", "def mask_to_weights(mask):\n", " return mask.astype(np.float32) / np.count_nonzero(mask)\n", "\n", "weights_tr, weights_va, weights_te = (\n", " mask_to_weights(mask) for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n", ")\n", "\n", "# define the model\n", "model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)\n", "model.compile(\n", " optimizer=Adam(learning_rate),\n", " loss=CategoricalCrossentropy(reduction=\"sum\"),\n", " weighted_metrics=[\"acc\"],\n", ")\n", "\n", "loader_tr = SingleLoader(dataset, sample_weights=weights_tr)\n", "loader_va = SingleLoader(dataset, sample_weights=weights_va)\n", "loader_te = SingleLoader(dataset, sample_weights=weights_te)\n", "\n", "# Train\n", "model.fit(\n", " loader_tr.load(),\n", " steps_per_epoch=loader_tr.steps_per_epoch,\n", " validation_data=loader_va.load(),\n", " validation_steps=loader_va.steps_per_epoch,\n", " epochs=epochs,\n", " callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],\n", ")\n", "\n", "# Evaluate\n", "eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)\n", "\n", "print('Done.\\n'\n", " f'Test loss: {eval_results[0]:.2f}\\n'\n", " f'Test accuracy: {eval_results[1]:.2f}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PubMed(Dataset):\n", "\n", " def __init__(self, n_samples, n_nodes=3, n_min=10, n_max=100, p=0.1, **kwargs):\n", " self.n_samples = n_samples\n", " self.n_nodes = n_nodes\n", " self.n_min = n_min\n", " self.n_max = n_max\n", " self.p = p\n", " super().__init__(**kwargs)\n", "\n", " def read(self):\n", " def make_graph():\n", " n = np.random.randint(self.n_min, self.n_max)\n", " colors = np.random.randint(0, self.n_colors, size=n)\n", "\n", " # Node features\n", " x = np.zeros((n, self.n_colors))\n", " x[np.arange(n), colors] = 1\n", "\n", " # Edges\n", " a = np.random.rand(n, n) <= self.p\n", " a = np.maximum(a, a.T).astype(int)\n", " a = sp.csr_matrix(a)\n", "\n", " # Labels\n", " y = np.zeros((self.n_colors,))\n", " color_counts = x.sum(0)\n", " y[np.argmax(color_counts)] = 1\n", "\n", " return Graph(x=x, a=a, y=y)\n", "\n", " # We must return a list of Graph objects\n", " return [make_graph() for _ in range(self.n_samples)]\n", "\n", "\n", "data = PubMed(1000, transforms=NormalizeAdj())\n" ] } ], "metadata": { "interpreter": { "hash": "4d4c55ad0dd25f9ca95e4d49a929aa3f71bfb37020ae570a9996c3e164818202" }, "kernelspec": { "display_name": "Python 3.9.7 64-bit ('py3': conda)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }