{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install numpy spektral tensorflow -q\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import spektral\n",
"from spektral.datasets.citation import Citation\n",
"from tqdm import tqdm\n",
"from spektral.transforms import LayerPreprocess\n",
"from spektral.layers import GCNConv\n",
"from spektral.transforms import NormalizeAdj\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pre-processing node features\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/scipy/sparse/_index.py:125: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
" self._set_arrayXarray(i, j, x)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pre-processing node features\n"
]
}
],
"source": [
"dataset = Citation('pubmed', normalize_x=True)\n",
"dataset = Citation('pubmed', normalize_x=True, transforms=[LayerPreprocess(GCNConv)])\n",
"\n",
"adj = dataset.graphs[0].a.todense()\n",
"features = dataset.graphs[0].x\n",
"labels = dataset.graphs[0].y\n",
"train_mask, val_mask, test_mask = dataset.mask_tr, dataset.mask_va, dataset.mask_te\n",
"\n",
"def masked_softmax_cross_entropy(logits, labels, mask):\n",
" loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
" mask = tf.cast(mask, dtype=tf.float32)\n",
" mask /= tf.reduce_mean(mask)\n",
" loss *= mask\n",
" return tf.reduce_mean(loss)\n",
"\n",
"def masked_accuracy(logits, labels, mask):\n",
" correct_preds = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))\n",
" accuracy = tf.cast(correct_preds, dtype=tf.float32)\n",
" mask = tf.cast(mask, dtype=tf.float32)\n",
" mask /= tf.reduce_mean(mask)\n",
" accuracy *= mask\n",
" return tf.reduce_mean(accuracy)\n",
"\n",
"def gnn_fn(features, adj, transform, activation):\n",
" seq_fts = transform(features)\n",
" ret_fts = tf.matmul(adj, seq_fts)\n",
" return activation(ret_fts)\n",
"\n",
"def train(features, adj, gnn_fn, units, epochs, lr=0.01):\n",
" lyr_1 = tf.keras.layers.Dense(units)\n",
" lyr_2 = tf.keras.layers.Dense(3)\n",
" \n",
" def gnn_net(features, adj):\n",
" hidden = gnn_fn(features, adj, lyr_1, tf.nn.relu)\n",
" logits = gnn_fn(hidden, adj, lyr_2, tf.identity)\n",
" return logits\n",
" \n",
" optimizer = tf.keras.optimizers.Adam(learning_rate=lr)\n",
" best_accuracy = 0.0\n",
" \n",
" for epoch in range(epochs+1):\n",
" with tf.GradientTape() as t:\n",
" logits = gnn_net(features, adj)\n",
" loss = masked_softmax_cross_entropy(logits, labels, train_mask)\n",
"\n",
" variables = t.watched_variables()\n",
" grads = t.gradient(loss, variables)\n",
" optimizer.apply_gradients(zip(grads, variables))\n",
"\n",
" logits = gnn_net(features, adj)\n",
" val_accuracy = masked_accuracy(logits, labels, val_mask)\n",
" test_accuracy = masked_accuracy(logits, labels, test_mask)\n",
" \n",
" if val_accuracy > best_accuracy:\n",
" best_accuracy = val_accuracy\n",
" print(f'epoch: {epoch},'\n",
" f'train_loss: {loss.numpy()},'\n",
" f'val_acc: {val_accuracy.numpy()},'\n",
" f'test_acc: {test_accuracy.numpy()}')\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0,train_loss: 1.0988209247589111,val_acc: 0.49799996614456177,test_acc: 0.5019999742507935\n",
"epoch: 3,train_loss: 1.0498652458190918,val_acc: 0.515999972820282,test_acc: 0.528999924659729\n",
"epoch: 4,train_loss: 1.0287455320358276,val_acc: 0.5720000267028809,test_acc: 0.6109998226165771\n",
"epoch: 5,train_loss: 1.0060378313064575,val_acc: 0.6500000357627869,test_acc: 0.6719998121261597\n",
"epoch: 6,train_loss: 0.9812273383140564,val_acc: 0.6760000586509705,test_acc: 0.6909998059272766\n",
"epoch: 7,train_loss: 0.9548961520195007,val_acc: 0.6880000233650208,test_acc: 0.6979997158050537\n",
"epoch: 8,train_loss: 0.9275616407394409,val_acc: 0.7100000977516174,test_acc: 0.6979997158050537\n",
"epoch: 9,train_loss: 0.899043619632721,val_acc: 0.7140000462532043,test_acc: 0.7019997835159302\n",
"epoch: 11,train_loss: 0.8381972908973694,val_acc: 0.7180001139640808,test_acc: 0.7079997658729553\n",
"epoch: 12,train_loss: 0.8062182664871216,val_acc: 0.7220000624656677,test_acc: 0.711999773979187\n",
"epoch: 13,train_loss: 0.7739933133125305,val_acc: 0.7300000786781311,test_acc: 0.714999794960022\n",
"epoch: 14,train_loss: 0.7413352131843567,val_acc: 0.7380000948905945,test_acc: 0.7249997854232788\n",
"epoch: 15,train_loss: 0.7085117697715759,val_acc: 0.7420001029968262,test_acc: 0.7279998064041138\n",
"epoch: 19,train_loss: 0.5778310894966125,val_acc: 0.7440000772476196,test_acc: 0.7369998097419739\n",
"epoch: 20,train_loss: 0.5461804866790771,val_acc: 0.7440001368522644,test_acc: 0.7369998097419739\n"
]
}
],
"source": [
"# on identity train(features, tf.eye(adj.shape[0]), gnn_fn, 32, 20, 0.01)\n",
"\n",
"# no normalization: train(features, adj, gnn_fn, 32, 20, 0.01)\n",
"\n",
"# normalize adj by degree\n",
"deg = tf.reduce_sum(adj, axis=-1)\n",
"train(features, adj / deg, gnn_fn, 32, 20, 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# thomas gibb in graph conv networks: normalize by 1/sqrt(deg)\n",
"\n",
"# norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg))\n",
"# norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n",
"# train(features, norm_adj, gnn_fn, 32, 200, 0.01)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another impl:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1/1 [==============================] - 0s 429ms/step - loss: 1.1066 - acc: 0.3833 - val_loss: 1.1029 - val_acc: 0.4720\n",
"Epoch 2/100\n",
"1/1 [==============================] - 0s 84ms/step - loss: 1.1028 - acc: 0.5167 - val_loss: 1.0993 - val_acc: 0.5660\n",
"Epoch 3/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 1.0975 - acc: 0.6167 - val_loss: 1.0951 - val_acc: 0.6140\n",
"Epoch 4/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 1.0911 - acc: 0.7500 - val_loss: 1.0905 - val_acc: 0.6480\n",
"Epoch 5/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 1.0841 - acc: 0.8667 - val_loss: 1.0855 - val_acc: 0.6700\n",
"Epoch 6/100\n",
"1/1 [==============================] - 0s 96ms/step - loss: 1.0779 - acc: 0.8333 - val_loss: 1.0803 - val_acc: 0.6780\n",
"Epoch 7/100\n",
"1/1 [==============================] - 0s 98ms/step - loss: 1.0696 - acc: 0.8667 - val_loss: 1.0750 - val_acc: 0.6840\n",
"Epoch 8/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 1.0639 - acc: 0.8667 - val_loss: 1.0696 - val_acc: 0.6880\n",
"Epoch 9/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 1.0566 - acc: 0.8333 - val_loss: 1.0641 - val_acc: 0.6920\n",
"Epoch 10/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 1.0499 - acc: 0.9000 - val_loss: 1.0586 - val_acc: 0.6900\n",
"Epoch 11/100\n",
"1/1 [==============================] - 0s 96ms/step - loss: 1.0377 - acc: 0.8833 - val_loss: 1.0531 - val_acc: 0.6900\n",
"Epoch 12/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 1.0297 - acc: 0.8833 - val_loss: 1.0475 - val_acc: 0.6940\n",
"Epoch 13/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 1.0265 - acc: 0.8833 - val_loss: 1.0419 - val_acc: 0.6940\n",
"Epoch 14/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 1.0154 - acc: 0.8167 - val_loss: 1.0364 - val_acc: 0.6960\n",
"Epoch 15/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 1.0056 - acc: 0.8500 - val_loss: 1.0309 - val_acc: 0.7000\n",
"Epoch 16/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.9946 - acc: 0.8500 - val_loss: 1.0254 - val_acc: 0.7080\n",
"Epoch 17/100\n",
"1/1 [==============================] - 0s 86ms/step - loss: 0.9747 - acc: 0.9167 - val_loss: 1.0198 - val_acc: 0.7100\n",
"Epoch 18/100\n",
"1/1 [==============================] - 0s 98ms/step - loss: 0.9775 - acc: 0.8833 - val_loss: 1.0143 - val_acc: 0.7240\n",
"Epoch 19/100\n",
"1/1 [==============================] - 0s 98ms/step - loss: 0.9703 - acc: 0.8167 - val_loss: 1.0088 - val_acc: 0.7280\n",
"Epoch 20/100\n",
"1/1 [==============================] - 0s 98ms/step - loss: 0.9489 - acc: 0.8833 - val_loss: 1.0032 - val_acc: 0.7320\n",
"Epoch 21/100\n",
"1/1 [==============================] - 0s 99ms/step - loss: 0.9471 - acc: 0.8500 - val_loss: 0.9977 - val_acc: 0.7340\n",
"Epoch 22/100\n",
"1/1 [==============================] - 0s 111ms/step - loss: 0.9274 - acc: 0.8833 - val_loss: 0.9921 - val_acc: 0.7360\n",
"Epoch 23/100\n",
"1/1 [==============================] - 0s 106ms/step - loss: 0.9140 - acc: 0.9333 - val_loss: 0.9863 - val_acc: 0.7320\n",
"Epoch 24/100\n",
"1/1 [==============================] - 0s 103ms/step - loss: 0.9167 - acc: 0.8833 - val_loss: 0.9805 - val_acc: 0.7380\n",
"Epoch 25/100\n",
"1/1 [==============================] - 0s 96ms/step - loss: 0.9001 - acc: 0.9000 - val_loss: 0.9748 - val_acc: 0.7380\n",
"Epoch 26/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 0.8726 - acc: 0.9500 - val_loss: 0.9691 - val_acc: 0.7440\n",
"Epoch 27/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.8624 - acc: 0.9333 - val_loss: 0.9636 - val_acc: 0.7420\n",
"Epoch 28/100\n",
"1/1 [==============================] - 0s 105ms/step - loss: 0.8646 - acc: 0.8833 - val_loss: 0.9581 - val_acc: 0.7460\n",
"Epoch 29/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.8331 - acc: 0.9333 - val_loss: 0.9527 - val_acc: 0.7460\n",
"Epoch 30/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 0.8285 - acc: 0.9500 - val_loss: 0.9473 - val_acc: 0.7460\n",
"Epoch 31/100\n",
"1/1 [==============================] - 0s 85ms/step - loss: 0.8413 - acc: 0.9167 - val_loss: 0.9419 - val_acc: 0.7460\n",
"Epoch 32/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.8097 - acc: 0.9167 - val_loss: 0.9364 - val_acc: 0.7520\n",
"Epoch 33/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.7880 - acc: 0.9000 - val_loss: 0.9310 - val_acc: 0.7540\n",
"Epoch 34/100\n",
"1/1 [==============================] - 0s 94ms/step - loss: 0.7893 - acc: 0.9667 - val_loss: 0.9257 - val_acc: 0.7560\n",
"Epoch 35/100\n",
"1/1 [==============================] - 0s 94ms/step - loss: 0.7557 - acc: 0.9667 - val_loss: 0.9205 - val_acc: 0.7580\n",
"Epoch 36/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 0.7780 - acc: 0.9167 - val_loss: 0.9150 - val_acc: 0.7580\n",
"Epoch 37/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.7612 - acc: 0.9500 - val_loss: 0.9096 - val_acc: 0.7640\n",
"Epoch 38/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.7419 - acc: 0.9500 - val_loss: 0.9043 - val_acc: 0.7660\n",
"Epoch 39/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.7383 - acc: 0.8833 - val_loss: 0.8992 - val_acc: 0.7680\n",
"Epoch 40/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.7241 - acc: 0.9500 - val_loss: 0.8943 - val_acc: 0.7700\n",
"Epoch 41/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.7062 - acc: 0.9500 - val_loss: 0.8897 - val_acc: 0.7720\n",
"Epoch 42/100\n",
"1/1 [==============================] - 0s 86ms/step - loss: 0.7407 - acc: 0.8500 - val_loss: 0.8853 - val_acc: 0.7720\n",
"Epoch 43/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.7206 - acc: 0.9667 - val_loss: 0.8810 - val_acc: 0.7740\n",
"Epoch 44/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.6992 - acc: 0.9333 - val_loss: 0.8767 - val_acc: 0.7760\n",
"Epoch 45/100\n",
"1/1 [==============================] - 0s 94ms/step - loss: 0.6923 - acc: 0.9500 - val_loss: 0.8725 - val_acc: 0.7760\n",
"Epoch 46/100\n",
"1/1 [==============================] - 0s 97ms/step - loss: 0.6670 - acc: 0.9667 - val_loss: 0.8683 - val_acc: 0.7740\n",
"Epoch 47/100\n",
"1/1 [==============================] - 0s 87ms/step - loss: 0.6801 - acc: 0.9333 - val_loss: 0.8643 - val_acc: 0.7720\n",
"Epoch 48/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.6728 - acc: 0.9500 - val_loss: 0.8605 - val_acc: 0.7740\n",
"Epoch 49/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.6843 - acc: 0.9500 - val_loss: 0.8567 - val_acc: 0.7740\n",
"Epoch 50/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.6378 - acc: 0.9333 - val_loss: 0.8531 - val_acc: 0.7780\n",
"Epoch 51/100\n",
"1/1 [==============================] - 0s 94ms/step - loss: 0.6508 - acc: 0.9500 - val_loss: 0.8495 - val_acc: 0.7760\n",
"Epoch 52/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.6212 - acc: 0.9500 - val_loss: 0.8461 - val_acc: 0.7740\n",
"Epoch 53/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.6273 - acc: 0.9000 - val_loss: 0.8430 - val_acc: 0.7760\n",
"Epoch 54/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 0.6076 - acc: 0.9667 - val_loss: 0.8399 - val_acc: 0.7740\n",
"Epoch 55/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.5923 - acc: 0.9333 - val_loss: 0.8370 - val_acc: 0.7740\n",
"Epoch 56/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 0.6108 - acc: 0.9333 - val_loss: 0.8343 - val_acc: 0.7760\n",
"Epoch 57/100\n",
"1/1 [==============================] - 0s 94ms/step - loss: 0.5720 - acc: 0.9333 - val_loss: 0.8317 - val_acc: 0.7760\n",
"Epoch 58/100\n",
"1/1 [==============================] - 0s 87ms/step - loss: 0.5809 - acc: 0.9500 - val_loss: 0.8291 - val_acc: 0.7760\n",
"Epoch 59/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 0.6027 - acc: 0.9333 - val_loss: 0.8264 - val_acc: 0.7740\n",
"Epoch 60/100\n",
"1/1 [==============================] - 0s 98ms/step - loss: 0.5793 - acc: 0.9333 - val_loss: 0.8238 - val_acc: 0.7740\n",
"Epoch 61/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.5688 - acc: 0.9333 - val_loss: 0.8211 - val_acc: 0.7700\n",
"Epoch 62/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.5791 - acc: 0.8833 - val_loss: 0.8184 - val_acc: 0.7720\n",
"Epoch 63/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.5321 - acc: 0.9667 - val_loss: 0.8156 - val_acc: 0.7720\n",
"Epoch 64/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.5385 - acc: 0.9500 - val_loss: 0.8129 - val_acc: 0.7700\n",
"Epoch 65/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.5376 - acc: 0.9833 - val_loss: 0.8100 - val_acc: 0.7740\n",
"Epoch 66/100\n",
"1/1 [==============================] - 0s 97ms/step - loss: 0.5222 - acc: 0.9833 - val_loss: 0.8073 - val_acc: 0.7780\n",
"Epoch 67/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.5196 - acc: 0.9500 - val_loss: 0.8046 - val_acc: 0.7800\n",
"Epoch 68/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.5265 - acc: 0.9500 - val_loss: 0.8022 - val_acc: 0.7860\n",
"Epoch 69/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.5388 - acc: 0.9500 - val_loss: 0.7998 - val_acc: 0.7900\n",
"Epoch 70/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.4889 - acc: 0.9667 - val_loss: 0.7975 - val_acc: 0.7920\n",
"Epoch 71/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 0.5209 - acc: 0.9333 - val_loss: 0.7953 - val_acc: 0.7940\n",
"Epoch 72/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.5043 - acc: 0.9500 - val_loss: 0.7933 - val_acc: 0.7960\n",
"Epoch 73/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.5118 - acc: 0.9667 - val_loss: 0.7914 - val_acc: 0.7960\n",
"Epoch 74/100\n",
"1/1 [==============================] - 0s 90ms/step - loss: 0.4964 - acc: 0.9833 - val_loss: 0.7894 - val_acc: 0.7960\n",
"Epoch 75/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.5301 - acc: 0.9500 - val_loss: 0.7874 - val_acc: 0.7940\n",
"Epoch 76/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.4703 - acc: 0.9833 - val_loss: 0.7856 - val_acc: 0.7940\n",
"Epoch 77/100\n",
"1/1 [==============================] - 0s 106ms/step - loss: 0.4777 - acc: 0.9667 - val_loss: 0.7839 - val_acc: 0.7920\n",
"Epoch 78/100\n",
"1/1 [==============================] - 0s 94ms/step - loss: 0.4782 - acc: 0.9500 - val_loss: 0.7822 - val_acc: 0.7920\n",
"Epoch 79/100\n",
"1/1 [==============================] - 0s 98ms/step - loss: 0.4735 - acc: 0.9667 - val_loss: 0.7804 - val_acc: 0.7920\n",
"Epoch 80/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 0.4739 - acc: 0.9333 - val_loss: 0.7786 - val_acc: 0.7920\n",
"Epoch 81/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.5066 - acc: 0.9500 - val_loss: 0.7768 - val_acc: 0.7900\n",
"Epoch 82/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 0.4626 - acc: 0.9667 - val_loss: 0.7750 - val_acc: 0.7900\n",
"Epoch 83/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 0.4573 - acc: 0.9667 - val_loss: 0.7733 - val_acc: 0.7920\n",
"Epoch 84/100\n",
"1/1 [==============================] - 0s 88ms/step - loss: 0.4399 - acc: 1.0000 - val_loss: 0.7715 - val_acc: 0.7920\n",
"Epoch 85/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.4751 - acc: 0.9667 - val_loss: 0.7698 - val_acc: 0.7940\n",
"Epoch 86/100\n",
"1/1 [==============================] - 0s 89ms/step - loss: 0.4270 - acc: 0.9833 - val_loss: 0.7683 - val_acc: 0.7940\n",
"Epoch 87/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.4611 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7940\n",
"Epoch 88/100\n",
"1/1 [==============================] - 0s 92ms/step - loss: 0.3962 - acc: 0.9833 - val_loss: 0.7652 - val_acc: 0.7980\n",
"Epoch 89/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 0.4901 - acc: 0.9667 - val_loss: 0.7640 - val_acc: 0.7980\n",
"Epoch 90/100\n",
"1/1 [==============================] - 0s 97ms/step - loss: 0.4380 - acc: 0.9500 - val_loss: 0.7630 - val_acc: 0.7980\n",
"Epoch 91/100\n",
"1/1 [==============================] - 0s 100ms/step - loss: 0.4628 - acc: 0.9500 - val_loss: 0.7620 - val_acc: 0.7960\n",
"Epoch 92/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.4599 - acc: 0.9667 - val_loss: 0.7612 - val_acc: 0.7960\n",
"Epoch 93/100\n",
"1/1 [==============================] - 0s 95ms/step - loss: 0.4459 - acc: 1.0000 - val_loss: 0.7601 - val_acc: 0.7960\n",
"Epoch 94/100\n",
"1/1 [==============================] - 0s 87ms/step - loss: 0.4031 - acc: 0.9667 - val_loss: 0.7587 - val_acc: 0.7960\n",
"Epoch 95/100\n",
"1/1 [==============================] - 0s 86ms/step - loss: 0.4624 - acc: 0.9500 - val_loss: 0.7571 - val_acc: 0.7940\n",
"Epoch 96/100\n",
"1/1 [==============================] - 0s 103ms/step - loss: 0.4164 - acc: 0.9667 - val_loss: 0.7559 - val_acc: 0.7940\n",
"Epoch 97/100\n",
"1/1 [==============================] - 0s 97ms/step - loss: 0.4244 - acc: 0.9833 - val_loss: 0.7549 - val_acc: 0.7940\n",
"Epoch 98/100\n",
"1/1 [==============================] - 0s 93ms/step - loss: 0.4013 - acc: 0.9833 - val_loss: 0.7542 - val_acc: 0.7920\n",
"Epoch 99/100\n",
"1/1 [==============================] - 0s 91ms/step - loss: 0.4387 - acc: 0.9667 - val_loss: 0.7537 - val_acc: 0.7920\n",
"Epoch 100/100\n",
"1/1 [==============================] - 0s 101ms/step - loss: 0.4377 - acc: 0.9667 - val_loss: 0.7528 - val_acc: 0.7920\n",
"1/1 [==============================] - 0s 39ms/step - loss: 0.7651 - acc: 0.7890\n",
"Done.\n",
"Test loss: 0.77\n",
"Test accuracy: 0.79\n"
]
}
],
"source": [
"\"\"\"\n",
"This example implements the experiments on citation networks from the paper:\n",
"Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907)\n",
"Thomas N. Kipf, Max Welling\n",
"\"\"\"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"from tensorflow.keras.losses import CategoricalCrossentropy\n",
"from tensorflow.keras.optimizers import Adam\n",
"\n",
"from spektral.data.loaders import SingleLoader\n",
"from spektral.models.gcn import GCN\n",
"\n",
"learning_rate = .01\n",
"epochs = 100\n",
"patience = 10\n",
"\n",
"def mask_to_weights(mask):\n",
" return mask.astype(np.float32) / np.count_nonzero(mask)\n",
"\n",
"weights_tr, weights_va, weights_te = (\n",
" mask_to_weights(mask) for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n",
")\n",
"\n",
"# define the model\n",
"model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)\n",
"model.compile(\n",
" optimizer=Adam(learning_rate),\n",
" loss=CategoricalCrossentropy(reduction=\"sum\"),\n",
" weighted_metrics=[\"acc\"],\n",
")\n",
"\n",
"loader_tr = SingleLoader(dataset, sample_weights=weights_tr)\n",
"loader_va = SingleLoader(dataset, sample_weights=weights_va)\n",
"loader_te = SingleLoader(dataset, sample_weights=weights_te)\n",
"\n",
"# Train\n",
"model.fit(\n",
" loader_tr.load(),\n",
" steps_per_epoch=loader_tr.steps_per_epoch,\n",
" validation_data=loader_va.load(),\n",
" validation_steps=loader_va.steps_per_epoch,\n",
" epochs=epochs,\n",
" callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],\n",
")\n",
"\n",
"# Evaluate\n",
"eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)\n",
"\n",
"print('Done.\\n'\n",
" f'Test loss: {eval_results[0]:.2f}\\n'\n",
" f'Test accuracy: {eval_results[1]:.2f}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class PubMed(Dataset):\n",
"\n",
" def __init__(self, n_samples, n_nodes=3, n_min=10, n_max=100, p=0.1, **kwargs):\n",
" self.n_samples = n_samples\n",
" self.n_nodes = n_nodes\n",
" self.n_min = n_min\n",
" self.n_max = n_max\n",
" self.p = p\n",
" super().__init__(**kwargs)\n",
"\n",
" def read(self):\n",
" def make_graph():\n",
" n = np.random.randint(self.n_min, self.n_max)\n",
" colors = np.random.randint(0, self.n_colors, size=n)\n",
"\n",
" # Node features\n",
" x = np.zeros((n, self.n_colors))\n",
" x[np.arange(n), colors] = 1\n",
"\n",
" # Edges\n",
" a = np.random.rand(n, n) <= self.p\n",
" a = np.maximum(a, a.T).astype(int)\n",
" a = sp.csr_matrix(a)\n",
"\n",
" # Labels\n",
" y = np.zeros((self.n_colors,))\n",
" color_counts = x.sum(0)\n",
" y[np.argmax(color_counts)] = 1\n",
"\n",
" return Graph(x=x, a=a, y=y)\n",
"\n",
" # We must return a list of Graph objects\n",
" return [make_graph() for _ in range(self.n_samples)]\n",
"\n",
"\n",
"data = PubMed(1000, transforms=NormalizeAdj())\n"
]
}
],
"metadata": {
"interpreter": {
"hash": "4d4c55ad0dd25f9ca95e4d49a929aa3f71bfb37020ae570a9996c3e164818202"
},
"kernelspec": {
"display_name": "Python 3.9.7 64-bit ('py3': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}