# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from itertools import cycle

import torch
import torch.nn as nn
import torch.optim as optim

from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup, to_device
from .mutator import EnasMutator

logger = logging.getLogger(__name__)

[docs]class EnasTrainer(Trainer): """ ENAS trainer. Parameters ---------- model : nn.Module PyTorch model to be trained. loss : callable Receives logits and ground truth label, return a loss tensor. metrics : callable Receives logits and ground truth label, return a dict of metrics. reward_function : callable Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. optimizer : Optimizer The optimizer used for optimizing the model. num_epochs : int Number of epochs planned for training. dataset_train : Dataset Dataset for training. Will be split for training weights and architecture weights. dataset_valid : Dataset Dataset for testing. mutator : EnasMutator Use when customizing your own mutator or a mutator with customized parameters. batch_size : int Batch size. workers : int Workers for data loading. device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. log_frequency : int Step count per logging. callbacks : list of Callback list of callbacks to trigger at events. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. child_steps : int How many mini-batches for model training per epoch. mutator_lr : float Learning rate for RL controller. mutator_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. mutator_steps : int Number of mini-batches for each epoch of RL controller learning. aux_weight : float Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. test_arc_per_epoch : int How many architectures are chosen for direct test after each epoch. """ def __init__(self, model, loss, metrics, reward_function, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, test_arc_per_epoch=1): super().__init__(model, mutator if mutator is not None else EnasMutator(model), loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks) self.reward_function = reward_function self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) self.batch_size = batch_size self.workers = workers self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.baseline = 0. self.mutator_steps_aggregate = mutator_steps_aggregate self.mutator_steps = mutator_steps self.child_steps = child_steps self.aux_weight = aux_weight self.test_arc_per_epoch = test_arc_per_epoch self.init_dataloader() def init_dataloader(self): n_train = len(self.dataset_train) split = n_train // 10 indices = list(range(n_train)) train_sampler =[:-split]) valid_sampler =[-split:]) self.train_loader =, batch_size=self.batch_size, sampler=train_sampler, num_workers=self.workers) self.valid_loader =, batch_size=self.batch_size, sampler=valid_sampler, num_workers=self.workers) self.test_loader =, batch_size=self.batch_size, num_workers=self.workers) self.train_loader = cycle(self.train_loader) self.valid_loader = cycle(self.valid_loader)
[docs] def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() for step in range(1, self.child_steps + 1): x, y = next(self.train_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() self._write_graph_status() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) loss = self.loss(logits, y) loss = loss + self.aux_weight * aux_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0:"Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step, self.child_steps, meters) # Train sampler (mutator) self.model.eval() self.mutator.train() meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) self._write_graph_status() metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) if self.entropy_weight: reward += self.entropy_weight * self.mutator.sample_entropy.item() self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["log_prob"] = self.mutator.sample_log_prob.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() meters.update(metrics) cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0:"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, meters) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step()
[docs] def validate_one_epoch(self, epoch): with torch.no_grad(): for arc_id in range(self.test_arc_per_epoch): meters = AverageMeterGroup() for x, y in self.test_loader: x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, _ = logits metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics)"Test Epoch [%d/%d] Arc [%d/%d] Summary %s", epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, meters.summary())