Distiller

DynamicLayerwiseDistiller

class nni.compression.distillation.DynamicLayerwiseDistiller(model: Module, config_list: List[Dict], evaluator: Evaluator, teacher_model: Module, teacher_predict: Callable[[Any, Module], Tensor], origin_loss_lambda: float = 1.0)[源代码]
class nni.compression.distillation.DynamicLayerwiseDistiller(model: Module, config_list: List[Dict], evaluator: Evaluator, teacher_model: Module, teacher_predict: Callable[[Any, Module], Tensor], origin_loss_lambda: float = 1.0, existed_wrappers: Dict[str, ModuleWrapper] | None = None)

Each student model distillation target (i.e., the output of a layer in the student model) will link to a list of teacher model distillation targets in this distiller. During distillation, a student target will compute a list of distillation losses with each of its linked teacher targets, then choose the minimum loss in the loss list as current student target distillation loss. The final distillation loss is the sum of each student target distillation loss multiplied by lambda. The final training loss is original loss multiplied by origin_loss_lambda add final distillation loss.

参数:
  • model (torch.nn.Module) -- The student model to be distilled.

  • config_list (List[Dict]) --

    Config list to configure how to distill. Common keys please refer Compression Config Specification.

    Specific keys:

    • 'lambda': By default, 1. This is a scaling factor to control the loss scale, the final loss used during training is (origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i)). Here i represents the i-th distillation target. The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss.

    • 'link': By default, 'auto'. 'auto' or a teacher module name or a list of teacher module names, the module name(s) of teacher module(s) will align with student module(s) configured in this config. If 'auto' is set, will use student module name as the link, usually requires the teacher model and the student model to be isomorphic.

    • 'apply_method': By default, 'mse'. 'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states. 'kl' means the KL loss, usually used to distill logits.

  • evaluator (Evaluator) --

    NNI will use the evaluator to intervene in the model training process, so as to perform training-aware model compression. All training-aware model compression will use the evaluator as the entry for intervention training in the future. Usually you just need to wrap some classes with nni.trace or package the training process as a function to initialize the evaluator. Please refer Compression Evaluator for a full tutorial on how to initialize a evaluator.

    The following are two simple examples, if you use native pytorch, please refer to TorchEvaluator, if you use pytorch_lightning, please refer to LightningEvaluator, if you use huggingface transformer trainer, please refer to TransformersEvaluator:

    # LightningEvaluator example
    import pytorch_lightning
    lightning_trainer = nni.trace(pytorch_lightning.Trainer)(max_epochs=1, max_steps=50, logger=TensorBoardLogger(...))
    lightning_data_module = nni.trace(pytorch_lightning.LightningDataModule)(...)
    
    from nni.compression import LightningEvaluator
    evaluator = LightningEvaluator(lightning_trainer, lightning_data_module)
    
    # TorchEvaluator example
    import torch
    import torch.nn.functional as F
    
    # The user customized `training_step` should follow this paramter signature,
    # the first is `batch`, the second is `model`,
    # and the return value of `training_step` should be loss, or tuple with the first element is loss,
    # or dict with key 'loss'.
    def training_step(batch, model, *args, **kwargs):
        input_data, target = batch
        result = model(input_data)
        return F.nll_loss(result, target)
    
    # The user customized `training_model` should follow this paramter signature,
    # (model, optimizer, `training_step`, lr_scheduler, max_steps, max_epochs, ...),
    # and note that `training_step`` should be defined out of `training_model`.
    def training_model(model, optimizer, training_step, lr_scheduler, max_steps, max_epochs, *args, **kwargs):
        # max_steps, max_epochs might be None, which means unlimited training time,
        # so here we need set a default termination condition (by default, total_epochs=10, total_steps=100000).
        total_epochs = max_epochs if max_epochs else 10
        total_steps = max_steps if max_steps else 100000
        current_step = 0
    
        # init dataloader
        train_dataloader = ...
    
        for epoch in range(total_epochs):
            ...
            for batch in train_dataloader:
                optimizer.zero_grad()
                loss = training_step(batch, model)
                loss.backward()
                optimizer.step()
                current_step += 1
                if current_step >= total_steps:
                    return
            lr_scheduler.step()
    
    import nni
    traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01)
    
    from nni.compression import TorchEvaluator
    evaluator = TorchEvaluator(training_func=training_model, optimziers=traced_optimizer, training_step=training_step)
    
    # TransformersEvaluator example
    from transformers.trainer import Trainer
    trainer = nni.trace(Trainer)(model=model, args=training_args)
    
    from nni.compression import TransformersEvaluator
    evaluator = TransformersEvaluator(trainer)
    

  • teacher_model (torch.nn.Module) -- The distillation teacher model.

  • teacher_predict (Callable[[Any, torch.nn.Module], torch.Tensor]) --

    A callable function with two inputs (batch, model).

    Example:

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)
    

  • origin_loss_lambda (float) -- A scaling factor to control the original loss scale.

