Newer
Older
notebooks / gnn_pg.ipynb
Morteza Ansarinia on 2 Nov 2021 25 KB some fixes in gnn
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install numpy spektral tensorflow -q\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import spektral\n",
    "from spektral.datasets.citation import Citation\n",
    "from tqdm import tqdm\n",
    "from spektral.transforms import LayerPreprocess\n",
    "from spektral.layers import GCNConv\n",
    "from spektral.transforms import NormalizeAdj\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-processing node features\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/scipy/sparse/_index.py:125: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
      "  self._set_arrayXarray(i, j, x)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-processing node features\n"
     ]
    }
   ],
   "source": [
    "dataset = Citation('pubmed', normalize_x=True)\n",
    "dataset = Citation('pubmed', normalize_x=True, transforms=[LayerPreprocess(GCNConv)])\n",
    "\n",
    "adj = dataset.graphs[0].a.todense()\n",
    "features = dataset.graphs[0].x\n",
    "labels = dataset.graphs[0].y\n",
    "train_mask, val_mask, test_mask = dataset.mask_tr, dataset.mask_va, dataset.mask_te\n",
    "\n",
    "def masked_softmax_cross_entropy(logits, labels, mask):\n",
    "  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
    "  mask = tf.cast(mask, dtype=tf.float32)\n",
    "  mask /= tf.reduce_mean(mask)\n",
    "  loss *= mask\n",
    "  return tf.reduce_mean(loss)\n",
    "\n",
    "def masked_accuracy(logits, labels, mask):\n",
    "  correct_preds = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))\n",
    "  accuracy = tf.cast(correct_preds, dtype=tf.float32)\n",
    "  mask = tf.cast(mask, dtype=tf.float32)\n",
    "  mask /= tf.reduce_mean(mask)\n",
    "  accuracy *= mask\n",
    "  return tf.reduce_mean(accuracy)\n",
    "\n",
    "def gnn_fn(features, adj, transform, activation):\n",
    "  seq_fts =  transform(features)\n",
    "  ret_fts = tf.matmul(adj, seq_fts)\n",
    "  return activation(ret_fts)\n",
    "\n",
    "def train(features, adj, gnn_fn, units, epochs, lr=0.01):\n",
    "  lyr_1 = tf.keras.layers.Dense(units)\n",
    "  lyr_2 = tf.keras.layers.Dense(3)\n",
    "   \n",
    "  def gnn_net(features, adj):\n",
    "    hidden = gnn_fn(features, adj, lyr_1, tf.nn.relu)\n",
    "    logits = gnn_fn(hidden, adj, lyr_2, tf.identity)\n",
    "    return logits\n",
    "  \n",
    "  optimizer = tf.keras.optimizers.Adam(learning_rate=lr)\n",
    "  best_accuracy = 0.0\n",
    "  \n",
    "  for epoch in range(epochs+1):\n",
    "    with tf.GradientTape() as t:\n",
    "      logits = gnn_net(features, adj)\n",
    "      loss = masked_softmax_cross_entropy(logits, labels, train_mask)\n",
    "\n",
    "    variables = t.watched_variables()\n",
    "    grads = t.gradient(loss, variables)\n",
    "    optimizer.apply_gradients(zip(grads, variables))\n",
    "\n",
    "    logits = gnn_net(features, adj)\n",
    "    val_accuracy = masked_accuracy(logits, labels, val_mask)\n",
    "    test_accuracy = masked_accuracy(logits, labels, test_mask)\n",
    "    \n",
    "    if val_accuracy > best_accuracy:\n",
    "      best_accuracy = val_accuracy\n",
    "      print(f'epoch: {epoch},'\n",
    "            f'train_loss: {loss.numpy()},'\n",
    "            f'val_acc: {val_accuracy.numpy()},'\n",
    "            f'test_acc: {test_accuracy.numpy()}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0,train_loss: 1.0988209247589111,val_acc: 0.49799996614456177,test_acc: 0.5019999742507935\n",
      "epoch: 3,train_loss: 1.0498652458190918,val_acc: 0.515999972820282,test_acc: 0.528999924659729\n",
      "epoch: 4,train_loss: 1.0287455320358276,val_acc: 0.5720000267028809,test_acc: 0.6109998226165771\n",
      "epoch: 5,train_loss: 1.0060378313064575,val_acc: 0.6500000357627869,test_acc: 0.6719998121261597\n",
      "epoch: 6,train_loss: 0.9812273383140564,val_acc: 0.6760000586509705,test_acc: 0.6909998059272766\n",
      "epoch: 7,train_loss: 0.9548961520195007,val_acc: 0.6880000233650208,test_acc: 0.6979997158050537\n",
      "epoch: 8,train_loss: 0.9275616407394409,val_acc: 0.7100000977516174,test_acc: 0.6979997158050537\n",
      "epoch: 9,train_loss: 0.899043619632721,val_acc: 0.7140000462532043,test_acc: 0.7019997835159302\n",
      "epoch: 11,train_loss: 0.8381972908973694,val_acc: 0.7180001139640808,test_acc: 0.7079997658729553\n",
      "epoch: 12,train_loss: 0.8062182664871216,val_acc: 0.7220000624656677,test_acc: 0.711999773979187\n",
      "epoch: 13,train_loss: 0.7739933133125305,val_acc: 0.7300000786781311,test_acc: 0.714999794960022\n",
      "epoch: 14,train_loss: 0.7413352131843567,val_acc: 0.7380000948905945,test_acc: 0.7249997854232788\n",
      "epoch: 15,train_loss: 0.7085117697715759,val_acc: 0.7420001029968262,test_acc: 0.7279998064041138\n",
      "epoch: 19,train_loss: 0.5778310894966125,val_acc: 0.7440000772476196,test_acc: 0.7369998097419739\n",
      "epoch: 20,train_loss: 0.5461804866790771,val_acc: 0.7440001368522644,test_acc: 0.7369998097419739\n"
     ]
    }
   ],
   "source": [
    "# on identity train(features, tf.eye(adj.shape[0]), gnn_fn, 32, 20, 0.01)\n",
    "\n",
    "# no normalization: train(features, adj, gnn_fn, 32, 20, 0.01)\n",
    "\n",
    "# normalize adj by degree\n",
    "deg = tf.reduce_sum(adj, axis=-1)\n",
    "train(features, adj / deg, gnn_fn, 32, 20, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# thomas gibb in graph conv networks: normalize by 1/sqrt(deg)\n",
    "\n",
    "# norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg))\n",
    "# norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n",
    "# train(features, norm_adj, gnn_fn, 32, 200, 0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another impl:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "1/1 [==============================] - 0s 429ms/step - loss: 1.1066 - acc: 0.3833 - val_loss: 1.1029 - val_acc: 0.4720\n",
      "Epoch 2/100\n",
      "1/1 [==============================] - 0s 84ms/step - loss: 1.1028 - acc: 0.5167 - val_loss: 1.0993 - val_acc: 0.5660\n",
      "Epoch 3/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 1.0975 - acc: 0.6167 - val_loss: 1.0951 - val_acc: 0.6140\n",
      "Epoch 4/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 1.0911 - acc: 0.7500 - val_loss: 1.0905 - val_acc: 0.6480\n",
      "Epoch 5/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 1.0841 - acc: 0.8667 - val_loss: 1.0855 - val_acc: 0.6700\n",
      "Epoch 6/100\n",
      "1/1 [==============================] - 0s 96ms/step - loss: 1.0779 - acc: 0.8333 - val_loss: 1.0803 - val_acc: 0.6780\n",
      "Epoch 7/100\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 1.0696 - acc: 0.8667 - val_loss: 1.0750 - val_acc: 0.6840\n",
      "Epoch 8/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 1.0639 - acc: 0.8667 - val_loss: 1.0696 - val_acc: 0.6880\n",
      "Epoch 9/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 1.0566 - acc: 0.8333 - val_loss: 1.0641 - val_acc: 0.6920\n",
      "Epoch 10/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 1.0499 - acc: 0.9000 - val_loss: 1.0586 - val_acc: 0.6900\n",
      "Epoch 11/100\n",
      "1/1 [==============================] - 0s 96ms/step - loss: 1.0377 - acc: 0.8833 - val_loss: 1.0531 - val_acc: 0.6900\n",
      "Epoch 12/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 1.0297 - acc: 0.8833 - val_loss: 1.0475 - val_acc: 0.6940\n",
      "Epoch 13/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 1.0265 - acc: 0.8833 - val_loss: 1.0419 - val_acc: 0.6940\n",
      "Epoch 14/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 1.0154 - acc: 0.8167 - val_loss: 1.0364 - val_acc: 0.6960\n",
      "Epoch 15/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 1.0056 - acc: 0.8500 - val_loss: 1.0309 - val_acc: 0.7000\n",
      "Epoch 16/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.9946 - acc: 0.8500 - val_loss: 1.0254 - val_acc: 0.7080\n",
      "Epoch 17/100\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.9747 - acc: 0.9167 - val_loss: 1.0198 - val_acc: 0.7100\n",
      "Epoch 18/100\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.9775 - acc: 0.8833 - val_loss: 1.0143 - val_acc: 0.7240\n",
      "Epoch 19/100\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.9703 - acc: 0.8167 - val_loss: 1.0088 - val_acc: 0.7280\n",
      "Epoch 20/100\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.9489 - acc: 0.8833 - val_loss: 1.0032 - val_acc: 0.7320\n",
      "Epoch 21/100\n",
      "1/1 [==============================] - 0s 99ms/step - loss: 0.9471 - acc: 0.8500 - val_loss: 0.9977 - val_acc: 0.7340\n",
      "Epoch 22/100\n",
      "1/1 [==============================] - 0s 111ms/step - loss: 0.9274 - acc: 0.8833 - val_loss: 0.9921 - val_acc: 0.7360\n",
      "Epoch 23/100\n",
      "1/1 [==============================] - 0s 106ms/step - loss: 0.9140 - acc: 0.9333 - val_loss: 0.9863 - val_acc: 0.7320\n",
      "Epoch 24/100\n",
      "1/1 [==============================] - 0s 103ms/step - loss: 0.9167 - acc: 0.8833 - val_loss: 0.9805 - val_acc: 0.7380\n",
      "Epoch 25/100\n",
      "1/1 [==============================] - 0s 96ms/step - loss: 0.9001 - acc: 0.9000 - val_loss: 0.9748 - val_acc: 0.7380\n",
      "Epoch 26/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.8726 - acc: 0.9500 - val_loss: 0.9691 - val_acc: 0.7440\n",
      "Epoch 27/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.8624 - acc: 0.9333 - val_loss: 0.9636 - val_acc: 0.7420\n",
      "Epoch 28/100\n",
      "1/1 [==============================] - 0s 105ms/step - loss: 0.8646 - acc: 0.8833 - val_loss: 0.9581 - val_acc: 0.7460\n",
      "Epoch 29/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.8331 - acc: 0.9333 - val_loss: 0.9527 - val_acc: 0.7460\n",
      "Epoch 30/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.8285 - acc: 0.9500 - val_loss: 0.9473 - val_acc: 0.7460\n",
      "Epoch 31/100\n",
      "1/1 [==============================] - 0s 85ms/step - loss: 0.8413 - acc: 0.9167 - val_loss: 0.9419 - val_acc: 0.7460\n",
      "Epoch 32/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.8097 - acc: 0.9167 - val_loss: 0.9364 - val_acc: 0.7520\n",
      "Epoch 33/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.7880 - acc: 0.9000 - val_loss: 0.9310 - val_acc: 0.7540\n",
      "Epoch 34/100\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.7893 - acc: 0.9667 - val_loss: 0.9257 - val_acc: 0.7560\n",
      "Epoch 35/100\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.7557 - acc: 0.9667 - val_loss: 0.9205 - val_acc: 0.7580\n",
      "Epoch 36/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.7780 - acc: 0.9167 - val_loss: 0.9150 - val_acc: 0.7580\n",
      "Epoch 37/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.7612 - acc: 0.9500 - val_loss: 0.9096 - val_acc: 0.7640\n",
      "Epoch 38/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.7419 - acc: 0.9500 - val_loss: 0.9043 - val_acc: 0.7660\n",
      "Epoch 39/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.7383 - acc: 0.8833 - val_loss: 0.8992 - val_acc: 0.7680\n",
      "Epoch 40/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.7241 - acc: 0.9500 - val_loss: 0.8943 - val_acc: 0.7700\n",
      "Epoch 41/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.7062 - acc: 0.9500 - val_loss: 0.8897 - val_acc: 0.7720\n",
      "Epoch 42/100\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.7407 - acc: 0.8500 - val_loss: 0.8853 - val_acc: 0.7720\n",
      "Epoch 43/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.7206 - acc: 0.9667 - val_loss: 0.8810 - val_acc: 0.7740\n",
      "Epoch 44/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.6992 - acc: 0.9333 - val_loss: 0.8767 - val_acc: 0.7760\n",
      "Epoch 45/100\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.6923 - acc: 0.9500 - val_loss: 0.8725 - val_acc: 0.7760\n",
      "Epoch 46/100\n",
      "1/1 [==============================] - 0s 97ms/step - loss: 0.6670 - acc: 0.9667 - val_loss: 0.8683 - val_acc: 0.7740\n",
      "Epoch 47/100\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.6801 - acc: 0.9333 - val_loss: 0.8643 - val_acc: 0.7720\n",
      "Epoch 48/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.6728 - acc: 0.9500 - val_loss: 0.8605 - val_acc: 0.7740\n",
      "Epoch 49/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.6843 - acc: 0.9500 - val_loss: 0.8567 - val_acc: 0.7740\n",
      "Epoch 50/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.6378 - acc: 0.9333 - val_loss: 0.8531 - val_acc: 0.7780\n",
      "Epoch 51/100\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.6508 - acc: 0.9500 - val_loss: 0.8495 - val_acc: 0.7760\n",
      "Epoch 52/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.6212 - acc: 0.9500 - val_loss: 0.8461 - val_acc: 0.7740\n",
      "Epoch 53/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.6273 - acc: 0.9000 - val_loss: 0.8430 - val_acc: 0.7760\n",
      "Epoch 54/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.6076 - acc: 0.9667 - val_loss: 0.8399 - val_acc: 0.7740\n",
      "Epoch 55/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.5923 - acc: 0.9333 - val_loss: 0.8370 - val_acc: 0.7740\n",
      "Epoch 56/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.6108 - acc: 0.9333 - val_loss: 0.8343 - val_acc: 0.7760\n",
      "Epoch 57/100\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.5720 - acc: 0.9333 - val_loss: 0.8317 - val_acc: 0.7760\n",
      "Epoch 58/100\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.5809 - acc: 0.9500 - val_loss: 0.8291 - val_acc: 0.7760\n",
      "Epoch 59/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.6027 - acc: 0.9333 - val_loss: 0.8264 - val_acc: 0.7740\n",
      "Epoch 60/100\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.5793 - acc: 0.9333 - val_loss: 0.8238 - val_acc: 0.7740\n",
      "Epoch 61/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.5688 - acc: 0.9333 - val_loss: 0.8211 - val_acc: 0.7700\n",
      "Epoch 62/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.5791 - acc: 0.8833 - val_loss: 0.8184 - val_acc: 0.7720\n",
      "Epoch 63/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5321 - acc: 0.9667 - val_loss: 0.8156 - val_acc: 0.7720\n",
      "Epoch 64/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5385 - acc: 0.9500 - val_loss: 0.8129 - val_acc: 0.7700\n",
      "Epoch 65/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5376 - acc: 0.9833 - val_loss: 0.8100 - val_acc: 0.7740\n",
      "Epoch 66/100\n",
      "1/1 [==============================] - 0s 97ms/step - loss: 0.5222 - acc: 0.9833 - val_loss: 0.8073 - val_acc: 0.7780\n",
      "Epoch 67/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5196 - acc: 0.9500 - val_loss: 0.8046 - val_acc: 0.7800\n",
      "Epoch 68/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.5265 - acc: 0.9500 - val_loss: 0.8022 - val_acc: 0.7860\n",
      "Epoch 69/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5388 - acc: 0.9500 - val_loss: 0.7998 - val_acc: 0.7900\n",
      "Epoch 70/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4889 - acc: 0.9667 - val_loss: 0.7975 - val_acc: 0.7920\n",
      "Epoch 71/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.5209 - acc: 0.9333 - val_loss: 0.7953 - val_acc: 0.7940\n",
      "Epoch 72/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5043 - acc: 0.9500 - val_loss: 0.7933 - val_acc: 0.7960\n",
      "Epoch 73/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.5118 - acc: 0.9667 - val_loss: 0.7914 - val_acc: 0.7960\n",
      "Epoch 74/100\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4964 - acc: 0.9833 - val_loss: 0.7894 - val_acc: 0.7960\n",
      "Epoch 75/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5301 - acc: 0.9500 - val_loss: 0.7874 - val_acc: 0.7940\n",
      "Epoch 76/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4703 - acc: 0.9833 - val_loss: 0.7856 - val_acc: 0.7940\n",
      "Epoch 77/100\n",
      "1/1 [==============================] - 0s 106ms/step - loss: 0.4777 - acc: 0.9667 - val_loss: 0.7839 - val_acc: 0.7920\n",
      "Epoch 78/100\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.4782 - acc: 0.9500 - val_loss: 0.7822 - val_acc: 0.7920\n",
      "Epoch 79/100\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.4735 - acc: 0.9667 - val_loss: 0.7804 - val_acc: 0.7920\n",
      "Epoch 80/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4739 - acc: 0.9333 - val_loss: 0.7786 - val_acc: 0.7920\n",
      "Epoch 81/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5066 - acc: 0.9500 - val_loss: 0.7768 - val_acc: 0.7900\n",
      "Epoch 82/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.4626 - acc: 0.9667 - val_loss: 0.7750 - val_acc: 0.7900\n",
      "Epoch 83/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4573 - acc: 0.9667 - val_loss: 0.7733 - val_acc: 0.7920\n",
      "Epoch 84/100\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4399 - acc: 1.0000 - val_loss: 0.7715 - val_acc: 0.7920\n",
      "Epoch 85/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4751 - acc: 0.9667 - val_loss: 0.7698 - val_acc: 0.7940\n",
      "Epoch 86/100\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4270 - acc: 0.9833 - val_loss: 0.7683 - val_acc: 0.7940\n",
      "Epoch 87/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4611 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7940\n",
      "Epoch 88/100\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.3962 - acc: 0.9833 - val_loss: 0.7652 - val_acc: 0.7980\n",
      "Epoch 89/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.4901 - acc: 0.9667 - val_loss: 0.7640 - val_acc: 0.7980\n",
      "Epoch 90/100\n",
      "1/1 [==============================] - 0s 97ms/step - loss: 0.4380 - acc: 0.9500 - val_loss: 0.7630 - val_acc: 0.7980\n",
      "Epoch 91/100\n",
      "1/1 [==============================] - 0s 100ms/step - loss: 0.4628 - acc: 0.9500 - val_loss: 0.7620 - val_acc: 0.7960\n",
      "Epoch 92/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4599 - acc: 0.9667 - val_loss: 0.7612 - val_acc: 0.7960\n",
      "Epoch 93/100\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.4459 - acc: 1.0000 - val_loss: 0.7601 - val_acc: 0.7960\n",
      "Epoch 94/100\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.4031 - acc: 0.9667 - val_loss: 0.7587 - val_acc: 0.7960\n",
      "Epoch 95/100\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.4624 - acc: 0.9500 - val_loss: 0.7571 - val_acc: 0.7940\n",
      "Epoch 96/100\n",
      "1/1 [==============================] - 0s 103ms/step - loss: 0.4164 - acc: 0.9667 - val_loss: 0.7559 - val_acc: 0.7940\n",
      "Epoch 97/100\n",
      "1/1 [==============================] - 0s 97ms/step - loss: 0.4244 - acc: 0.9833 - val_loss: 0.7549 - val_acc: 0.7940\n",
      "Epoch 98/100\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.4013 - acc: 0.9833 - val_loss: 0.7542 - val_acc: 0.7920\n",
      "Epoch 99/100\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4387 - acc: 0.9667 - val_loss: 0.7537 - val_acc: 0.7920\n",
      "Epoch 100/100\n",
      "1/1 [==============================] - 0s 101ms/step - loss: 0.4377 - acc: 0.9667 - val_loss: 0.7528 - val_acc: 0.7920\n",
      "1/1 [==============================] - 0s 39ms/step - loss: 0.7651 - acc: 0.7890\n",
      "Done.\n",
      "Test loss: 0.77\n",
      "Test accuracy: 0.79\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "This example implements the experiments on citation networks from the paper:\n",
    "Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907)\n",
    "Thomas N. Kipf, Max Welling\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "\n",
    "from spektral.data.loaders import SingleLoader\n",
    "from spektral.models.gcn import GCN\n",
    "\n",
    "learning_rate = .01\n",
    "epochs = 100\n",
    "patience = 10\n",
    "\n",
    "def mask_to_weights(mask):\n",
    "    return mask.astype(np.float32) / np.count_nonzero(mask)\n",
    "\n",
    "weights_tr, weights_va, weights_te = (\n",
    "    mask_to_weights(mask) for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n",
    ")\n",
    "\n",
    "# define the model\n",
    "model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)\n",
    "model.compile(\n",
    "    optimizer=Adam(learning_rate),\n",
    "    loss=CategoricalCrossentropy(reduction=\"sum\"),\n",
    "    weighted_metrics=[\"acc\"],\n",
    ")\n",
    "\n",
    "loader_tr = SingleLoader(dataset, sample_weights=weights_tr)\n",
    "loader_va = SingleLoader(dataset, sample_weights=weights_va)\n",
    "loader_te = SingleLoader(dataset, sample_weights=weights_te)\n",
    "\n",
    "# Train\n",
    "model.fit(\n",
    "    loader_tr.load(),\n",
    "    steps_per_epoch=loader_tr.steps_per_epoch,\n",
    "    validation_data=loader_va.load(),\n",
    "    validation_steps=loader_va.steps_per_epoch,\n",
    "    epochs=epochs,\n",
    "    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],\n",
    ")\n",
    "\n",
    "# Evaluate\n",
    "eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)\n",
    "\n",
    "print('Done.\\n'\n",
    "      f'Test loss: {eval_results[0]:.2f}\\n'\n",
    "      f'Test accuracy: {eval_results[1]:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PubMed(Dataset):\n",
    "\n",
    "    def __init__(self, n_samples, n_nodes=3, n_min=10, n_max=100, p=0.1, **kwargs):\n",
    "        self.n_samples = n_samples\n",
    "        self.n_nodes = n_nodes\n",
    "        self.n_min = n_min\n",
    "        self.n_max = n_max\n",
    "        self.p = p\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def read(self):\n",
    "        def make_graph():\n",
    "            n = np.random.randint(self.n_min, self.n_max)\n",
    "            colors = np.random.randint(0, self.n_colors, size=n)\n",
    "\n",
    "            # Node features\n",
    "            x = np.zeros((n, self.n_colors))\n",
    "            x[np.arange(n), colors] = 1\n",
    "\n",
    "            # Edges\n",
    "            a = np.random.rand(n, n) <= self.p\n",
    "            a = np.maximum(a, a.T).astype(int)\n",
    "            a = sp.csr_matrix(a)\n",
    "\n",
    "            # Labels\n",
    "            y = np.zeros((self.n_colors,))\n",
    "            color_counts = x.sum(0)\n",
    "            y[np.argmax(color_counts)] = 1\n",
    "\n",
    "            return Graph(x=x, a=a, y=y)\n",
    "\n",
    "        # We must return a list of Graph objects\n",
    "        return [make_graph() for _ in range(self.n_samples)]\n",
    "\n",
    "\n",
    "data = PubMed(1000, transforms=NormalizeAdj())\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4d4c55ad0dd25f9ca95e4d49a929aa3f71bfb37020ae570a9996c3e164818202"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('py3': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}