Source code for nni.algorithms.compression.v2.pytorch.pruning.auto_compress_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional

from torch import Tensor
from torch.nn import Module

from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper

from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator

_logger = logging.getLogger(__name__)


class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
    def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
                 origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
                 keep_intermediate_result: bool = False):
        self.iterative_pruner = SimulatedAnnealingPruner(model=None,
                                                         config_list=None,
                                                         log_dir=Path(log_dir, 'SA'),
                                                         **sa_params)
        super().__init__(total_iteration=total_iteration,
                         origin_model=origin_model,
                         origin_config_list=origin_config_list,
                         origin_masks=origin_masks,
                         log_dir=log_dir,
                         keep_intermediate_result=keep_intermediate_result)

    def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
        # TODO: replace with validation here
        for config in config_list:
            if 'sparsity' in config or 'sparsity_per_layer' in config:
                _logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
        return super().reset(model, config_list, masks)

    def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
        self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
        self.iterative_pruner.reset(model, config_list=config_list, masks=masks)

    def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
        self._iterative_pruner_reset(model, new_config_list, masks)
        self.iterative_pruner.compress()
        _, _, _, _, config_list = self.iterative_pruner.get_best_result()
        return config_list


[docs]class AutoCompressPruner(IterativePruner): """ Parameters ---------- model : Module The origin unwrapped pytorch model to be pruned. config_list : List[Dict] The origin config list provided by the user. total_iteration : int The total iteration number. evaluator : Callable[[Module], float] Evaluate the pruned model and give a score. admm_params : Dict The parameters passed to the ADMMPruner. - trainer : Callable[[Module, Optimizer, Callable]. A callable function used to train model or just inference. Take model, optimizer, criterion as input. The model will be trained or inferenced `training_epochs` epochs. - traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer) The traced optimizer instance which the optimizer class is wrapped by nni.trace. E.g. traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters()). - criterion : Callable[[Tensor, Tensor], Tensor]. The criterion function used in trainer. Take model output and target value as input, and return the loss. - iterations : int. The total iteration number in admm pruning algorithm. - training_epochs : int. The epoch number for training model in each iteration. sa_params : Dict The parameters passed to the SimulatedAnnealingPruner. - evaluator : Callable[[Module], float]. Required. Evaluate the pruned model and give a score. - start_temperature : float. Default: `100`. Start temperature of the simulated annealing process. - stop_temperature : float. Default: `20`. Stop temperature of the simulated annealing process. - cool_down_rate : float. Default: `0.9`. Cooldown rate of the temperature. - perturbation_magnitude : float. Default: `0.35`. Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature. - pruning_algorithm : str. Default: `'level'`. Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm']. - pruning_params : Dict. Default: `{}`. If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. log_dir : str The log directory used to save the result, you can find the best result under this folder. keep_intermediate_result : bool If keeping the intermediate result, including intermediate model and masks during each iteration. finetuner : Optional[Callable[[Module], None]] The finetuner handles all finetune logic, takes a pytorch module as input. It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration. speed_up : bool If set True, speed up the model at the end of each iteration to make the pruned model compact. dummy_input : Optional[torch.Tensor] If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. """ def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict, sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Callable[[Module], float] = None): task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration, origin_model=model, origin_config_list=config_list, sa_params=sa_params, log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) if 'traced_optimizer' in admm_params: admm_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, admm_params['traced_optimizer']) pruner = ADMMPruner(None, None, **admm_params) super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, evaluator=evaluator, reset_weight=False)