Newer
Older
notebooks / weighted_random_walk.ipynb
Morteza Ansarinia on 23 Jan 2022 4 KB Created using Colaboratory
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "weighted_random_walk.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyMdSPfxyctAhjlb3X6mFu53",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/morteza/notebooks/blob/master/weighted_random_walk.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "3oRsjozrAf5F",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "78b36777-7a12-4cac-9f1c-de194c56f5e8"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  Building wheel for csrgraph (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install networkx csrgraph --quiet"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import networkx as nx\n",
        "import csrgraph as cg\n",
        "import numpy as np"
      ],
      "metadata": {
        "id": "8C3CEM_yAjN_"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "def weighted_metapath_random_walk(graph, n:int=None, length: int = None, metapaths = None):\n",
        "  _g = cg.csrgraph(graph)\n",
        "\n",
        "  # TODO assert length == len(metapaths[0]) == len(metapaths[1]) == ...\n",
        "  # TODO walks of length>len(metapaths[0]) are also valid if it uses the same methapath\n",
        "\n",
        "  valid_walks = []\n",
        "  while len(valid_walks) < n:\n",
        "    walks = _g.random_walks(length, n, start_nodes=None)\n",
        "    types = [[graph.nodes[node]['type']\n",
        "              for node in walk]\n",
        "             for walk in walks]\n",
        "    walks = [walks[i].tolist()\n",
        "             for i,t in enumerate(types)\n",
        "             if t in metapaths and\n",
        "                len(set(walks[i])) == len(walks[i])  # no loop\n",
        "            ]\n",
        "    valid_walks.extend(walks) \n",
        "\n",
        "  return valid_walks[:n]\n",
        "\n",
        "\n",
        "\n",
        "# EXAMPLE USAGE\n",
        "weights = np.random.random((5,5))\n",
        "np.fill_diagonal(weights, 0)\n",
        "\n",
        "G = nx.from_numpy_array(weights)\n",
        "G.nodes[0]['type'] = 'task'\n",
        "G.nodes[1]['type'] = 'task'\n",
        "G.nodes[2]['type'] = 'construct'\n",
        "G.nodes[3]['type'] = 'construct'\n",
        "G.nodes[4]['type'] = 'construct'\n",
        "\n",
        "metapaths = [\n",
        "  ['task', 'construct', 'task'],\n",
        "  ['construct', 'task', 'construct'],\n",
        "]\n",
        "\n",
        "weighted_metapath_random_walk(G, 30, 3, metapaths)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3QWQH4Q1A2TU",
        "outputId": "93b798ba-81f1-49de-f4f3-73446cc07038"
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[[0, 4, 1],\n",
              " [1, 2, 0],\n",
              " [0, 4, 1],\n",
              " [2, 0, 4],\n",
              " [2, 1, 4],\n",
              " [2, 1, 3],\n",
              " [4, 1, 2],\n",
              " [3, 1, 4],\n",
              " [2, 1, 4],\n",
              " [3, 0, 2],\n",
              " [2, 0, 4],\n",
              " [4, 0, 2],\n",
              " [2, 1, 4],\n",
              " [4, 0, 3],\n",
              " [4, 1, 3],\n",
              " [3, 1, 4],\n",
              " [3, 1, 2],\n",
              " [3, 0, 4],\n",
              " [4, 1, 3],\n",
              " [2, 1, 3],\n",
              " [4, 0, 3],\n",
              " [4, 1, 2],\n",
              " [2, 1, 4],\n",
              " [0, 2, 1],\n",
              " [1, 4, 0],\n",
              " [2, 1, 4],\n",
              " [0, 4, 1],\n",
              " [3, 1, 4],\n",
              " [4, 1, 2],\n",
              " [0, 4, 1]]"
            ]
          },
          "metadata": {},
          "execution_count": 31
        }
      ]
    }
  ]
}