{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %conda activate py38\n",
"# %pip install spektral -Uq"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import spektral as sp\n",
"import spektral.models\n",
"\n",
"class CogTextDataset(sp.data.Dataset):\n",
" def __init__(self):\n",
" n_nodes = 11\n",
" node_features_dim = 12\n",
" edge_features_dim = 13\n",
" adj = np.random.randint(0, 2, (n_nodes, n_nodes))\n",
" node_features = np.random.rand(n_nodes, node_features_dim)\n",
" edge_features = np.random.rand(n_nodes, n_nodes, edge_features_dim)\n",
" y = np.random.randint(0, 2, (n_nodes))\n",
" g = sp.data.Graph(x=node_features,\n",
" a=adj,\n",
" e=edge_features,\n",
" y=y)\n",
" self.graphs = [g]\n",
" super().__init__()\n",
"\n",
" def read(self):\n",
" return self.graphs\n",
"\n",
"dataset = CogTextDataset()\n",
"loader = sp.data.loaders.SingleLoader(dataset, epochs=1)\n",
"X, y = list(loader)[0]\n",
"\n",
"model = sp.models.GeneralGNN(output=128, pool=None, connectivity='cat')\n",
"# model.compile()\n",
"y_pred = model(X)\n",
"y_pred.shape"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}