Source code for nni.nas.evaluator.pytorch.lightning

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Union, Optional, List, Callable, Type

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torchmetrics.classification
import torch.utils.data as torch_data

import nni
from nni.common.serializer import is_traceable
try:
    from .cgo import trainer as cgo_trainer
    cgo_import_failed = False
except ImportError:
    cgo_import_failed = True

from nni.nas.evaluator import Evaluator
from nni.typehint import Literal


__all__ = [
    'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
    'SupervisedLearningModule', 'ClassificationModule', 'RegressionModule',
]

_logger = logging.getLogger(__name__)


[docs]class LightningModule(pl.LightningModule): """ Basic wrapper of generated model. Lightning modules used in NNI should inherit this class. It's a subclass of ``pytorch_lightning.LightningModule``. See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html """ running_mode: Literal['multi', 'oneshot'] = 'multi' """An indicator of whether current module is running in a multi-trial experiment or an one-shot. This flag should be automatically set by experiments when they start to run. """
[docs] def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None: """Set the inner model (architecture) to train / evaluate. Parameters ---------- model : callable or nn.Module Can be a callable returning nn.Module or nn.Module. """ if isinstance(model, nn.Module): self.model = model else: self.model = model()
Trainer = nni.trace(pl.Trainer) Trainer.__doc__ = """ Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html """ DataLoader = nni.trace(torch_data.DataLoader) DataLoader.__doc__ = """ Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html """
[docs]@nni.trace class Lightning(Evaluator): """ Delegate the whole training to PyTorch Lightning. Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or ``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in which case, dataloaders are not required for this class to work. Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics`` in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name and type depend on the specific task. .. warning:: The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator, please note that the inner ``lightning_module`` and ``trainer`` will be reused. Parameters ---------- lightning_module Lightning module that defines the training logic. trainer Lightning trainer that handles the training. train_dataloders Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__. val_dataloaders Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__. fit_kwargs Keyword arguments passed to ``trainer.fit()``. """ def __init__(self, lightning_module: LightningModule, trainer: Trainer, train_dataloaders: Optional[Any] = None, val_dataloaders: Optional[Any] = None, train_dataloader: Optional[Any] = None, fit_kwargs: Optional[Dict[str, Any]] = None): assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' if train_dataloader is not None: warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) train_dataloaders = train_dataloader if cgo_import_failed: assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}' else: # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \ f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' if not _check_dataloader(train_dataloaders): warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' f'import DataLoader from {__name__}: {train_dataloaders}', RuntimeWarning) if not _check_dataloader(val_dataloaders): warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' f'import DataLoader from {__name__}: {val_dataloaders}', RuntimeWarning) self.module = lightning_module self.trainer = trainer self.train_dataloaders = train_dataloaders self.val_dataloaders = val_dataloaders self.fit_kwargs = fit_kwargs or {} @staticmethod def _load(ir): return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders']) def _dump(self): return { 'type': self.__class__, 'module': self.module, 'trainer': self.trainer, 'train_dataloaders': self.train_dataloaders, 'val_dataloaders': self.val_dataloaders } def _execute(self, model_cls): return self.fit(model_cls) @property def train_dataloader(self): warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning) def __eq__(self, other): eq_func = False eq_args = False if other is None: return False if hasattr(self, "function") and hasattr(other, "function"): eq_func = getattr(self, "function") == getattr(other, "function") elif not (hasattr(self, "function") or hasattr(other, "function")): eq_func = True if hasattr(self, "arguments") and hasattr(other, "arguments"): eq_args = getattr(self, "arguments") == getattr(other, "arguments") elif not (hasattr(self, "arguments") or hasattr(other, "arguments")): eq_args = True return eq_func and eq_args def fit(self, model): """ Fit the model with provided dataloader, with Lightning trainer. If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called. Parameters ---------- model : nn.Module The model to fit. """ self.module.set_model(model) if self.train_dataloaders is None: _logger.info('Train dataloaders are missing. Skip to validation.') return self.trainer.validate(self.module, self.val_dataloaders, **self.fit_kwargs) else: if self.val_dataloaders is None: _logger.warning('Validation dataloaders are missing.') return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs)
def _check_dataloader(dataloader): # Check the type of dataloader recursively. if isinstance(dataloader, list): return all([_check_dataloader(d) for d in dataloader]) if isinstance(dataloader, dict): return all([_check_dataloader(v) for v in dataloader.values()]) if isinstance(dataloader, torch_data.DataLoader): return is_traceable(dataloader) return True ### The following are some commonly used Lightning modules ### class SupervisedLearningModule(LightningModule): trainer: pl.Trainer def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric], learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, export_onnx: Union[Path, str, bool, None] = None): super().__init__() self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay') self.criterion = criterion() self.optimizer = optimizer self.metrics = nn.ModuleDict(metrics) if export_onnx is None or export_onnx is True: self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx' elif export_onnx: self.export_onnx = Path(export_onnx) else: self.export_onnx = None def forward(self, x): y_hat = self.model(x) return y_hat def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = self.criterion(y_hat, y) self.log('train_loss', loss, prog_bar=True) for name, metric in self.metrics.items(): self.log('train_' + name, metric(y_hat, y), prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) if self.running_mode == 'multi' and self.export_onnx is not None: self.export_onnx.parent.mkdir(exist_ok=True) try: self.to_onnx(self.export_onnx, x, export_params=True) except RuntimeError as e: warnings.warn(f'ONNX conversion failed. As a result, you might not be able to use visualization. Error message: {e}') self.export_onnx = None self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) for name, metric in self.metrics.items(): self.log('val_' + name, metric(y_hat, y), prog_bar=True) def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) self.log('test_loss', self.criterion(y_hat, y), prog_bar=True) for name, metric in self.metrics.items(): self.log('test_' + name, metric(y_hat, y), prog_bar=True) def configure_optimizers(self): return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore def on_validation_epoch_end(self): if not self.trainer.sanity_checking and self.running_mode == 'multi' and nni.get_current_parameter() is not None: # Don't report metric when sanity checking nni.report_intermediate_result(self._get_validation_metrics()) def on_fit_end(self): self._final_report() def on_validation_end(self): self._final_report() def _final_report(self): if self.running_mode == 'multi' and nni.get_current_parameter() is not None: nni.report_final_result(self._get_validation_metrics()) def _get_validation_metrics(self): if len(self.metrics) == 1: metric_name = next(iter(self.metrics)) return self.trainer.callback_metrics['val_' + metric_name].item() else: warnings.warn('Multiple metrics without "default" is not supported by current framework.') return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics} class _AccuracyWithLogits(torchmetrics.Accuracy): # Only for torchmetrics < 0.11 def update(self, pred, target): return super().update(nn_functional.softmax(pred, dim=-1), target) # type: ignore
[docs]@nni.trace class ClassificationModule(SupervisedLearningModule): def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, export_onnx: bool = True, num_classes: Optional[int] = None): from packaging.version import Version if Version(torchmetrics.__version__) < Version('0.11.0'): # Older version accepts num_classes = None metrics = {'acc': _AccuracyWithLogits()} # type: ignore # pylint: disable=no-value-for-parameter else: if num_classes is None: raise ValueError('num_classes must be specified for torchmetrics >= 0.11. ' 'Please either specify it or use an older version of torchmetrics.') metrics = {'acc': torchmetrics.Accuracy('multiclass', num_classes=num_classes)} super().__init__(criterion, metrics, # type: ignore learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
[docs]class Classification(Lightning): """ Evaluator that is used for classification. Available callback metrics in :class:`Classification` are: - train_loss - train_acc - val_loss - val_acc Parameters ---------- criterion : nn.Module Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss`` learning_rate : float Learning rate. default: 0.001 weight_decay : float L2 weight decay. default: 0 optimizer : Optimizer Class for optimizer (not an instance). default: ``Adam`` train_dataloaders : DataLoader Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. val_dataloaders : DataLoader or List of DataLoader Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. export_onnx : bool If true, model will be exported to ``model.onnx`` before training starts. default true num_classes : int Number of classes for classification task. Required for torchmetrics >= 0.11.0. default: None trainer_kwargs : dict Optional keyword arguments passed to trainer. See `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. Examples -------- >>> evaluator = Classification() To use customized criterion and optimizer: >>> evaluator = Classification(nn.LabelSmoothingCrossEntropy, optimizer=torch.optim.SGD) Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration: >>> evaluator = Classification(accelerator='gpu', devices=2, strategy='ddp') """ def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, train_dataloaders: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, export_onnx: bool = True, train_dataloader: Optional[DataLoader] = None, num_classes: Optional[int] = None, **trainer_kwargs): if train_dataloader is not None: warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) train_dataloaders = train_dataloader module = ClassificationModule(criterion=criterion, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx, num_classes=num_classes) super().__init__(module, Trainer(**trainer_kwargs), train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
[docs]@nni.trace class RegressionModule(SupervisedLearningModule): def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, export_onnx: bool = True): super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError()}, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
[docs]class Regression(Lightning): """ Evaluator that is used for regression. Available callback metrics in :class:`Regression` are: - train_loss - train_mse - val_loss - val_mse Parameters ---------- criterion : nn.Module Class for criterion module (not an instance). default: ``nn.MSELoss`` learning_rate : float Learning rate. default: 0.001 weight_decay : float L2 weight decay. default: 0 optimizer : Optimizer Class for optimizer (not an instance). default: ``Adam`` train_dataloaders : DataLoader Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. val_dataloaders : DataLoader or List of DataLoader Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. export_onnx : bool If true, model will be exported to ``model.onnx`` before training starts. default: true trainer_kwargs : dict Optional keyword arguments passed to trainer. See `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. Examples -------- >>> evaluator = Regression() Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration: >>> evaluator = Regression(gpus=1) """ def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, train_dataloaders: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, export_onnx: bool = True, train_dataloader: Optional[DataLoader] = None, **trainer_kwargs): if train_dataloader is not None: warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) train_dataloaders = train_dataloader module = RegressionModule(criterion=criterion, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) super().__init__(module, Trainer(**trainer_kwargs), train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
# Alias for backwards compatibility _SupervisedLearningModule = SupervisedLearningModule _ClassificationModule = ClassificationModule _RegressionModule = RegressionModule