Adaptive1dLayerwiseDistiller

class nni.compression.distillation.Adaptive1dLayerwiseDistiller(model: Module, config_list: List[Dict], evaluator: Evaluator, teacher_model: Module, teacher_predict: Callable[[Any, Module], Tensor], origin_loss_lambda: float = 1.0)[源代码]
class nni.compression.distillation.Adaptive1dLayerwiseDistiller(model: Module, config_list: List[Dict], evaluator: Evaluator, teacher_model: Module, teacher_predict: Callable[[Any, Module], Tensor], origin_loss_lambda: float = 1.0, existed_wrappers: Dict[str, ModuleWrapper] | None = None)

This distiller will adaptively align the last dimension between student distillation target and teacher distillation target by adding a trainable torch.nn.Linear between them. (If the last dimensions between student and teacher have already aligned, won't add a new linear layer.)

Note that this distiller need call Adaptive1dLayerwiseDistiller.track_forward(...) first to get the shape of each distillation target to initialize the linear layer before call Adaptive1dLayerwiseDistiller.compress(...).

参数:
  • model (torch.nn.Module) -- The student model to be distilled.

  • config_list (List[Dict]) --

    Config list to configure how to distill. Common keys please refer Compression Config Specification.

    Specific keys:

    • 'lambda': By default, 1. This is a scaling factor to control the loss scale, the final loss used during training is (origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i)). Here i represents the i-th distillation target. The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss.

    • 'link': By default, 'auto'. 'auto' or a teacher module name or a list of teacher module names, the module name(s) of teacher module(s) will align with student module(s) configured in this config. If 'auto' is set, will use student module name as the link, usually requires the teacher model and the student model to be isomorphic.

    • 'apply_method': By default, 'mse'. 'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states. 'kl' means the KL loss, usually used to distill logits.

  • evaluator (Evaluator) --

    NNI will use the evaluator to intervene in the model training process, so as to perform training-aware model compression. All training-aware model compression will use the evaluator as the entry for intervention training in the future. Usually you just need to wrap some classes with nni.trace or package the training process as a function to initialize the evaluator. Please refer Compression Evaluator for a full tutorial on how to initialize a evaluator.

    The following are two simple examples, if you use native pytorch, please refer to TorchEvaluator, if you use pytorch_lightning, please refer to LightningEvaluator, if you use huggingface transformer trainer, please refer to TransformersEvaluator:

    # LightningEvaluator example
    import pytorch_lightning
    lightning_trainer = nni.trace(pytorch_lightning.Trainer)(max_epochs=1, max_steps=50, logger=TensorBoardLogger(...))
    lightning_data_module = nni.trace(pytorch_lightning.LightningDataModule)(...)
    
    from nni.compression import LightningEvaluator
    evaluator = LightningEvaluator(lightning_trainer, lightning_data_module)
    
    # TorchEvaluator example
    import torch
    import torch.nn.functional as F
    
    # The user customized `training_step` should follow this paramter signature,
    # the first is `batch`, the second is `model`,
    # and the return value of `training_step` should be loss, or tuple with the first element is loss,
    # or dict with key 'loss'.
    def training_step(batch, model, *args, **kwargs):
        input_data, target = batch
        result = model(input_data)
        return F.nll_loss(result, target)
    
    # The user customized `training_model` should follow this paramter signature,
    # (model, optimizer, `training_step`, lr_scheduler, max_steps, max_epochs, ...),
    # and note that `training_step`` should be defined out of `training_model`.
    def training_model(model, optimizer, training_step, lr_scheduler, max_steps, max_epochs, *args, **kwargs):
        # max_steps, max_epochs might be None, which means unlimited training time,
        # so here we need set a default termination condition (by default, total_epochs=10, total_steps=100000).
        total_epochs = max_epochs if max_epochs else 10
        total_steps = max_steps if max_steps else 100000
        current_step = 0
    
        # init dataloader
        train_dataloader = ...
    
        for epoch in range(total_epochs):
            ...
            for batch in train_dataloader:
                optimizer.zero_grad()
                loss = training_step(batch, model)
                loss.backward()
                optimizer.step()
                current_step += 1
                if current_step >= total_steps:
                    return
            lr_scheduler.step()
    
    import nni
    traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01)
    
    from nni.compression import TorchEvaluator
    evaluator = TorchEvaluator(training_func=training_model, optimziers=traced_optimizer, training_step=training_step)
    
    # TransformersEvaluator example
    from transformers.trainer import Trainer
    trainer = nni.trace(Trainer)(model=model, args=training_args)
    
    from nni.compression import TransformersEvaluator
    evaluator = TransformersEvaluator(trainer)
    

  • teacher_model (torch.nn.Module) -- The distillation teacher model.

  • teacher_predict (Callable[[Any, torch.nn.Module], torch.Tensor]) --

    A callable function with two inputs (batch, model).

    Example:

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)
    

  • origin_loss_lambda (float) -- A scaling factor to control the original loss scale.