nni.nas.evaluator.pytorch.lightning 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Union, Optional, List, Type

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torchmetrics.classification
import torch.utils.data as torch_data

import nni
from nni.common.serializer import is_traceable
from nni.nas.evaluator import MutableEvaluator

__all__ = [
    'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
    'SupervisedLearningModule', 'ClassificationModule', 'RegressionModule',
]

_logger = logging.getLogger(__name__)


[文档] class LightningModule(pl.LightningModule): """ Basic wrapper of generated model. Lightning modules used in NNI should inherit this class. It's a subclass of ``pytorch_lightning.LightningModule``. See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html See :class:`SupervisedLearningModule` as an example. """ @property def model(self) -> nn.Module: """The inner model (architecture) to train / evaluate. It will be only available after calling :meth:`set_model`. """ model = getattr(self, '_model', None) if model is None: raise RuntimeError('Model is not set. Please call set_model() first.') return model
[文档] def set_model(self, model: nn.Module) -> None: """Set the inner model (architecture) to train / evaluate. As there is no explicit method to "unset" a model, the model is left in the lightning module after the method is called. We don't recommend relying on this behavior. """ if not isinstance(model, nn.Module): raise TypeError('model must be an instance of nn.Module') self._model = model
Trainer = nni.trace(pl.Trainer) Trainer.__doc__ = """ Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html """ DataLoader = nni.trace(torch_data.DataLoader) DataLoader.__doc__ = """ Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html """
[文档] @nni.trace class Lightning(MutableEvaluator): """ Delegate the whole training to PyTorch Lightning. Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or ``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in which case, dataloaders are not required for this class to work. Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics`` in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name and type depend on the specific task. .. warning:: The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator, please note that the inner ``lightning_module`` and ``trainer`` will be reused. Parameters ---------- lightning_module Lightning module that defines the training logic. trainer Lightning trainer that handles the training. train_dataloders Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. It can be any types of dataloader supported by Lightning. val_dataloaders Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. It can be any types of dataloader supported by Lightning. datamodule Used in ``trainer.fit()``. See `Lightning DataModule <https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html>`__. fit_kwargs Keyword arguments passed to ``trainer.fit()``. detect_interrupt Lightning has a `graceful shutdown <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ mechanism. It does not terminate the whole program (but only the training) when a KeyboardInterrupt is received. Setting this to ``True`` will raise the KeyboardInterrupt to the main process, so that the whole program can be terminated. Examples -------- Users should define a Lightning module that inherits :class:`LightningModule`, and use :class:`Trainer` and :class:`DataLoader` from ```nni.nas.evaluator.pytorch``, and make them parameters of this evaluator:: import nni from nni.nas.evaluator.pytorch.lightning import Lightning, LightningModule, Trainer, DataLoader """ def __init__(self, lightning_module: LightningModule, trainer: Trainer, train_dataloaders: Optional[Any] = None, val_dataloaders: Optional[Any] = None, train_dataloader: Optional[Any] = None, datamodule: Optional[pl.LightningDataModule] = None, fit_kwargs: Optional[Dict[str, Any]] = None, detect_interrupt: bool = True): assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' if train_dataloader is not None: warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) train_dataloaders = train_dataloader if not (isinstance(trainer, pl.Trainer) and is_traceable(trainer)): raise TypeError(f'Trainer must be imported from {__name__}, but found {trainer.__class__.__qualname__}') if not _check_dataloader(train_dataloaders): warnings.warn(f'When using training service to spawn trials, please try to wrap PyTorch DataLoader with nni.trace or ' f'import DataLoader from {__name__}: {train_dataloaders}', RuntimeWarning) if not _check_dataloader(val_dataloaders): warnings.warn(f'When using training service to spawn trials, please try to wrap PyTorch DataLoader with nni.trace or ' f'import DataLoader from {__name__}: {val_dataloaders}', RuntimeWarning) self.module = lightning_module self.trainer = trainer self.train_dataloaders = train_dataloaders self.val_dataloaders = val_dataloaders self.datamodule = datamodule self.fit_kwargs = fit_kwargs or {} self.detect_interrupt = detect_interrupt def evaluate(self, model): """ Fit the model with provided dataloader, with Lightning trainer. If ``train_dataloaders`` is not provided, ``trainer.validate()`` will be called. Parameters ---------- model The model to fit. """ if self.is_mutable(): raise RuntimeError('Mutable evaluator must first be `freeze()` before evaluation.') self.module.set_model(model) if self.datamodule is not None: _logger.info('Fit with datamodule. Train and valid dataloaders will be ignored.') rv = self.trainer.fit(self.module, self.datamodule, **self.fit_kwargs) elif self.train_dataloaders is None and self.val_dataloaders is not None: _logger.info('Only validation dataloaders are available. Skip to validation.') rv = self.trainer.validate(self.module, self.val_dataloaders, **self.fit_kwargs) else: if self.val_dataloaders is None: _logger.warning('Validation dataloaders are missing. Safe to ignore this warning when using one-shot strategy.') rv = self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs) if self.detect_interrupt: from pytorch_lightning.trainer.states import TrainerStatus if self.trainer.state.status == TrainerStatus.INTERRUPTED: _logger.warning('Trainer status is detected to be interrupted.') raise KeyboardInterrupt('Trainer status is detected to be interrupted.') return rv @property def train_dataloader(self): warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning) def __eq__(self, other): if not isinstance(other, Lightning): return False return self.module == other.module and self.trainer == other.trainer and \ self.train_dataloaders == other.train_dataloaders and self.val_dataloaders == other.val_dataloaders and \ self.fit_kwargs == other.fit_kwargs def __repr__(self): return f'{self.__class__.__name__}({self.module}, {self.trainer}, train_dataloaders={self.train_dataloaders}, ' \ f'val_dataloaders={self.val_dataloaders}, fit_kwargs={self.fit_kwargs})' def fit(self, model): warnings.warn('`fit` is deprecated, please use `evaluate`.', DeprecationWarning) return self.evaluate(model)
def _check_dataloader(dataloader): # Check the type of dataloader recursively. if isinstance(dataloader, list): return all([_check_dataloader(d) for d in dataloader]) if isinstance(dataloader, dict): return all([_check_dataloader(v) for v in dataloader.values()]) if isinstance(dataloader, torch_data.DataLoader): return is_traceable(dataloader) return True ### The following are some commonly used Lightning modules ### class SupervisedLearningModule(LightningModule): trainer: pl.Trainer def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric], learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, export_onnx: Union[Path, str, bool, None] = None): super().__init__() self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay') self.criterion = criterion() self.optimizer = optimizer self.metrics = nn.ModuleDict(metrics) if export_onnx is None or export_onnx is True: self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx' elif export_onnx: self.export_onnx = Path(export_onnx) else: self.export_onnx = None def forward(self, x): y_hat = self.model(x) return y_hat def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = self.criterion(y_hat, y) self.log('train_loss', loss, prog_bar=True) for name, metric in self.metrics.items(): self.log('train_' + name, metric(y_hat, y), prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) if self.export_onnx is not None: self.export_onnx.parent.mkdir(exist_ok=True) try: self.to_onnx(self.export_onnx, x, export_params=True) except RuntimeError as e: warnings.warn(f'ONNX conversion failed. As a result, you might not be able to use visualization. Error message: {e}') self.export_onnx = None self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) for name, metric in self.metrics.items(): self.log('val_' + name, metric(y_hat, y), prog_bar=True) def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) self.log('test_loss', self.criterion(y_hat, y), prog_bar=True) for name, metric in self.metrics.items(): self.log('test_' + name, metric(y_hat, y), prog_bar=True) def configure_optimizers(self): return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore def on_validation_epoch_end(self): if nni.get_current_parameter() is not None and not self.trainer.sanity_checking: # Don't report metric when sanity checking nni.report_intermediate_result(self._get_result_for_report()) def on_fit_end(self): # Inline import to avoid errors with unsupported lightning version from pytorch_lightning.trainer.states import TrainerFn if self.trainer.state.fn == TrainerFn.FITTING: self._final_report() def on_validation_end(self): from pytorch_lightning.trainer.states import TrainerFn if self.trainer.state.fn == TrainerFn.VALIDATING: self._final_report() def _final_report(self): if nni.get_current_parameter() is not None: nni.report_final_result(self._get_result_for_report()) def _get_result_for_report(self): stage = 'val' if not self.trainer.val_dataloaders: _logger.debug('No validation dataloader. Use results on training set instead.') stage = 'train' if len(self.metrics) == 1: metric_name = next(iter(self.metrics)) return self.trainer.callback_metrics[f'{stage}_{metric_name}'].item() else: warnings.warn('Multiple metrics without "default" is not supported by current framework.') return {name: self.trainer.callback_metrics[f'{stage}_{name}'].item() for name in self.metrics} class _AccuracyWithLogits(torchmetrics.Accuracy): # Only for torchmetrics < 0.11 def update(self, pred, target): return super().update(nn_functional.softmax(pred, dim=-1), target) # type: ignore
[文档] @nni.trace class ClassificationModule(SupervisedLearningModule): def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, export_onnx: bool = False, num_classes: Optional[int] = None): from packaging.version import Version if Version(torchmetrics.__version__) < Version('0.11.0'): # Older version accepts num_classes = None metrics = {'acc': _AccuracyWithLogits()} # type: ignore # pylint: disable=no-value-for-parameter else: if num_classes is None: raise ValueError('num_classes must be specified for torchmetrics >= 0.11. ' 'Please either specify it or use an older version of torchmetrics.') metrics = {'acc': torchmetrics.Accuracy('multiclass', num_classes=num_classes)} super().__init__(criterion, metrics, # type: ignore learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
[文档] @nni.trace class Classification(Lightning): """ Evaluator that is used for classification. Available callback metrics in :class:`Classification` are: - train_loss - train_acc - val_loss - val_acc Parameters ---------- criterion : nn.Module Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss`` learning_rate : float Learning rate. default: 0.001 weight_decay : float L2 weight decay. default: 0 optimizer : Optimizer Class for optimizer (not an instance). default: ``Adam`` train_dataloaders : DataLoader Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. val_dataloaders : DataLoader or List of DataLoader Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. datamodule Used in ``trainer.fit()``. See `Lightning DataModule <https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html>`__. export_onnx : bool If true, model will be exported to ``model.onnx`` before training starts. default true num_classes : int Number of classes for classification task. Required for torchmetrics >= 0.11.0. default: None trainer_kwargs : dict Optional keyword arguments passed to trainer. See `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. Examples -------- >>> evaluator = Classification() To use customized criterion and optimizer: >>> evaluator = Classification(nn.LabelSmoothingCrossEntropy, optimizer=torch.optim.SGD) Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration: >>> evaluator = Classification(accelerator='gpu', devices=2, strategy='ddp') """ def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, train_dataloaders: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, datamodule: Optional[pl.LightningDataModule] = None, export_onnx: bool = False, train_dataloader: Optional[DataLoader] = None, num_classes: Optional[int] = None, **trainer_kwargs): if train_dataloader is not None: warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) train_dataloaders = train_dataloader module = ClassificationModule(criterion=criterion, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx, num_classes=num_classes) super().__init__(module, Trainer(**trainer_kwargs), train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule)
[文档] @nni.trace class RegressionModule(SupervisedLearningModule): def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, export_onnx: bool = False): super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError()}, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
[文档] @nni.trace class Regression(Lightning): """ Evaluator that is used for regression. Available callback metrics in :class:`Regression` are: - train_loss - train_mse - val_loss - val_mse Parameters ---------- criterion : nn.Module Class for criterion module (not an instance). default: ``nn.MSELoss`` learning_rate : float Learning rate. default: 0.001 weight_decay : float L2 weight decay. default: 0 optimizer : Optimizer Class for optimizer (not an instance). default: ``Adam`` train_dataloaders : DataLoader Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. val_dataloaders : DataLoader or List of DataLoader Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. datamodule Used in ``trainer.fit()``. See `Lightning DataModule <https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html>`__. export_onnx : bool If true, model will be exported to ``model.onnx`` before training starts. default: true trainer_kwargs : dict Optional keyword arguments passed to trainer. See `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. Examples -------- >>> evaluator = Regression() Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration: >>> evaluator = Regression(gpus=1) """ def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: Type[optim.Optimizer] = optim.Adam, train_dataloaders: Optional[DataLoader] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, datamodule: Optional[pl.LightningDataModule] = None, export_onnx: bool = False, train_dataloader: Optional[DataLoader] = None, **trainer_kwargs): if train_dataloader is not None: warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) train_dataloaders = train_dataloader module = RegressionModule(criterion=criterion, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) super().__init__(module, Trainer(**trainer_kwargs), train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule)
# Alias for backwards compatibility _SupervisedLearningModule = SupervisedLearningModule _ClassificationModule = ClassificationModule _RegressionModule = RegressionModule