Newer
Older
notebooks / gnn_pg.ipynb
Morteza Ansarinia on 26 Oct 2021 40 KB gnn playground using spektral
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install numpy spektral tensorflow -q\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import spektral\n",
    "from spektral.datasets.citation import Citation\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-processing node features\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/scipy/sparse/_index.py:125: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
      "  self._set_arrayXarray(i, j, x)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-processing node features\n",
      "epoch: 0,train_loss: 1.0980415344238281,val_acc: 0.4699999988079071,test_acc: 0.48900002241134644\n",
      "epoch: 1,train_loss: 1.0857902765274048,val_acc: 0.531999945640564,test_acc: 0.5269998908042908\n",
      "epoch: 5,train_loss: 1.0279830694198608,val_acc: 0.5459999442100525,test_acc: 0.5679998993873596\n",
      "epoch: 6,train_loss: 1.0107402801513672,val_acc: 0.6380000710487366,test_acc: 0.6259998679161072\n",
      "epoch: 7,train_loss: 0.9920498132705688,val_acc: 0.6820001006126404,test_acc: 0.6719997525215149\n",
      "epoch: 8,train_loss: 0.9725067615509033,val_acc: 0.7120000720024109,test_acc: 0.6989997625350952\n",
      "epoch: 9,train_loss: 0.9518073201179504,val_acc: 0.724000096321106,test_acc: 0.7109997868537903\n",
      "epoch: 10,train_loss: 0.929932177066803,val_acc: 0.736000120639801,test_acc: 0.7179997563362122\n",
      "epoch: 14,train_loss: 0.8337627053260803,val_acc: 0.7380000948905945,test_acc: 0.7209997773170471\n",
      "epoch: 15,train_loss: 0.8073488473892212,val_acc: 0.7420001029968262,test_acc: 0.7239997982978821\n"
     ]
    }
   ],
   "source": [
    "dataset = Citation('pubmed', normalize_x=True)\n",
    "dataset = Citation('pubmed', normalize_x=True, transforms=[LayerPreprocess(GCNConv)])\n",
    "\n",
    "adj = dataset.graphs[0].a.todense()\n",
    "features = dataset.graphs[0].x\n",
    "labels = dataset.graphs[0].y\n",
    "train_mask, val_mask, test_mask = dataset.mask_tr, dataset.mask_va, dataset.mask_te\n",
    "\n",
    "def masked_softmax_cross_entropy(logits, labels, mask):\n",
    "  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
    "  mask = tf.cast(mask, dtype=tf.float32)\n",
    "  mask /= tf.reduce_mean(mask)\n",
    "  loss *= mask\n",
    "  return tf.reduce_mean(loss)\n",
    "\n",
    "def masked_accuracy(logits, labels, mask):\n",
    "  correct_preds = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))\n",
    "  accuracy = tf.cast(correct_preds, dtype=tf.float32)\n",
    "  mask = tf.cast(mask, dtype=tf.float32)\n",
    "  mask /= tf.reduce_mean(mask)\n",
    "  accuracy *= mask\n",
    "  return tf.reduce_mean(accuracy)\n",
    "\n",
    "def gnn_fn(features, adj, transform, activation):\n",
    "  seq_fts =  transform(features)\n",
    "  ret_fts = tf.matmul(adj, seq_fts)\n",
    "  return activation(ret_fts)\n",
    "\n",
    "def train(features, adj, gnn_fn, units, epochs, lr=0.01):\n",
    "  lyr_1 = tf.keras.layers.Dense(units)\n",
    "  lyr_2 = tf.keras.layers.Dense(3)\n",
    "   \n",
    "  def gnn_net(features, adj):\n",
    "    hidden = gnn_fn(features, adj, lyr_1, tf.nn.relu)\n",
    "    logits = gnn_fn(hidden, adj, lyr_2, tf.identity)\n",
    "    return logits\n",
    "  \n",
    "  optimizer = tf.keras.optimizers.Adam(learning_rate=lr)\n",
    "  best_accuracy = 0.0\n",
    "  \n",
    "  for epoch in range(epochs+1):\n",
    "    with tf.GradientTape() as t:\n",
    "      logits = gnn_net(features, adj)\n",
    "      loss = masked_softmax_cross_entropy(logits, labels, train_mask)\n",
    "\n",
    "    variables = t.watched_variables()\n",
    "    grads = t.gradient(loss, variables)\n",
    "    optimizer.apply_gradients(zip(grads, variables))\n",
    "\n",
    "    logits = gnn_net(features, adj)\n",
    "    val_accuracy = masked_accuracy(logits, labels, val_mask)\n",
    "    test_accuracy = masked_accuracy(logits, labels, test_mask)\n",
    "    \n",
    "    if val_accuracy > best_accuracy:\n",
    "      best_accuracy = val_accuracy\n",
    "      print(f'epoch: {epoch},'\n",
    "            f'train_loss: {loss.numpy()},'\n",
    "            f'val_acc: {val_accuracy.numpy()},'\n",
    "            f'test_acc: {test_accuracy.numpy()}')\n",
    "\n",
    "train(features, adj, gnn_fn, 32, 20, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0,train_loss: 1.0996336936950684,val_acc: 0.5119999647140503,test_acc: 0.4919999837875366\n",
      "epoch: 2,train_loss: 1.073289394378662,val_acc: 0.5379999876022339,test_acc: 0.5219999551773071\n",
      "epoch: 3,train_loss: 1.0569909811019897,val_acc: 0.6000000238418579,test_acc: 0.5999998450279236\n",
      "epoch: 4,train_loss: 1.0379494428634644,val_acc: 0.6700000166893005,test_acc: 0.6739997863769531\n",
      "epoch: 5,train_loss: 1.0155421495437622,val_acc: 0.7160000801086426,test_acc: 0.7109997868537903\n",
      "epoch: 6,train_loss: 0.9916436672210693,val_acc: 0.7280001044273376,test_acc: 0.7139998078346252\n",
      "epoch: 9,train_loss: 0.9153128266334534,val_acc: 0.7300000786781311,test_acc: 0.7069997787475586\n",
      "epoch: 11,train_loss: 0.8576377034187317,val_acc: 0.7320001125335693,test_acc: 0.711999773979187\n",
      "epoch: 12,train_loss: 0.8269254565238953,val_acc: 0.7380000948905945,test_acc: 0.714999794960022\n",
      "epoch: 13,train_loss: 0.7955628037452698,val_acc: 0.7440001368522644,test_acc: 0.7159997820854187\n",
      "epoch: 15,train_loss: 0.7319157123565674,val_acc: 0.7460001111030579,test_acc: 0.714999794960022\n",
      "epoch: 17,train_loss: 0.6676653623580933,val_acc: 0.7500001192092896,test_acc: 0.7179997563362122\n",
      "epoch: 18,train_loss: 0.6357672810554504,val_acc: 0.752000093460083,test_acc: 0.7169997692108154\n"
     ]
    }
   ],
   "source": [
    "# on identity train(features, tf.eye(adj.shape[0]), gnn_fn, 32, 20, 0.01)\n",
    "\n",
    "# normalize adj by degree\n",
    "deg = tf.reduce_sum(adj, axis=-1)\n",
    "train(features, adj / deg, gnn_fn, 32, 20, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/3_/gmvd1nkx285133z5yh3chz2c0000gp/T/ipykernel_40315/3405598776.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mnorm_deg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1.9\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdeg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mnorm_adj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorm_deg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0madj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_deg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_adj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgnn_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    204\u001b[0m     \u001b[0;34m\"\"\"Call target, and fall back on dispatchers if there is a TypeError.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    205\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    207\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    208\u001b[0m       \u001b[0;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\u001b[0m in \u001b[0;36mmatmul\u001b[0;34m(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, output_type, name)\u001b[0m\n\u001b[1;32m   3652\u001b[0m             a, b, adj_x=adjoint_a, adj_y=adjoint_b, Tout=output_type, name=name)\n\u001b[1;32m   3653\u001b[0m       \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3654\u001b[0;31m         return gen_math_ops.mat_mul(\n\u001b[0m\u001b[1;32m   3655\u001b[0m             a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)\n\u001b[1;32m   3656\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/Caskroom/miniforge/base/envs/py3/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py\u001b[0m in \u001b[0;36mmat_mul\u001b[0;34m(a, b, transpose_a, transpose_b, name)\u001b[0m\n\u001b[1;32m   5689\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mtld\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_eager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5690\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5691\u001b[0;31m       _result = pywrap_tfe.TFE_Py_FastPathExecute(\n\u001b[0m\u001b[1;32m   5692\u001b[0m         \u001b[0m_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"MatMul\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"transpose_a\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtranspose_a\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"transpose_b\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   5693\u001b[0m         transpose_b)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# thomas gibb in graph conv networks: normalize by 1/sqrt(deg)\n",
    "\n",
    "norm_deg = tf.linalg.diag(1.9 / tf.sqrt(deg))\n",
    "norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))\n",
    "train(features, norm_adj, gnn_fn, 32, 200, 0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another impl:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1000\n",
      "1/1 [==============================] - 0s 462ms/step - loss: 1.1087 - acc: 0.2833 - val_loss: 1.1010 - val_acc: 0.6500\n",
      "Epoch 2/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 1.1005 - acc: 0.6167 - val_loss: 1.0974 - val_acc: 0.7500\n",
      "Epoch 3/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 1.0946 - acc: 0.7500 - val_loss: 1.0928 - val_acc: 0.7600\n",
      "Epoch 4/1000\n",
      "1/1 [==============================] - 0s 101ms/step - loss: 1.0894 - acc: 0.8833 - val_loss: 1.0872 - val_acc: 0.7560\n",
      "Epoch 5/1000\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 1.0805 - acc: 0.8167 - val_loss: 1.0809 - val_acc: 0.7520\n",
      "Epoch 6/1000\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 1.0727 - acc: 0.8667 - val_loss: 1.0744 - val_acc: 0.7580\n",
      "Epoch 7/1000\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 1.0636 - acc: 0.9000 - val_loss: 1.0681 - val_acc: 0.7520\n",
      "Epoch 8/1000\n",
      "1/1 [==============================] - 0s 105ms/step - loss: 1.0563 - acc: 0.8667 - val_loss: 1.0619 - val_acc: 0.7460\n",
      "Epoch 9/1000\n",
      "1/1 [==============================] - 0s 113ms/step - loss: 1.0451 - acc: 0.8667 - val_loss: 1.0557 - val_acc: 0.7460\n",
      "Epoch 10/1000\n",
      "1/1 [==============================] - 0s 120ms/step - loss: 1.0412 - acc: 0.8667 - val_loss: 1.0497 - val_acc: 0.7460\n",
      "Epoch 11/1000\n",
      "1/1 [==============================] - 0s 111ms/step - loss: 1.0284 - acc: 0.8833 - val_loss: 1.0435 - val_acc: 0.7460\n",
      "Epoch 12/1000\n",
      "1/1 [==============================] - 0s 101ms/step - loss: 1.0180 - acc: 0.9000 - val_loss: 1.0375 - val_acc: 0.7480\n",
      "Epoch 13/1000\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 1.0036 - acc: 0.9000 - val_loss: 1.0315 - val_acc: 0.7520\n",
      "Epoch 14/1000\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.9942 - acc: 0.9333 - val_loss: 1.0255 - val_acc: 0.7580\n",
      "Epoch 15/1000\n",
      "1/1 [==============================] - 0s 94ms/step - loss: 0.9864 - acc: 0.9167 - val_loss: 1.0196 - val_acc: 0.7620\n",
      "Epoch 16/1000\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.9864 - acc: 0.9167 - val_loss: 1.0137 - val_acc: 0.7620\n",
      "Epoch 17/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.9612 - acc: 0.9500 - val_loss: 1.0077 - val_acc: 0.7620\n",
      "Epoch 18/1000\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.9546 - acc: 0.9167 - val_loss: 1.0018 - val_acc: 0.7640\n",
      "Epoch 19/1000\n",
      "1/1 [==============================] - 0s 104ms/step - loss: 0.9423 - acc: 0.9500 - val_loss: 0.9957 - val_acc: 0.7660\n",
      "Epoch 20/1000\n",
      "1/1 [==============================] - 0s 108ms/step - loss: 0.9271 - acc: 0.9167 - val_loss: 0.9895 - val_acc: 0.7620\n",
      "Epoch 21/1000\n",
      "1/1 [==============================] - 0s 101ms/step - loss: 0.9319 - acc: 0.8833 - val_loss: 0.9832 - val_acc: 0.7620\n",
      "Epoch 22/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.9068 - acc: 0.9000 - val_loss: 0.9770 - val_acc: 0.7640\n",
      "Epoch 23/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.8857 - acc: 0.8667 - val_loss: 0.9708 - val_acc: 0.7640\n",
      "Epoch 24/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.8685 - acc: 0.8833 - val_loss: 0.9646 - val_acc: 0.7640\n",
      "Epoch 25/1000\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.8696 - acc: 0.9333 - val_loss: 0.9583 - val_acc: 0.7640\n",
      "Epoch 26/1000\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.8706 - acc: 0.9500 - val_loss: 0.9521 - val_acc: 0.7640\n",
      "Epoch 27/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.8500 - acc: 0.9167 - val_loss: 0.9459 - val_acc: 0.7680\n",
      "Epoch 28/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.8448 - acc: 0.9167 - val_loss: 0.9399 - val_acc: 0.7700\n",
      "Epoch 29/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.8462 - acc: 0.9333 - val_loss: 0.9339 - val_acc: 0.7660\n",
      "Epoch 30/1000\n",
      "1/1 [==============================] - 0s 84ms/step - loss: 0.8169 - acc: 0.9000 - val_loss: 0.9280 - val_acc: 0.7680\n",
      "Epoch 31/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.7901 - acc: 0.9500 - val_loss: 0.9223 - val_acc: 0.7700\n",
      "Epoch 32/1000\n",
      "1/1 [==============================] - 0s 95ms/step - loss: 0.8125 - acc: 0.9500 - val_loss: 0.9167 - val_acc: 0.7700\n",
      "Epoch 33/1000\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.7813 - acc: 0.9667 - val_loss: 0.9110 - val_acc: 0.7700\n",
      "Epoch 34/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.7728 - acc: 0.9167 - val_loss: 0.9055 - val_acc: 0.7700\n",
      "Epoch 35/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.7708 - acc: 0.9333 - val_loss: 0.9001 - val_acc: 0.7720\n",
      "Epoch 36/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.7519 - acc: 0.9333 - val_loss: 0.8949 - val_acc: 0.7700\n",
      "Epoch 37/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.7435 - acc: 0.9500 - val_loss: 0.8899 - val_acc: 0.7700\n",
      "Epoch 38/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.7134 - acc: 0.9333 - val_loss: 0.8849 - val_acc: 0.7680\n",
      "Epoch 39/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.7055 - acc: 0.9500 - val_loss: 0.8800 - val_acc: 0.7680\n",
      "Epoch 40/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.6887 - acc: 0.9667 - val_loss: 0.8752 - val_acc: 0.7680\n",
      "Epoch 41/1000\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.6847 - acc: 0.9667 - val_loss: 0.8703 - val_acc: 0.7680\n",
      "Epoch 42/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.6918 - acc: 0.9167 - val_loss: 0.8655 - val_acc: 0.7700\n",
      "Epoch 43/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.6699 - acc: 0.9333 - val_loss: 0.8608 - val_acc: 0.7700\n",
      "Epoch 44/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.6433 - acc: 0.9667 - val_loss: 0.8565 - val_acc: 0.7720\n",
      "Epoch 45/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.6713 - acc: 0.9167 - val_loss: 0.8523 - val_acc: 0.7720\n",
      "Epoch 46/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.6396 - acc: 0.9667 - val_loss: 0.8482 - val_acc: 0.7720\n",
      "Epoch 47/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.6291 - acc: 0.9167 - val_loss: 0.8444 - val_acc: 0.7740\n",
      "Epoch 48/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.6029 - acc: 0.9667 - val_loss: 0.8407 - val_acc: 0.7700\n",
      "Epoch 49/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.6102 - acc: 0.9500 - val_loss: 0.8371 - val_acc: 0.7700\n",
      "Epoch 50/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.6154 - acc: 0.9167 - val_loss: 0.8335 - val_acc: 0.7700\n",
      "Epoch 51/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.6160 - acc: 0.9500 - val_loss: 0.8302 - val_acc: 0.7700\n",
      "Epoch 52/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.5654 - acc: 0.9667 - val_loss: 0.8271 - val_acc: 0.7680\n",
      "Epoch 53/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.6097 - acc: 0.9500 - val_loss: 0.8242 - val_acc: 0.7700\n",
      "Epoch 54/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.5976 - acc: 0.9500 - val_loss: 0.8213 - val_acc: 0.7720\n",
      "Epoch 55/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.5914 - acc: 0.9000 - val_loss: 0.8188 - val_acc: 0.7720\n",
      "Epoch 56/1000\n",
      "1/1 [==============================] - 0s 93ms/step - loss: 0.5761 - acc: 0.9667 - val_loss: 0.8164 - val_acc: 0.7760\n",
      "Epoch 57/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.5613 - acc: 0.9667 - val_loss: 0.8141 - val_acc: 0.7800\n",
      "Epoch 58/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5773 - acc: 0.9500 - val_loss: 0.8117 - val_acc: 0.7800\n",
      "Epoch 59/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5665 - acc: 0.9500 - val_loss: 0.8094 - val_acc: 0.7800\n",
      "Epoch 60/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5408 - acc: 0.9667 - val_loss: 0.8069 - val_acc: 0.7800\n",
      "Epoch 61/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.5425 - acc: 0.9500 - val_loss: 0.8043 - val_acc: 0.7840\n",
      "Epoch 62/1000\n",
      "1/1 [==============================] - 0s 104ms/step - loss: 0.5627 - acc: 0.9167 - val_loss: 0.8019 - val_acc: 0.7840\n",
      "Epoch 63/1000\n",
      "1/1 [==============================] - 0s 96ms/step - loss: 0.5338 - acc: 0.9333 - val_loss: 0.7996 - val_acc: 0.7820\n",
      "Epoch 64/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.5082 - acc: 0.9667 - val_loss: 0.7972 - val_acc: 0.7820\n",
      "Epoch 65/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.5194 - acc: 0.9667 - val_loss: 0.7950 - val_acc: 0.7820\n",
      "Epoch 66/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4943 - acc: 0.9333 - val_loss: 0.7929 - val_acc: 0.7840\n",
      "Epoch 67/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.5025 - acc: 0.9667 - val_loss: 0.7909 - val_acc: 0.7820\n",
      "Epoch 68/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4835 - acc: 0.9833 - val_loss: 0.7891 - val_acc: 0.7800\n",
      "Epoch 69/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.5119 - acc: 0.9833 - val_loss: 0.7874 - val_acc: 0.7800\n",
      "Epoch 70/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5110 - acc: 0.9667 - val_loss: 0.7857 - val_acc: 0.7800\n",
      "Epoch 71/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4949 - acc: 0.9333 - val_loss: 0.7842 - val_acc: 0.7800\n",
      "Epoch 72/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4662 - acc: 0.9667 - val_loss: 0.7828 - val_acc: 0.7800\n",
      "Epoch 73/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.5109 - acc: 0.9833 - val_loss: 0.7810 - val_acc: 0.7780\n",
      "Epoch 74/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.4513 - acc: 0.9833 - val_loss: 0.7792 - val_acc: 0.7820\n",
      "Epoch 75/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.4873 - acc: 0.9500 - val_loss: 0.7772 - val_acc: 0.7840\n",
      "Epoch 76/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4595 - acc: 0.9667 - val_loss: 0.7753 - val_acc: 0.7820\n",
      "Epoch 77/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.5105 - acc: 0.9500 - val_loss: 0.7736 - val_acc: 0.7860\n",
      "Epoch 78/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.4737 - acc: 0.9667 - val_loss: 0.7717 - val_acc: 0.7880\n",
      "Epoch 79/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4541 - acc: 0.9833 - val_loss: 0.7702 - val_acc: 0.7900\n",
      "Epoch 80/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.4477 - acc: 1.0000 - val_loss: 0.7685 - val_acc: 0.7900\n",
      "Epoch 81/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.4596 - acc: 0.9500 - val_loss: 0.7667 - val_acc: 0.7920\n",
      "Epoch 82/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4661 - acc: 0.9333 - val_loss: 0.7651 - val_acc: 0.7920\n",
      "Epoch 83/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4452 - acc: 0.9500 - val_loss: 0.7639 - val_acc: 0.7940\n",
      "Epoch 84/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4559 - acc: 0.9667 - val_loss: 0.7631 - val_acc: 0.7940\n",
      "Epoch 85/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4705 - acc: 0.9500 - val_loss: 0.7623 - val_acc: 0.7960\n",
      "Epoch 86/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.4418 - acc: 0.9500 - val_loss: 0.7616 - val_acc: 0.7940\n",
      "Epoch 87/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4240 - acc: 0.9667 - val_loss: 0.7610 - val_acc: 0.7940\n",
      "Epoch 88/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4489 - acc: 0.9667 - val_loss: 0.7603 - val_acc: 0.7940\n",
      "Epoch 89/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4326 - acc: 0.9833 - val_loss: 0.7596 - val_acc: 0.7920\n",
      "Epoch 90/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4337 - acc: 0.9667 - val_loss: 0.7586 - val_acc: 0.7900\n",
      "Epoch 91/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4263 - acc: 0.9667 - val_loss: 0.7579 - val_acc: 0.7900\n",
      "Epoch 92/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.4165 - acc: 1.0000 - val_loss: 0.7568 - val_acc: 0.7900\n",
      "Epoch 93/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.4393 - acc: 0.9667 - val_loss: 0.7557 - val_acc: 0.7900\n",
      "Epoch 94/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.4316 - acc: 1.0000 - val_loss: 0.7544 - val_acc: 0.7900\n",
      "Epoch 95/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.4527 - acc: 0.9667 - val_loss: 0.7529 - val_acc: 0.7900\n",
      "Epoch 96/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4238 - acc: 0.9667 - val_loss: 0.7512 - val_acc: 0.7880\n",
      "Epoch 97/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.3868 - acc: 0.9833 - val_loss: 0.7496 - val_acc: 0.7940\n",
      "Epoch 98/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3973 - acc: 1.0000 - val_loss: 0.7478 - val_acc: 0.7940\n",
      "Epoch 99/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4228 - acc: 0.9667 - val_loss: 0.7459 - val_acc: 0.7940\n",
      "Epoch 100/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3667 - acc: 0.9833 - val_loss: 0.7443 - val_acc: 0.7960\n",
      "Epoch 101/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3947 - acc: 0.9667 - val_loss: 0.7430 - val_acc: 0.7960\n",
      "Epoch 102/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3955 - acc: 0.9667 - val_loss: 0.7416 - val_acc: 0.7940\n",
      "Epoch 103/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.4087 - acc: 0.9833 - val_loss: 0.7402 - val_acc: 0.7940\n",
      "Epoch 104/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.4469 - acc: 0.9833 - val_loss: 0.7387 - val_acc: 0.7940\n",
      "Epoch 105/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3881 - acc: 0.9833 - val_loss: 0.7375 - val_acc: 0.7940\n",
      "Epoch 106/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3735 - acc: 0.9833 - val_loss: 0.7365 - val_acc: 0.7940\n",
      "Epoch 107/1000\n",
      "1/1 [==============================] - 0s 84ms/step - loss: 0.4000 - acc: 0.9667 - val_loss: 0.7358 - val_acc: 0.7940\n",
      "Epoch 108/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3820 - acc: 1.0000 - val_loss: 0.7352 - val_acc: 0.7960\n",
      "Epoch 109/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.3557 - acc: 0.9833 - val_loss: 0.7345 - val_acc: 0.7980\n",
      "Epoch 110/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3602 - acc: 1.0000 - val_loss: 0.7337 - val_acc: 0.7960\n",
      "Epoch 111/1000\n",
      "1/1 [==============================] - 0s 98ms/step - loss: 0.4138 - acc: 0.9667 - val_loss: 0.7329 - val_acc: 0.7940\n",
      "Epoch 112/1000\n",
      "1/1 [==============================] - 0s 83ms/step - loss: 0.3778 - acc: 0.9667 - val_loss: 0.7324 - val_acc: 0.7960\n",
      "Epoch 113/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.3877 - acc: 0.9667 - val_loss: 0.7322 - val_acc: 0.7920\n",
      "Epoch 114/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3661 - acc: 1.0000 - val_loss: 0.7317 - val_acc: 0.7920\n",
      "Epoch 115/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3956 - acc: 0.9667 - val_loss: 0.7310 - val_acc: 0.7920\n",
      "Epoch 116/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3616 - acc: 1.0000 - val_loss: 0.7300 - val_acc: 0.7920\n",
      "Epoch 117/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.3594 - acc: 0.9667 - val_loss: 0.7291 - val_acc: 0.7880\n",
      "Epoch 118/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3739 - acc: 0.9833 - val_loss: 0.7285 - val_acc: 0.7860\n",
      "Epoch 119/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.3778 - acc: 0.9833 - val_loss: 0.7280 - val_acc: 0.7880\n",
      "Epoch 120/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3669 - acc: 1.0000 - val_loss: 0.7274 - val_acc: 0.7900\n",
      "Epoch 121/1000\n",
      "1/1 [==============================] - 0s 85ms/step - loss: 0.3288 - acc: 1.0000 - val_loss: 0.7266 - val_acc: 0.7880\n",
      "Epoch 122/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.3726 - acc: 0.9833 - val_loss: 0.7259 - val_acc: 0.7880\n",
      "Epoch 123/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3813 - acc: 0.9667 - val_loss: 0.7250 - val_acc: 0.7900\n",
      "Epoch 124/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3615 - acc: 0.9833 - val_loss: 0.7240 - val_acc: 0.7920\n",
      "Epoch 125/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3283 - acc: 1.0000 - val_loss: 0.7229 - val_acc: 0.7920\n",
      "Epoch 126/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3312 - acc: 0.9833 - val_loss: 0.7215 - val_acc: 0.7940\n",
      "Epoch 127/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.4086 - acc: 0.9667 - val_loss: 0.7206 - val_acc: 0.7960\n",
      "Epoch 128/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3588 - acc: 0.9667 - val_loss: 0.7195 - val_acc: 0.7920\n",
      "Epoch 129/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3678 - acc: 1.0000 - val_loss: 0.7184 - val_acc: 0.7920\n",
      "Epoch 130/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3201 - acc: 0.9833 - val_loss: 0.7175 - val_acc: 0.7920\n",
      "Epoch 131/1000\n",
      "1/1 [==============================] - 0s 100ms/step - loss: 0.3630 - acc: 0.9833 - val_loss: 0.7167 - val_acc: 0.7920\n",
      "Epoch 132/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3601 - acc: 0.9667 - val_loss: 0.7163 - val_acc: 0.7900\n",
      "Epoch 133/1000\n",
      "1/1 [==============================] - 0s 83ms/step - loss: 0.3231 - acc: 1.0000 - val_loss: 0.7158 - val_acc: 0.7900\n",
      "Epoch 134/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3124 - acc: 1.0000 - val_loss: 0.7150 - val_acc: 0.7900\n",
      "Epoch 135/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.3430 - acc: 0.9833 - val_loss: 0.7142 - val_acc: 0.7920\n",
      "Epoch 136/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3344 - acc: 0.9833 - val_loss: 0.7135 - val_acc: 0.7900\n",
      "Epoch 137/1000\n",
      "1/1 [==============================] - 0s 84ms/step - loss: 0.3537 - acc: 0.9833 - val_loss: 0.7130 - val_acc: 0.7900\n",
      "Epoch 138/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3424 - acc: 0.9667 - val_loss: 0.7128 - val_acc: 0.7880\n",
      "Epoch 139/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3199 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n",
      "Epoch 140/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3173 - acc: 0.9833 - val_loss: 0.7127 - val_acc: 0.7860\n",
      "Epoch 141/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3474 - acc: 0.9833 - val_loss: 0.7126 - val_acc: 0.7840\n",
      "Epoch 142/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3292 - acc: 1.0000 - val_loss: 0.7127 - val_acc: 0.7840\n",
      "Epoch 143/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3552 - acc: 0.9667 - val_loss: 0.7127 - val_acc: 0.7820\n",
      "Epoch 144/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3256 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n",
      "Epoch 145/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3204 - acc: 1.0000 - val_loss: 0.7125 - val_acc: 0.7840\n",
      "Epoch 146/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3133 - acc: 0.9833 - val_loss: 0.7122 - val_acc: 0.7800\n",
      "Epoch 147/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.3209 - acc: 1.0000 - val_loss: 0.7116 - val_acc: 0.7800\n",
      "Epoch 148/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3291 - acc: 1.0000 - val_loss: 0.7100 - val_acc: 0.7820\n",
      "Epoch 149/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3524 - acc: 0.9833 - val_loss: 0.7079 - val_acc: 0.7840\n",
      "Epoch 150/1000\n",
      "1/1 [==============================] - 0s 99ms/step - loss: 0.3336 - acc: 1.0000 - val_loss: 0.7057 - val_acc: 0.7860\n",
      "Epoch 151/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.3318 - acc: 0.9667 - val_loss: 0.7037 - val_acc: 0.7900\n",
      "Epoch 152/1000\n",
      "1/1 [==============================] - 0s 82ms/step - loss: 0.3615 - acc: 0.9833 - val_loss: 0.7022 - val_acc: 0.7920\n",
      "Epoch 153/1000\n",
      "1/1 [==============================] - 0s 86ms/step - loss: 0.3089 - acc: 1.0000 - val_loss: 0.7011 - val_acc: 0.7920\n",
      "Epoch 154/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3616 - acc: 0.9667 - val_loss: 0.7006 - val_acc: 0.7920\n",
      "Epoch 155/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.2931 - acc: 1.0000 - val_loss: 0.7000 - val_acc: 0.7940\n",
      "Epoch 156/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3272 - acc: 1.0000 - val_loss: 0.6993 - val_acc: 0.7960\n",
      "Epoch 157/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3255 - acc: 0.9833 - val_loss: 0.6992 - val_acc: 0.7980\n",
      "Epoch 158/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3066 - acc: 0.9833 - val_loss: 0.6993 - val_acc: 0.7980\n",
      "Epoch 159/1000\n",
      "1/1 [==============================] - 0s 89ms/step - loss: 0.3397 - acc: 0.9500 - val_loss: 0.6993 - val_acc: 0.7960\n",
      "Epoch 160/1000\n",
      "1/1 [==============================] - 0s 91ms/step - loss: 0.3396 - acc: 0.9833 - val_loss: 0.6993 - val_acc: 0.7960\n",
      "Epoch 161/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3275 - acc: 1.0000 - val_loss: 0.6997 - val_acc: 0.7960\n",
      "Epoch 162/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.3218 - acc: 0.9667 - val_loss: 0.7001 - val_acc: 0.7940\n",
      "Epoch 163/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3345 - acc: 0.9833 - val_loss: 0.7006 - val_acc: 0.7940\n",
      "Epoch 164/1000\n",
      "1/1 [==============================] - 0s 88ms/step - loss: 0.3028 - acc: 0.9833 - val_loss: 0.7010 - val_acc: 0.7940\n",
      "Epoch 165/1000\n",
      "1/1 [==============================] - 0s 90ms/step - loss: 0.3109 - acc: 1.0000 - val_loss: 0.7012 - val_acc: 0.7940\n",
      "Epoch 166/1000\n",
      "1/1 [==============================] - 0s 87ms/step - loss: 0.3116 - acc: 0.9667 - val_loss: 0.7015 - val_acc: 0.7940\n",
      "Epoch 167/1000\n",
      "1/1 [==============================] - 0s 92ms/step - loss: 0.2926 - acc: 1.0000 - val_loss: 0.7014 - val_acc: 0.7900\n",
      "1/1 [==============================] - 0s 38ms/step - loss: 0.7112 - acc: 0.7930\n",
      "Done.\n",
      "Test loss: 0.7112346291542053\n",
      "Test accuracy: 0.7930001020431519\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "This example implements the experiments on citation networks from the paper:\n",
    "Semi-Supervised Classification with Graph Convolutional Networks (https://arxiv.org/abs/1609.02907)\n",
    "Thomas N. Kipf, Max Welling\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.callbacks import EarlyStopping\n",
    "from tensorflow.keras.losses import CategoricalCrossentropy\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "\n",
    "from spektral.data.loaders import SingleLoader\n",
    "from spektral.layers import GCNConv\n",
    "from spektral.models.gcn import GCN\n",
    "from spektral.transforms import LayerPreprocess\n",
    "\n",
    "learning_rate = .01\n",
    "epochs = 100\n",
    "patience = 10\n",
    "\n",
    "# We convert the binary masks to sample weights so that we can compute the\n",
    "# average loss over the nodes (following original implementation by\n",
    "# Kipf & Welling)\n",
    "def mask_to_weights(mask):\n",
    "    return mask.astype(np.float32) / np.count_nonzero(mask)\n",
    "\n",
    "\n",
    "weights_tr, weights_va, weights_te = (\n",
    "    mask_to_weights(mask)\n",
    "    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)\n",
    ")\n",
    "\n",
    "# define the model\n",
    "model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)\n",
    "model.compile(\n",
    "    optimizer=Adam(learning_rate),\n",
    "    loss=CategoricalCrossentropy(reduction=\"sum\"),\n",
    "    weighted_metrics=[\"acc\"],\n",
    ")\n",
    "\n",
    "loader_tr = SingleLoader(dataset, sample_weights=weights_tr)\n",
    "loader_va = SingleLoader(dataset, sample_weights=weights_va)\n",
    "loader_te = SingleLoader(dataset, sample_weights=weights_te)\n",
    "\n",
    "# Train\n",
    "model.fit(\n",
    "    loader_tr.load(),\n",
    "    steps_per_epoch=loader_tr.steps_per_epoch,\n",
    "    validation_data=loader_va.load(),\n",
    "    validation_steps=loader_va.steps_per_epoch,\n",
    "    epochs=epochs,\n",
    "    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],\n",
    ")\n",
    "\n",
    "# Evaluate\n",
    "eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)\n",
    "\n",
    "print('Done.\\n'\n",
    "      f'Test loss: {eval_results[0]:2.f}\\n'\n",
    "      f'Test accuracy: {eval_results[1]:.2f}')"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4d4c55ad0dd25f9ca95e4d49a929aa3f71bfb37020ae570a9996c3e164818202"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('py3': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}