{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Port PyTorch Quickstart to NNI\nThis is a modified version of `PyTorch quickstart`_.\n\nIt can be run directly and will have the exact same result as original version.\n\nFurthermore, it enables the ability of auto tuning with an NNI *experiment*, which will be detailed later.\n\nIt is recommended to run this script directly first to verify the environment.\n\nThere are 2 key differences from the original version:\n\n1. In `Get optimized hyperparameters`_ part, it receives generated hyperparameters.\n2. In `Train model and report accuracy`_ part, it reports accuracy metrics to NNI.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nni\nimport torch\nfrom torch import nn\nfrom torch.utils.data import DataLoader\nfrom torchvision import datasets\nfrom torchvision.transforms import ToTensor"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Hyperparameters to be tuned\nThese are the hyperparameters that will be tuned.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "params = {\n    'features': 512,\n    'lr': 0.001,\n    'momentum': 0,\n}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Get optimized hyperparameters\nIf run directly, :func:`nni.get_next_parameter` is a no-op and returns an empty dict.\nBut with an NNI *experiment*, it will receive optimized hyperparameters from tuning algorithm.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "optimized_params = nni.get_next_parameter()\nparams.update(optimized_params)\nprint(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load dataset\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "training_data = datasets.FashionMNIST(root=\"data\", train=True, download=True, transform=ToTensor())\ntest_data = datasets.FashionMNIST(root=\"data\", train=False, download=True, transform=ToTensor())\n\nbatch_size = 64\n\ntrain_dataloader = DataLoader(training_data, batch_size=batch_size)\ntest_dataloader = DataLoader(test_data, batch_size=batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Build model with hyperparameters\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nprint(f\"Using {device} device\")\n\nclass NeuralNetwork(nn.Module):\n    def __init__(self):\n        super(NeuralNetwork, self).__init__()\n        self.flatten = nn.Flatten()\n        self.linear_relu_stack = nn.Sequential(\n            nn.Linear(28*28, params['features']),\n            nn.ReLU(),\n            nn.Linear(params['features'], params['features']),\n            nn.ReLU(),\n            nn.Linear(params['features'], 10)\n        )\n\n    def forward(self, x):\n        x = self.flatten(x)\n        logits = self.linear_relu_stack(x)\n        return logits\n\nmodel = NeuralNetwork().to(device)\n\nloss_fn = nn.CrossEntropyLoss()\noptimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Define train and test\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def train(dataloader, model, loss_fn, optimizer):\n    size = len(dataloader.dataset)\n    model.train()\n    for batch, (X, y) in enumerate(dataloader):\n        X, y = X.to(device), y.to(device)\n        pred = model(X)\n        loss = loss_fn(pred, y)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\ndef test(dataloader, model, loss_fn):\n    size = len(dataloader.dataset)\n    num_batches = len(dataloader)\n    model.eval()\n    test_loss, correct = 0, 0\n    with torch.no_grad():\n        for X, y in dataloader:\n            X, y = X.to(device), y.to(device)\n            pred = model(X)\n            test_loss += loss_fn(pred, y).item()\n            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n    test_loss /= num_batches\n    correct /= size\n    return correct"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Train model and report accuracy\nReport accuracy metrics to NNI so the tuning algorithm can suggest better hyperparameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "epochs = 5\nfor t in range(epochs):\n    print(f\"Epoch {t+1}\\n-------------------------------\")\n    train(train_dataloader, model, loss_fn, optimizer)\n    accuracy = test(test_dataloader, model, loss_fn)\n    nni.report_intermediate_result(accuracy)\nnni.report_final_result(accuracy)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}