Note
Click here to download the full example code
Port PyTorch Quickstart to NNI¶
This is a modified version of PyTorch quickstart.
It can be run directly and will have the exact same result as original version.
Furthermore, it enables the ability of auto tuning with an NNI experiment, which will be detailed later.
It is recommended to run this script directly first to verify the environment.
There are 2 key differences from the original version:
In Get optimized hyperparameters part, it receives generated hyperparameters.
In Train model and report accuracy part, it reports accuracy metrics to NNI.
import nni
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
Hyperparameters to be tuned¶
These are the hyperparameters that will be tuned.
params = {
'features': 512,
'lr': 0.001,
'momentum': 0,
}
Get optimized hyperparameters¶
If run directly, nni.get_next_parameter()
is a no-op and returns an empty dict.
But with an NNI experiment, it will receive optimized hyperparameters from tuning algorithm.
optimized_params = nni.get_next_parameter()
params.update(optimized_params)
print(params)
Out:
{'features': 512, 'lr': 0.001, 'momentum': 0}
Load dataset¶
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
Build model with hyperparameters¶
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, params['features']),
nn.ReLU(),
nn.Linear(params['features'], params['features']),
nn.ReLU(),
nn.Linear(params['features'], 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'])
Out:
Using cpu device
Define train and test¶
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
return correct
Train model and report accuracy¶
Report accuracy metrics to NNI so the tuning algorithm can suggest better hyperparameters.
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
accuracy = test(test_dataloader, model, loss_fn)
nni.report_intermediate_result(accuracy)
nni.report_final_result(accuracy)
Out:
Epoch 1
-------------------------------
[2022-03-21 01:09:37] INFO (nni/MainThread) Intermediate result: 0.461 (Index 0)
Epoch 2
-------------------------------
[2022-03-21 01:09:42] INFO (nni/MainThread) Intermediate result: 0.5529 (Index 1)
Epoch 3
-------------------------------
[2022-03-21 01:09:47] INFO (nni/MainThread) Intermediate result: 0.6155 (Index 2)
Epoch 4
-------------------------------
[2022-03-21 01:09:52] INFO (nni/MainThread) Intermediate result: 0.6345 (Index 3)
Epoch 5
-------------------------------
[2022-03-21 01:09:56] INFO (nni/MainThread) Intermediate result: 0.6505 (Index 4)
[2022-03-21 01:09:56] INFO (nni/MainThread) Final result: 0.6505
Total running time of the script: ( 0 minutes 24.441 seconds)