nni.nas.execution.cgo.evaluator 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

__all__ = ['MultiModelLightningModule', 'MultiModelTrainer']

import pytorch_lightning
import torch
from pytorch_lightning.strategies import SingleDeviceStrategy
from torch import nn
from torchmetrics import Metric

import nni
from nni.nas.evaluator.pytorch.lightning import LightningModule


[文档] class MultiModelLightningModule(LightningModule): """The lightning module for a merged "multi-model". The output of the multi-model is expected to be a tuple of tensors. The tensors will be each passed to a criterion and a metric. The loss will be added up for back propagation, and the metrics will be logged. The reported metric will be a list of metrics, one for each model. Parameters ---------- criterion Loss function. metric Metric function. n_models Number of models in the multi-model. """ def __init__(self, criterion: nn.Module, metric: Metric, n_models: int | None = None): super().__init__() self.criterion = criterion self.metric = metric self.n_models = n_models def _dump(self) -> dict: return { 'criterion': self.criterion, 'metric': self.metric, 'n_models': self.n_models, } @staticmethod def _load(criterion: nn.Module, metric: Metric, n_models: int | None = None) -> MultiModelLightningModule: return MultiModelLightningModule(criterion, metric, n_models) def forward(self, x): y_hat = self.model(x) return y_hat def training_step(self, batch, batch_idx): x, y = batch multi_y_hat = self(x) if isinstance(multi_y_hat, tuple): assert len(multi_y_hat) == self.n_models else: assert self.n_models == 1 multi_y_hat = [multi_y_hat] multi_loss = [] for idx, y_hat in enumerate(multi_y_hat): loss = self.criterion(y_hat.to("cpu"), y.to("cpu")) self.log(f'train_loss_{idx}', loss, prog_bar=True) self.log(f'train_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True) multi_loss.append(loss) return sum(multi_loss) def validation_step(self, batch, batch_idx): x, y = batch multi_y_hat = self(x) if isinstance(multi_y_hat, tuple): assert len(multi_y_hat) == self.n_models else: assert self.n_models == 1 multi_y_hat = [multi_y_hat] for idx, y_hat in enumerate(multi_y_hat): self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True) self.log(f'val_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True) def test_step(self, batch, batch_idx): x, y = batch multi_y_hat = self(x) if isinstance(multi_y_hat, tuple): assert len(multi_y_hat) == self.n_models else: assert self.n_models == 1 multi_y_hat = [multi_y_hat] for idx, y_hat in enumerate(multi_y_hat): self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True) self.log(f'test_metric_{idx}', self.metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) def on_validation_epoch_end(self): nni.report_intermediate_result(self._get_validation_metrics()) # type: ignore def teardown(self, stage): if stage == 'fit': nni.report_final_result(self._get_validation_metrics()) # type: ignore def _get_validation_metrics(self): # TODO: split metric of multiple models? assert self.n_models is not None return [self.trainer.callback_metrics[f'val_metric_{idx}'].item() for idx in range(self.n_models)]
class _BypassStrategy(SingleDeviceStrategy): strategy_name = "single_device" def model_to_device(self) -> None: pass
[文档] @nni.trace class MultiModelTrainer(pytorch_lightning.Trainer): """ Trainer for cross-graph optimization. Parameters ---------- use_cgo Whether cross-graph optimization (CGO) is used. If it is True, CGO will manage device placement. Any device placement from pytorch lightning will be bypassed. default: False trainer_kwargs Optional keyword arguments passed to trainer. See `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. """ def __init__(self, use_cgo: bool = True, **trainer_kwargs): if use_cgo: # Accelerator and strategy can be both set at lightning 2.0. # if "accelerator" in trainer_kwargs: # raise ValueError("accelerator should not be set when cross-graph optimization is enabled.") if 'strategy' in trainer_kwargs: raise ValueError("MultiModelTrainer does not support specifying strategy") trainer_kwargs['strategy'] = _BypassStrategy() super().__init__(**trainer_kwargs)