Newer
Older
notebooks / spektral_generalgnn.ipynb
Morteza Ansarinia on 3 Nov 2021 1 KB add spektral general_gnn (SOTA)
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %conda activate py38\n",
    "# %pip install spektral -Uq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import spektral as sp\n",
    "import spektral.models\n",
    "\n",
    "class CogTextDataset(sp.data.Dataset):\n",
    "  def __init__(self):\n",
    "    n_nodes = 11\n",
    "    node_features_dim = 12\n",
    "    edge_features_dim = 13\n",
    "    adj = np.random.randint(0, 2, (n_nodes, n_nodes))\n",
    "    node_features = np.random.rand(n_nodes, node_features_dim)\n",
    "    edge_features = np.random.rand(n_nodes, n_nodes, edge_features_dim)\n",
    "    y = np.random.randint(0, 2, (n_nodes))\n",
    "    g = sp.data.Graph(x=node_features,\n",
    "                      a=adj,\n",
    "                      e=edge_features,\n",
    "                      y=y)\n",
    "    self.graphs = [g]\n",
    "    super().__init__()\n",
    "\n",
    "  def read(self):\n",
    "    return self.graphs\n",
    "\n",
    "dataset = CogTextDataset()\n",
    "loader = sp.data.loaders.SingleLoader(dataset, epochs=1)\n",
    "X, y = list(loader)[0]\n",
    "\n",
    "model = sp.models.GeneralGNN(output=128, pool=None, connectivity='cat')\n",
    "# model.compile()\n",
    "y_pred = model(X)\n",
    "y_pred.shape"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}