Source code for nni.nas.pytorch.spos.trainer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging

import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup

from .mutator import SPOSSupernetTrainingMutator

logger = logging.getLogger(__name__)


[docs]class SPOSSupernetTrainer(Trainer): """ This trainer trains a supernet that can be used for evolution search. Parameters ---------- model : nn.Module Model with mutables. mutator : Mutator A mutator object that has been initialized with the model. loss : callable Called with logits and targets. Returns a loss tensor. metrics : callable Returns a dict that maps metrics keys to metrics data. optimizer : Optimizer Optimizer that optimizes the model. num_epochs : int Number of epochs of training. train_loader : iterable Data loader of training. Raise ``StopIteration`` when one epoch is exhausted. dataset_valid : iterable Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted. batch_size : int Batch size. workers: int Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future. device : torch.device Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will automatic detects GPU and selects GPU first. log_frequency : int Number of mini-batches to log metrics. callbacks : list of Callback Callbacks to plug into the trainer. See Callbacks. """ def __init__(self, model, loss, metrics, optimizer, num_epochs, train_loader, valid_loader, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): assert torch.cuda.is_available() super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model), loss, metrics, optimizer, num_epochs, None, None, batch_size, workers, device, log_frequency, callbacks) self.train_loader = train_loader self.valid_loader = valid_loader
[docs] def train_one_epoch(self, epoch): self.model.train() meters = AverageMeterGroup() for step, (x, y) in enumerate(self.train_loader): self.optimizer.zero_grad() self.mutator.reset() logits = self.model(x) loss = self.loss(logits, y) loss.backward() self.optimizer.step() metrics = self.metrics(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters)
[docs] def validate_one_epoch(self, epoch): self.model.eval() meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(self.valid_loader): self.mutator.reset() logits = self.model(x) loss = self.loss(logits, y) metrics = self.metrics(logits, y) metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.valid_loader), meters)