Source code for nni.compression.torch.pruning.one_shot

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from schema import And, Optional
from .constants import MASKER_DICT
from ..utils.config_validation import CompressorSchema
from ..compressor import Pruner

__all__ = ['LevelPruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner', \
    'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']

logger = logging.getLogger('torch pruner')

class OneshotPruner(Pruner):
    """
    Prune model to an exact pruning level for one time.
    """

    def __init__(self, model, config_list, pruning_algorithm='level', optimizer=None, **algo_kwargs):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            List on pruning configs
        pruning_algorithm: str
            algorithms being used to prune model
        optimizer: torch.optim.Optimizer
            Optimizer used to train model
        algo_kwargs: dict
            Additional parameters passed to pruning algorithm masker class
        """

        super().__init__(model, config_list, optimizer)
        self.set_wrappers_attribute("if_calculated", False)
        self.masker = MASKER_DICT[pruning_algorithm](model, self, **algo_kwargs)

    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
        model : torch.nn.Module
            Model to be pruned
        config_list : list
            List on pruning configs
        """
        schema = CompressorSchema([{
            'sparsity': And(float, lambda n: 0 < n < 1),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)

    def calc_mask(self, wrapper, wrapper_idx=None):
        """
        Calculate the mask of given layer
        Parameters
        ----------
        wrapper : Module
            the module to instrument the compression operation
        wrapper_idx: int
            index of this wrapper in pruner's all wrappers
        Returns
        -------
        dict
            dictionary for storing masks, keys of the dict:
            'weight_mask':  weight mask tensor
            'bias_mask': bias mask tensor (optional)
        """
        if wrapper.if_calculated:
            return None

        sparsity = wrapper.config['sparsity']
        if not wrapper.if_calculated:
            masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)

            # masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later
            if masks is not None:
                wrapper.if_calculated = True
            return masks
        else:
            return None

[docs]class LevelPruner(OneshotPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Operation types to prune. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer)
[docs]class SlimPruner(OneshotPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only BatchNorm2d is supported in Slim Pruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer) def validate_config(self, model, config_list): schema = CompressorSchema([{ 'sparsity': And(float, lambda n: 0 < n < 1), 'op_types': ['BatchNorm2d'], Optional('op_names'): [str] }], model, logger) schema.validate(config_list) if len(config_list) > 1: logger.warning('Slim pruner only supports 1 configuration')
class _StructuredFilterPruner(OneshotPruner): def __init__(self, model, config_list, pruning_algorithm, optimizer=None, **algo_kwargs): super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, optimizer=optimizer, **algo_kwargs) def validate_config(self, model, config_list): schema = CompressorSchema([{ 'sparsity': And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], Optional('op_names'): [str] }], model, logger) schema.validate(config_list)
[docs]class L1FilterPruner(_StructuredFilterPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in L1FilterPruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer)
[docs]class L2FilterPruner(_StructuredFilterPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in L2FilterPruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer)
[docs]class FPGMPruner(_StructuredFilterPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in FPGM Pruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer)
[docs]class TaylorFOWeightFilterPruner(_StructuredFilterPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : How much percentage of convolutional filters are to be pruned. - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1): super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num)
[docs]class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : How much percentage of convolutional filters are to be pruned. - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \ activation=activation, statistics_batch_num=statistics_batch_num)
[docs]class ActivationMeanRankFilterPruner(_StructuredFilterPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : How much percentage of convolutional filters are to be pruned. - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner. optimizer: torch.optim.Optimizer Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \ activation=activation, statistics_batch_num=statistics_batch_num)