Evaluator¶
TorchEvaluator¶
- class nni.compression.TorchEvaluator(training_func, optimizers, training_step, lr_schedulers=None, dummy_input=None, evaluating_func=None)[source]¶
TorchEvaluator is the Evaluator for native PyTorch users. Please refer to the Compression Evaluator for the evaluator initialization example.
- Parameters:
training_func (_TRAINING_FUNC) –
The training function is used to train the model, note that this a entire optimization training loop. Training function has three required parameters,
model
,optimizers
andtraining_step
, and three optional parameters,lr_schedulers
,max_steps
,max_epochs
.Let’s explain these six parameters NNI passed in, but in most cases, users don’t need to care about these. Users only need to treat these six parameters as the original parameters during the training process.
The
model
is a wrapped model from the original model, it has a similar structure to the model to be pruned, so it can share training function with the original model.optimizers
are re-initialized from theoptimizers
passed to the evaluator and the wrapped model’s parameters.training_step
also based on thetraining_step
passed to the evaluator, it might be modified by the compressor during model compression.If users use
lr_schedulers
in thetraining_func
, NNI will re-initialize thelr_schedulers
with the re-initialized optimizers.max_steps
is the NNI training duration limitation. It is for pruner (or quantizer) to control the number of training steps. The user implementedtraining_func
should respectmax_steps
by stopping the training loop aftermax_steps
is reached. Pruner may passNone
tomax_steps
when it only controlsmax_epochs
.max_epochs
is similar to themax_steps
, the only different is that it controls the number of training epochs. The user implementedtraining_func
should respectmax_epochs
by stopping the training loop aftermax_epochs
is reached. Pruner may passNone
tomax_epochs
when it only controlsmax_steps
.
Note that when the pruner passes
None
to bothmax_steps
andmax_epochs
, it treatstraining_func
as a function of model fine-tuning. Users should assign proper values tomax_steps
andmax_epochs
.def training_func(model: torch.nn.Module, optimizers: torch.optim.Optimizer, training_step: Callable[[Any, Any], torch.Tensor], lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None, max_epochs: int | None = None, *args, **kwargs): ... total_epochs = max_epochs if max_epochs else 20 total_steps = max_steps if max_steps else 1000000 current_steps = 0 ... for epoch in range(total_epochs): ... if current_steps >= total_steps: return
Note that
optimizers
andlr_schedulers
passed to thetraining_func
have the same type as theoptimizers
andlr_schedulers
passed to evaluator, a singletorch.optim.Optimzier
/torch.optim._LRScheduler
instance or a list of them.optimziers –
A single traced optimizer instance or a list of traced optimizers by
nni.trace
.NNI may modify the
torch.optim.Optimizer
member functionstep
and/or optimize compressed models, so NNI needs to have the ability to re-initialize the optimizer.nni.trace
can record the initialization parameters of a function/class, which can then be used by NNI to re-initialize the optimizer for a new but structurally similar model.E.g.
traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())
.training_step (_TRAINING_STEP) –
A callable function, the first argument of inputs should be
batch
, and the outputs should contain loss. Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a keyloss
.def training_step(batch, model, ...): inputs, labels = batch output = model(inputs) ... loss = loss_func(output, labels) return loss
lr_schedulers (SCHEDULER | List[SCHEDULER] | None) –
Optional. A single traced lr_scheduler instance or a list of traced lr_schedulers by
nni.trace
. For the same reason withoptimizers
, NNI needs the traced lr_scheduler to re-initialize it.E.g.
traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
.dummy_input (Any | None) – Optional. The dummy_input is used to trace the graph, it’s same with
example_inputs
in torch.jit.trace.evaluating_func (_EVALUATING_FUNC | None) – Optional. A function that input is model and return the evaluation metric. This is the function used to evaluate the compressed model performance. The input is a model and the output is a
float
metric or adict
(dict
should contains keydefault
with afloat
value). NNI will take the float number as the model score, and assume the higher score means the better performance. If you want to provide additional information, please put it into a dict and NNI will take the value of keydefault
as evaluation metric.
Notes
It is also worth to note that not all the arguments of
TorchEvaluator
must be provided. Some pruners (or quantizers) only requireevaluating_func
as they do not train the model, some pruners (or quantizers) only requiretraining_func
. Please refer to each pruner’s (or quantizer’s) doc to check the required arguments. But, it is fine to provide more arguments than the pruner’s (or quantizer’s) need.
LightningEvaluator¶
- class nni.compression.LightningEvaluator(trainer, data_module, dummy_input=None)[source]¶
LightningEvaluator is the Evaluator based on PyTorchLightning. It is very friendly to the users who are familiar to PyTorchLightning or already have training/validation/testing code written in PyTorchLightning. The only need is to use
nni.trace
to trace the Trainer & LightningDataModule.Additionally, please make sure the
Optimizer
class andLR_Scheduler
class used inLightningModule.configure_optimizers()
are also be traced bynni.trace
.Please refer to the Compression Evaluator for the evaluator initialization example.
- Parameters:
trainer (pl.Trainer) – Pytorch-Lightning Trainer. It should be traced by nni, e.g.,
trainer = nni.trace(pl.Trainer)(...)
.data_module (pl.LightningDataModule) – Pytorch-Lightning LightningDataModule. It should be traced by nni, e.g.,
data_module = nni.trace(pl.LightningDataModule)(...)
.dummy_input (Any | None) – The dummy_input is used to trace the graph. If dummy_input is not given, will use the data in data_module.train_dataloader().
Notes
If the the test metric is needed by nni, please make sure log metric with key
default
inLightningModule.test_step()
.
TransformersEvaluator¶
- class nni.compression.TransformersEvaluator(trainer, dummy_input=None)[source]¶
TransformersEvaluator is for the users who using Huggingface
transformers.trainer.Trainer
.Here is an example for using
transformers.trainer.Trainer
to initialize an evaluator:from transformers.trainer import Trainer # wrap Trainer class with nni.trace trainer = nni.trace(Trainer)(model=model) evaluator = TransformersEvaluator(trainer) # if you want to using customized optimizer & lr_scheduler, please also wrap Optimzier & _LRScheduler class optimizer = nni.trace(Adam)(...) lr_scheduler = nni.trace(LambdaLR)(...) trainer = nni.trace(Trainer)(model=model, ..., optimizers=(optimizer, lr_scheduler)) evaluator = TransformersEvaluator(trainer)
- Parameters:
trainer (HFTrainer) –
nni.trace(transformers.trainer.Trainer)
instance. The trainer will be re-initialized inside evaluator, so wrap withnni.trace
is required for getting the initialization arguments.dummy_input (Any | None) –
Optional. The dummy_input is used to trace the graph, it’s same with
example_inputs
in torch.jit.trace.
DeepspeedTorchEvaluator¶
- class nni.compression.DeepspeedTorchEvaluator(training_func, training_step, deepspeed, optimizer=None, lr_scheduler=None, resume_from_checkpoint_args=None, dummy_input=None, evaluating_func=None)[source]¶
The DeepseedTorchEvaluator is an evaluator designed specifically for native PyTorch users who are utilizing DeepSpeed.
- Parameters:
training_func (_TRAINING_FUNC) –
The training function is used to train the model, note that this a entire optimization training loop. Training function has three required parameters,
model
,optimizers
andtraining_step
, and three optional parameters,lr_schedulers
,max_steps
,max_epochs
.Let’s explain these six parameters NNI passed in, but in most cases, users don’t need to care about these. Users only need to treat these six parameters as the original parameters during the training process.
The
model
is a wrapped model from the original model, it has a similar structure to the model to be pruned, so it can share training function with the original model.optimizers
are re-initialized from theoptimizers
passed to the evaluator and the wrapped model’s parameters.training_step
also based on thetraining_step
passed to the evaluator, it might be modified by the compressor during model compression.If users use
lr_schedulers
in thetraining_func
, NNI will re-initialize thelr_schedulers
with the re-initialized optimizers.max_steps
is the NNI training duration limitation. It is for pruner (or quantizer) to control the number of training steps. The user implementedtraining_func
should respectmax_steps
by stopping the training loop aftermax_steps
is reached. Pruner may passNone
tomax_steps
when it only controlsmax_epochs
.max_epochs
is similar to themax_steps
, the only different is that it controls the number of training epochs. The user implementedtraining_func
should respectmax_epochs
by stopping the training loop aftermax_epochs
is reached. Pruner may passNone
tomax_epochs
when it only controlsmax_steps
.
Note that when the pruner passes
None
to bothmax_steps
andmax_epochs
, it treatstraining_func
as a function of model fine-tuning. Users should assign proper values tomax_steps
andmax_epochs
.def training_func(model: DeepSpeedEngine, optimizers: torch.optim.Optimizer, training_step: Callable[[Any, Any], torch.Tensor], lr_schedulers: _LRScheduler | None = None, max_steps: int | None = None, max_epochs: int | None = None, *args, **kwargs): ... total_epochs = max_epochs if max_epochs else 20 total_steps = max_steps if max_steps else 1000000 current_steps = 0 ... for epoch in range(total_epochs): ... model.backward(loss) model.step() if current_steps >= total_steps: return
Note that
optimizers
andlr_schedulers
passed to thetraining_func
have the same type as theoptimizers
andlr_schedulers
passed to evaluator, a singletorch.optim.Optimzier
/torch.optim._LRScheduler
instance or a list of them.training_step (_TRAINING_STEP) –
A callable function, the first argument of inputs should be
batch
, and the outputs should contain loss. Three kinds of outputs are supported: single loss, tuple with the first element is loss, a dict contains a keyloss
.def training_step(batch, model, ...): inputs, labels = batch output = model(inputs) ... loss = loss_func(output, labels) return loss
deepspeed (str | Dict) – Str | dict. The deepspeed configuration which Contains the parameters needed in DeepSpeed, such as train_batch_size, among others.
optimzier – Optional. A single traced optimizer instance or a function that takes the model parameters as input and returns an optimizer instance. NNI may modify the
torch.optim.Optimizer
member functionstep
and/or optimize compressed models, so NNI needs to have the ability to re-initialize the optimizer.nni.trace
can record the initialization parameters of a function/class, which can then be used by NNI to re-initialize the optimizer for a new but structurally similar model. E.g.traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())
.lr_schedulers – Optional. A single traced lr_scheduler instance or a function that takes the model parameters and the optimizer as input and returns an lr_scheduler instance. For the same reason with
optimizers
, NNI needs the traced lr_scheduler to re-initialize it. E.g.traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
.resume_from_checkpoint_args (Dict | None) –
Dict | None. Used in the deepspeed_init process to load models saved during training with DeepSpeed. Let’s explain these seven elements in the resume_from_checkpoint_args.
load_dir
: The directory to load the checkpoint from.tag
: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in ‘latest’ fileload_module_strict
: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.load_optimizer_states
: Optional. Boolean to load the training optimizer states from Checkpoint.load_lr_scheduler_states
: Optional. Boolean to add the learning rate scheduler states from Checkpoint.load_module_only
: Optional. Boolean to load only the model weights from the checkpoint.custom_load_fn
: Optional. Custom model load function.
dummy_input (Any | None) –
Optional. The dummy_input is used to trace the graph, it’s same with
example_inputs
in torch.jit.trace.evaluating_func (_EVALUATING_FUNC | None) – Optional. A function that input is model and return the evaluation metric. This is the function used to evaluate the compressed model performance. The input is a model and the output is a
float
metric or adict
(dict
should contains keydefault
with afloat
value). NNI will take the float number as the model score, and assume the higher score means the better performance. If you want to provide additional information, please put it into a dict and NNI will take the value of keydefault
as evaluation metric.
Notes
It is also worth to note that not all the arguments of
DeepspeedTorchEvaluator
must be provided. Some pruners (or quantizers) only requireevaluating_func
as they do not train the model, some pruners (or quantizers) only requiretraining_func
. Please refer to each pruner’s (or quantizer’s) doc to check the required arguments. But, it is fine to provide more arguments than the pruner’s (or quantizer’s) need.