Newer
Older
notebooks / torch_playground.ipynb
{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Loss Functions\n",
    "\n",
    "https://sparrow.dev/cross-entropy-loss-in-pytorch/"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# Single-label binary\n",
    "x = torch.randn(10)\n",
    "yhat = torch.sigmoid(x)\n",
    "y = torch.randint(2, (10,), dtype=torch.float)\n",
    "loss = nn.BCELoss()(yhat, y)\n",
    "\n",
    "# Single-label binary with automatic sigmoid\n",
    "loss = nn.BCEWithLogitsLoss()(x, y)\n",
    "\n",
    "# Single-label categorical\n",
    "x = torch.randn(10, 5)\n",
    "y = torch.randint(5, (10,))\n",
    "loss = nn.CrossEntropyLoss()(x, y)\n",
    "\n",
    "# Multi-label categorical\n",
    "x = torch.randn(10, 5)\n",
    "yhat = torch.sigmoid(x)\n",
    "y = torch.randint(2, (10, 5), dtype=torch.float)\n",
    "loss = nn.BCELoss()(y, y)\n",
    "\n",
    "# Multi-label categorical with automatic sigmoid\n",
    "loss = nn.BCEWithLogitsLoss()(x, y)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Data loading and training loop\n",
    "\n",
    "- dataset\n",
    "- data loader\n",
    "- epochs and minibatches"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import TensorDataset, DataLoader, random_split\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from sklearn.metrics import accuracy_score, explained_variance_score, r2_score\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "# data params\n",
    "n_samples = 100\n",
    "n_predictors = 10\n",
    "n_outcomes = 2\n",
    "\n",
    "# split params\n",
    "train_size = int(n_samples * .8)\n",
    "test_size = n_samples - train_size\n",
    "\n",
    "# training params\n",
    "n_epoches = 100\n",
    "batch_size = 20\n",
    "\n",
    "# DATA\n",
    "X = torch.rand((n_samples, n_predictors))\n",
    "y = torch.rand((n_samples, n_outcomes))\n",
    "\n",
    "# Tensorboard logger\n",
    "logger = SummaryWriter()\n",
    "\n",
    "dataset = TensorDataset(X, y)\n",
    "train_subset, test_subset = random_split(dataset, lengths=(train_size,test_size))\n",
    "\n",
    "X_test, y_test = dataset[test_subset.indices]\n",
    "\n",
    "model = nn.Linear(n_predictors, n_outcomes)\n",
    "logger.add_graph(model, X)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "criterion = nn.MSELoss()\n",
    "\n",
    "for epoch in range(n_epoches):\n",
    "\n",
    "  # train\n",
    "  model.train()\n",
    "  train_epoch_loss = torch.tensor(0.)\n",
    "  for X_batch, y_batch in DataLoader(train_subset, batch_size=batch_size):\n",
    "    model.zero_grad()\n",
    "    y_pred = model(X_batch)\n",
    "    loss = criterion(y_batch, y_pred)\n",
    "    logger.add_scalar('loss/train', loss.detach(), epoch)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "  # eval\n",
    "  model.eval()\n",
    "  test_epoch_accuracy = torch.tensor(0.)\n",
    "  with torch.no_grad():\n",
    "    y_pred = model(X_test)\n",
    "    loss = criterion(y_test, y_pred)\n",
    "    logger.add_scalar('loss/test', loss.detach(), epoch)\n",
    "    \n",
    "    ev = explained_variance_score(y_test, y_pred)\n",
    "    logger.add_scalar('explained_variance/test', ev, epoch)"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.9.4",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.9.4 64-bit ('py3': conda)"
  },
  "interpreter": {
   "hash": "5ddcf14c786c671500c086f61f0b66d0417d6c58ff12753e346e191a84f72b84"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}