# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import warnings

import torch
import torch.nn as nn
import torch.optim as optim
from import SubsetRandomSampler, DataLoader

from nni.nas.oneshot.pytorch.enas import ReinforceController, ReinforceField

from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device

_logger = logging.getLogger(__name__)

[docs]class EnasTrainer(BaseOneShotTrainer): """ ENAS trainer. Parameters ---------- model : nn.Module PyTorch model to be trained. loss : callable Receives logits and ground truth label, return a loss tensor. metrics : callable Receives logits and ground truth label, return a dict of metrics. reward_function : callable Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. optimizer : Optimizer The optimizer used for optimizing the model. num_epochs : int Number of epochs planned for training. dataset : Dataset Dataset for training. Will be split for training weights and architecture weights. batch_size : int Batch size. workers : int Workers for data loading. device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. log_frequency : int Step count per logging. grad_clip : float Gradient clipping. Set to 0 to disable. Default: 5. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. ctrl_lr : float Learning rate for RL controller. ctrl_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. ctrl_steps : int Number of mini-batches for each epoch of RL controller learning. ctrl_kwargs : dict Optional kwargs that will be passed to :class:`ReinforceController`. """ def __init__(self, model, loss, metrics, reward_function, optimizer, num_epochs, dataset, batch_size=64, workers=4, device=None, log_frequency=None, grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None): warnings.warn('EnasTrainer is deprecated. Please use strategy.ENAS instead.', DeprecationWarning) self.model = model self.loss = loss self.metrics = metrics self.optimizer = optimizer self.num_epochs = num_epochs self.dataset = dataset self.batch_size = batch_size self.workers = workers self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device self.log_frequency = log_frequency self.nas_modules = [] replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules) replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules) for _, module in self.nas_modules: self.nas_fields = [ReinforceField(name, len(module), isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1) for name, module in self.nas_modules] self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {})) self.grad_clip = grad_clip self.reward_function = reward_function self.ctrl_optim = optim.Adam(self.controller.parameters(), lr=ctrl_lr) self.batch_size = batch_size self.workers = workers self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.baseline = 0. self.ctrl_steps_aggregate = ctrl_steps_aggregate self.init_dataloader() def init_dataloader(self): n_train = len(self.dataset) split = n_train // 2 indices = list(range(n_train)) train_sampler = SubsetRandomSampler(indices[:-split]) valid_sampler = SubsetRandomSampler(indices[-split:]) self.train_loader = DataLoader(self.dataset, batch_size=self.batch_size, sampler=train_sampler, num_workers=self.workers) self.valid_loader = DataLoader(self.dataset, batch_size=self.batch_size, sampler=valid_sampler, num_workers=self.workers) def _train_model(self, epoch): self.model.train() self.controller.eval() meters = AverageMeterGroup() for step, (x, y) in enumerate(self.train_loader): x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() self._resample() logits = self.model(x) metrics = self.metrics(logits, y) loss = self.loss(logits, y) loss.backward() if self.grad_clip > 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() metrics['loss'] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0:'Model Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters) def _train_controller(self, epoch): self.model.eval() self.controller.train() meters = AverageMeterGroup() self.ctrl_optim.zero_grad() for ctrl_step, (x, y) in enumerate(self.valid_loader): x, y = to_device(x, self.device), to_device(y, self.device) self._resample() with torch.no_grad(): logits = self.model(x) metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) if self.entropy_weight: reward += self.entropy_weight * self.controller.sample_entropy.item() # type: ignore self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) loss = self.controller.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.controller.sample_skip_penalty metrics['reward'] = reward metrics['loss'] = loss.item() metrics['ent'] = self.controller.sample_entropy.item() # type: ignore metrics['log_prob'] = self.controller.sample_log_prob.item() # type: ignore metrics['baseline'] = self.baseline metrics['skip'] = self.controller.sample_skip_penalty loss /= self.ctrl_steps_aggregate loss.backward() meters.update(metrics) if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0: if self.grad_clip > 0: nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip) self.ctrl_optim.step() self.ctrl_optim.zero_grad() if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:'RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs, ctrl_step + 1, len(self.valid_loader), meters) def _resample(self): result = self.controller.resample() for name, module in self.nas_modules: module.sampled = result[name] def fit(self): for i in range(self.num_epochs): self._train_model(i) self._train_controller(i) def export(self): self.controller.eval() with torch.no_grad(): return self.controller.resample()