{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# %conda activate py38\n", "# %pip install spektral -Uq" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import spektral as sp\n", "import spektral.models\n", "\n", "class CogTextDataset(sp.data.Dataset):\n", " def __init__(self):\n", " n_nodes = 11\n", " node_features_dim = 12\n", " edge_features_dim = 13\n", " adj = np.random.randint(0, 2, (n_nodes, n_nodes))\n", " node_features = np.random.rand(n_nodes, node_features_dim)\n", " edge_features = np.random.rand(n_nodes, n_nodes, edge_features_dim)\n", " y = np.random.randint(0, 2, (n_nodes))\n", " g = sp.data.Graph(x=node_features,\n", " a=adj,\n", " e=edge_features,\n", " y=y)\n", " self.graphs = [g]\n", " super().__init__()\n", "\n", " def read(self):\n", " return self.graphs\n", "\n", "dataset = CogTextDataset()\n", "loader = sp.data.loaders.SingleLoader(dataset, epochs=1)\n", "X, y = list(loader)[0]\n", "\n", "model = sp.models.GeneralGNN(output=128, pool=None, connectivity='cat')\n", "# model.compile()\n", "y_pred = model(X)\n", "y_pred.shape" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }