diff --git a/tfp_keras_embedding_playground.ipynb b/tfp_keras_embedding_playground.ipynb
new file mode 100644
index 0000000..3fb2166
--- /dev/null
+++ b/tfp_keras_embedding_playground.ipynb
@@ -0,0 +1,400 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "tfp_keras_embedding_playground.ipynb",
+      "provenance": [],
+      "collapsed_sections": [],
+      "authorship_tag": "ABX9TyOUTUu0qPp27YU64gUuClas",
+      "include_colab_link": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "view-in-github",
+        "colab_type": "text"
+      },
+      "source": [
+        " "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 204
+        },
+        "id": "zNYhsni5Q_yO",
+        "outputId": "d7373ff1-dc2c-436d-9240-703c3fe5461d"
+      },
+      "source": [
+        "import pandas as pd\n",
+        "import numpy as np\n",
+        "import scipy.stats as stats\n",
+        "\n",
+        "import tensorflow as tf\n",
+        "import tensorflow.keras.layers as layers\n",
+        "import tensorflow_probability as tfp\n",
+        "\n",
+        "tfd = tfp.distributions\n",
+        "\n",
+        "# Data\n",
+        "n_samples = 10000\n",
+        "n_labels = 64\n",
+        "n_topics = 12\n",
+        "n_categories = 2\n",
+        "\n",
+        "\n",
+        "data = pd.DataFrame({\n",
+        "    'category': np.random.choice(n_categories, (n_samples,)),\n",
+        "    'label': np.random.choice(n_labels, (n_samples,)),\n",
+        "})\n",
+        "\n",
+        "features = stats.truncnorm(0,1).rvs(size=(n_samples, n_topics))\n",
+        "\n",
+        "data = pd.concat([data, pd.DataFrame(features)], axis=1)\n",
+        "\n",
+        "data.tail()\n"
+      ],
+      "execution_count": 1,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/html": [
+              "
"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 204
+        },
+        "id": "zNYhsni5Q_yO",
+        "outputId": "d7373ff1-dc2c-436d-9240-703c3fe5461d"
+      },
+      "source": [
+        "import pandas as pd\n",
+        "import numpy as np\n",
+        "import scipy.stats as stats\n",
+        "\n",
+        "import tensorflow as tf\n",
+        "import tensorflow.keras.layers as layers\n",
+        "import tensorflow_probability as tfp\n",
+        "\n",
+        "tfd = tfp.distributions\n",
+        "\n",
+        "# Data\n",
+        "n_samples = 10000\n",
+        "n_labels = 64\n",
+        "n_topics = 12\n",
+        "n_categories = 2\n",
+        "\n",
+        "\n",
+        "data = pd.DataFrame({\n",
+        "    'category': np.random.choice(n_categories, (n_samples,)),\n",
+        "    'label': np.random.choice(n_labels, (n_samples,)),\n",
+        "})\n",
+        "\n",
+        "features = stats.truncnorm(0,1).rvs(size=(n_samples, n_topics))\n",
+        "\n",
+        "data = pd.concat([data, pd.DataFrame(features)], axis=1)\n",
+        "\n",
+        "data.tail()\n"
+      ],
+      "execution_count": 1,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/html": [
+              "
\n",
+              "\n",
+              "
\n",
+              "  \n",
+              "    \n",
+              "      | \n",
+              " | category\n",
+              " | label\n",
+              " | 0\n",
+              " | 1\n",
+              " | 2\n",
+              " | 3\n",
+              " | 4\n",
+              " | 5\n",
+              " | 6\n",
+              " | 7\n",
+              " | 8\n",
+              " | 9\n",
+              " | 10\n",
+              " | 11\n",
+              " | 
\n",
+              "  \n",
+              "  \n",
+              "    \n",
+              "      | 9995\n",
+              " | 0\n",
+              " | 29\n",
+              " | 0.338359\n",
+              " | 0.200292\n",
+              " | 0.366758\n",
+              " | 0.512287\n",
+              " | 0.421813\n",
+              " | 0.412299\n",
+              " | 0.298286\n",
+              " | 0.217234\n",
+              " | 0.876204\n",
+              " | 0.026990\n",
+              " | 0.553567\n",
+              " | 0.579650\n",
+              " | 
\n",
+              "    \n",
+              "      | 9996\n",
+              " | 0\n",
+              " | 48\n",
+              " | 0.733919\n",
+              " | 0.308699\n",
+              " | 0.425183\n",
+              " | 0.713357\n",
+              " | 0.932479\n",
+              " | 0.623409\n",
+              " | 0.803112\n",
+              " | 0.768884\n",
+              " | 0.677390\n",
+              " | 0.269333\n",
+              " | 0.232070\n",
+              " | 0.141412\n",
+              " | 
\n",
+              "    \n",
+              "      | 9997\n",
+              " | 1\n",
+              " | 13\n",
+              " | 0.382530\n",
+              " | 0.146689\n",
+              " | 0.508429\n",
+              " | 0.578849\n",
+              " | 0.205561\n",
+              " | 0.785769\n",
+              " | 0.299319\n",
+              " | 0.671348\n",
+              " | 0.469920\n",
+              " | 0.164159\n",
+              " | 0.239871\n",
+              " | 0.142832\n",
+              " | 
\n",
+              "    \n",
+              "      | 9998\n",
+              " | 1\n",
+              " | 1\n",
+              " | 0.192772\n",
+              " | 0.806608\n",
+              " | 0.066029\n",
+              " | 0.309884\n",
+              " | 0.230482\n",
+              " | 0.110686\n",
+              " | 0.390677\n",
+              " | 0.798374\n",
+              " | 0.489189\n",
+              " | 0.558733\n",
+              " | 0.274848\n",
+              " | 0.120094\n",
+              " | 
\n",
+              "    \n",
+              "      | 9999\n",
+              " | 1\n",
+              " | 29\n",
+              " | 0.570087\n",
+              " | 0.859802\n",
+              " | 0.116288\n",
+              " | 0.976016\n",
+              " | 0.820458\n",
+              " | 0.020163\n",
+              " | 0.373489\n",
+              " | 0.004759\n",
+              " | 0.964626\n",
+              " | 0.957332\n",
+              " | 0.215921\n",
+              " | 0.497364\n",
+              " | 
\n",
+              "  \n",
+              "
\n",
+              "