Source code for nni.algorithms.compression.v2.pytorch.pruning.iterative_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Dict, List, Callable, Optional

from torch import Tensor
from torch.nn import Module

from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper

from .basic_pruner import (
    LevelPruner,
    L1NormPruner,
    L2NormPruner,
    FPGMPruner,
    SlimPruner,
    ActivationAPoZRankPruner,
    ActivationMeanRankPruner,
    TaylorFOWeightPruner,
    ADMMPruner
)
from .basic_scheduler import PruningScheduler
from .tools import (
    LinearTaskGenerator,
    AGPTaskGenerator,
    LotteryTicketTaskGenerator,
    SimulatedAnnealingTaskGenerator
)

_logger = logging.getLogger(__name__)

__all__ = ['LinearPruner', 'AGPPruner', 'LotteryTicketPruner', 'SimulatedAnnealingPruner']


PRUNER_DICT = {
    'level': LevelPruner,
    'l1': L1NormPruner,
    'l2': L2NormPruner,
    'fpgm': FPGMPruner,
    'slim': SlimPruner,
    'apoz': ActivationAPoZRankPruner,
    'mean_activation': ActivationMeanRankPruner,
    'taylorfo': TaylorFOWeightPruner,
    'admm': ADMMPruner
}


class IterativePruner(PruningScheduler):
    def _wrap_model(self):
        """
        Deprecated function.
        """
        _logger.warning('Nothing will happen when calling this function.\
            This pruner is an iterative pruner and does not directly wrap the model.')

    def _unwrap_model(self):
        """
        Deprecated function.
        """
        _logger.warning('Nothing will happen when calling this function.\
            This pruner is an iterative pruner and does not directly wrap the model.')

    def export_model(self, *args, **kwargs):
        """
        Deprecated function.
        """
        _logger.warning('Nothing will happen when calling this function.\
            The best result (and intermediate result if keeped) during iteration is under `log_dir` (default: \\.).')


[docs]class LinearPruner(IterativePruner): """ Parameters ---------- model : Module The origin unwrapped pytorch model to be pruned. config_list : List[Dict] The origin config list provided by the user. pruning_algorithm : str Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm']. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration. total_iteration : int The total iteration number. log_dir : str The log directory use to saving the result, you can find the best result under this folder. keep_intermediate_result : bool If keeping the intermediate result, including intermediate model and masks during each iteration. finetuner : Optional[Callable[[Module], None]] The finetuner handled all finetune logic, use a pytorch module as input. It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration. speed_up : bool If set True, speed up the model at the end of each iteration to make the pruned model compact. dummy_input : Optional[torch.Tensor] If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. evaluator : Optional[Callable[[Module], float]] Evaluate the pruned model and give a score. If evaluator is None, the best result refers to the latest result. pruning_params : Dict If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. """ def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}): task_generator = LinearTaskGenerator(total_iteration=total_iteration, origin_model=model, origin_config_list=config_list, log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) if 'traced_optimizer' in pruning_params: pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, evaluator=evaluator, reset_weight=False)
[docs]class AGPPruner(IterativePruner): """ Parameters ---------- model : Module The origin unwrapped pytorch model to be pruned. config_list : List[Dict] The origin config list provided by the user. pruning_algorithm : str Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm']. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration. total_iteration : int The total iteration number. log_dir : str The log directory use to saving the result, you can find the best result under this folder. keep_intermediate_result : bool If keeping the intermediate result, including intermediate model and masks during each iteration. finetuner : Optional[Callable[[Module], None]] The finetuner handled all finetune logic, use a pytorch module as input. It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration. speed_up : bool If set True, speed up the model at the end of each iteration to make the pruned model compact. dummy_input : Optional[torch.Tensor] If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. evaluator : Optional[Callable[[Module], float]] Evaluate the pruned model and give a score. If evaluator is None, the best result refers to the latest result. pruning_params : Dict If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. """ def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}): task_generator = AGPTaskGenerator(total_iteration=total_iteration, origin_model=model, origin_config_list=config_list, log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) if 'traced_optimizer' in pruning_params: pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, evaluator=evaluator, reset_weight=False)
[docs]class LotteryTicketPruner(IterativePruner): """ Parameters ---------- model : Module The origin unwrapped pytorch model to be pruned. config_list : List[Dict] The origin config list provided by the user. pruning_algorithm : str Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm']. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration. total_iteration : int The total iteration number. log_dir : str The log directory use to saving the result, you can find the best result under this folder. keep_intermediate_result : bool If keeping the intermediate result, including intermediate model and masks during each iteration. finetuner : Optional[Callable[[Module], None]] The finetuner handled all finetune logic, use a pytorch module as input. It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise. speed_up : bool If set True, speed up the model at the end of each iteration to make the pruned model compact. dummy_input : Optional[torch.Tensor] If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. evaluator : Optional[Callable[[Module], float]] Evaluate the pruned model and give a score. If evaluator is None, the best result refers to the latest result. reset_weight : bool If set True, the model weight will reset to the original model weight at the end of each iteration step. pruning_params : Dict If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. """ def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True, pruning_params: Dict = {}): task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration, origin_model=model, origin_config_list=config_list, log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) if 'traced_optimizer' in pruning_params: pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, evaluator=evaluator, reset_weight=reset_weight)
[docs]class SimulatedAnnealingPruner(IterativePruner): """ Parameters ---------- model : Module The origin unwrapped pytorch model to be pruned. config_list : List[Dict] The origin config list provided by the user. evaluator : Callable[[Module], float] Evaluate the pruned model and give a score. start_temperature : float Start temperature of the simulated annealing process. stop_temperature : float Stop temperature of the simulated annealing process. cool_down_rate : float Cool down rate of the temperature. perturbation_magnitude : float Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature. pruning_algorithm : str Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm']. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration. pruning_params : Dict If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. log_dir : str The log directory use to saving the result, you can find the best result under this folder. keep_intermediate_result : bool If keeping the intermediate result, including intermediate model and masks during each iteration. finetuner : Optional[Callable[[Module], None]] The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration. speed_up : bool If set True, speed up the model at the end of each iteration to make the pruned model compact. dummy_input : Optional[torch.Tensor] If `speed_up` is True, `dummy_input` is required for tracing the model in speed up. """ def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None): task_generator = SimulatedAnnealingTaskGenerator(origin_model=model, origin_config_list=config_list, start_temperature=start_temperature, stop_temperature=stop_temperature, cool_down_rate=cool_down_rate, perturbation_magnitude=perturbation_magnitude, log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) if 'traced_optimizer' in pruning_params: pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, evaluator=evaluator, reset_weight=False)