Supported Pruning Algorithms in NNI

NNI provides several pruning algorithms that reproducing from the papers. In pruning v2, NNI split the pruning algorithm into more detailed components. This means users can freely combine components from different algorithms, or easily use a component of their own implementation to replace a step in the original algorithm to implement their own pruning algorithm.

Right now, pruning algorithms with how to generate masks in one step are implemented as pruners, and how to schedule sparsity in each iteration are implemented as iterative pruners.

Pruner

Iterative Pruner

Level Pruner

This is a basic pruner, and in some papers called it magnitude pruning or fine-grained pruning.

It will mask the weight in each specified layer with smaller absolute value by a ratio configured in the config list.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model, config_list)
masked_model, masks = pruner.compress()

User configuration for Level Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.LevelPruner(model: torch.nn.modules.module.Module, config_list: List[Dict])[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • op_types : Operation types to prune.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

L1 Norm Pruner

L1 norm pruner computes the l1 norm of the layer weight on the first dimension, then prune the weight blocks on this dimension with smaller l1 norm values. i.e., compute the l1 norm of the filters in convolution layer as metric values, compute the l1 norm of the weight by rows in linear layer as metric values.

For more details, please refer to PRUNING FILTERS FOR EFFICIENT CONVNETS.

In addition, L1 norm pruner also supports dependency-aware mode.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = L1NormPruner(model, config_list)
masked_model, masks = pruner.compress()

User configuration for L1 Norm Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.L1NormPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], mode: str = 'normal', dummy_input: Optional[torch.Tensor] = None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • op_types : Conv2d and Linear are supported in L1NormPruner.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

  • mode (str) – ‘normal’ or ‘dependency_aware’. If prune the model in a dependency-aware way, this pruner will prune the model according to the l1-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if set ‘dependency_aware’ , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (Optional[torch.Tensor]) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

L2 Norm Pruner

L2 norm pruner is a variant of L1 norm pruner. It uses l2 norm as metric to determine which weight elements should be pruned.

L2 norm pruner also supports dependency-aware mode.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import L2NormPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = L2NormPruner(model, config_list)
masked_model, masks = pruner.compress()

User configuration for L2 Norm Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.L2NormPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], mode: str = 'normal', dummy_input: Optional[torch.Tensor] = None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • op_types : Conv2d and Linear are supported in L1NormPruner.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

  • mode (str) – ‘normal’ or ‘dependency_aware’. If prune the model in a dependency-aware way, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if set ‘dependency_aware’ , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (Optional[torch.Tensor]) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

FPGM Pruner

FPGM pruner prunes the blocks of the weight on the first dimension with the smallest geometric median. FPGM chooses the weight blocks with the most replaceable contribution.

For more details, please refer to Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration.

FPGM pruner also supports dependency-aware mode.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import FPGMPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = FPGMPruner(model, config_list)
masked_model, masks = pruner.compress()

User configuration for FPGM Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.FPGMPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], mode: str = 'normal', dummy_input: Optional[torch.Tensor] = None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • op_types : Conv2d and Linear are supported in FPGMPruner.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

  • mode (str) – ‘normal’ or ‘dependency_aware’. If prune the model in a dependency-aware way, this pruner will prune the model according to the FPGM of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if set ‘dependency_aware’ , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (Optional[torch.Tensor]) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

Slim Pruner

Slim pruner adds sparsity regularization on the scaling factors of batch normalization (BN) layers during training to identify unimportant channels. The channels with small scaling factor values will be pruned.

For more details, please refer to Learning Efficient Convolutional Networks through Network Slimming.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import SlimPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }]
pruner = SlimPruner(model, config_list, trainer, optimizer, criterion, training_epochs=1)
masked_model, masks = pruner.compress()

User configuration for Slim Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.SlimPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], trainer: Callable[[torch.nn.modules.module.Module, torch.optim.optimizer.Optimizer, Callable], None], optimizer: torch.optim.optimizer.Optimizer, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], training_epochs: int, scale: float = 0.0001, mode='global')[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • total_sparsityThis is to specify the total sparsity for all layers in this config,

      each layer may have different sparsity.

    • max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.

    • op_types : Only BatchNorm2d is supported in SlimPruner.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

  • trainer (Callable[[Module, Optimizer, Callable], None]) –

    A callable function used to train model or just inference. Take model, optimizer, criterion as input. The model will be trained or inferenced training_epochs epochs.

    Example:

    def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
        training = model.training
        model.train(mode=True)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
            optimizer.step()
        model.train(mode=training)
    

  • optimizer (torch.optim.Optimizer) – The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, so do not use this optimizer in other places.

  • criterion (Callable[[Tensor, Tensor], Tensor]) – The criterion function used in trainer. Take model output and target value as input, and return the loss.

  • training_epochs (int) – The epoch number for training model to sparsify the BN weight.

  • mode (str) – ‘normal’ or ‘global’. If prune the model in a global way, all layer weights with same config will be considered uniformly. That means a single layer may not reach or exceed the sparsity setting in config, but the total pruned weights meet the sparsity setting.

Activation APoZ Rank Pruner

Activation APoZ rank pruner is a pruner which prunes on the first weight dimension, with the smallest importance criterion APoZ calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion APoZ is explained in the paper Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures.

The APoZ is defined as:

\(APoZ_{c}^{(i)} = APoZ\left(O_{c}^{(i)}\right)=\frac{\sum_{k}^{N} \sum_{j}^{M} f\left(O_{c, j}^{(i)}(k)=0\right)}{N \times M}\)

Activation APoZ rank pruner also supports dependency-aware mode.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import ActivationAPoZRankPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = ActivationAPoZRankPruner(model, config_list, trainer, optimizer, criterion, training_batches=20)
masked_model, masks = pruner.compress()

User configuration for Activation APoZ Rank Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.ActivationAPoZRankPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], trainer: Callable[[torch.nn.modules.module.Module, torch.optim.optimizer.Optimizer, Callable], None], optimizer: torch.optim.optimizer.Optimizer, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], training_batches: int, activation: str = 'relu', mode: str = 'normal', dummy_input: Optional[torch.Tensor] = None)[source]

Activation Mean Rank Pruner

Activation mean rank pruner is a pruner which prunes on the first weight dimension, with the smallest importance criterion mean activation calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion mean activation is explained in section 2.2 of the paper Pruning Convolutional Neural Networks for Resource Efficient Inference.

Activation mean rank pruner also supports dependency-aware mode.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import ActivationMeanRankPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = ActivationMeanRankPruner(model, config_list, trainer, optimizer, criterion, training_batches=20)
masked_model, masks = pruner.compress()

User configuration for Activation Mean Rank Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.ActivationMeanRankPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], trainer: Callable[[torch.nn.modules.module.Module, torch.optim.optimizer.Optimizer, Callable], None], optimizer: torch.optim.optimizer.Optimizer, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], training_batches: int, activation: str = 'relu', mode: str = 'normal', dummy_input: Optional[torch.Tensor] = None)[source]

Taylor FO Weight Pruner

Taylor FO weight pruner is a pruner which prunes on the first weight dimension, based on estimated importance calculated from the first order taylor expansion on weights to achieve a preset level of network sparsity. The estimated importance is defined as the paper Importance Estimation for Neural Network Pruning.

\(\widehat{\mathcal{I}}_{\mathcal{S}}^{(1)}(\mathbf{W}) \triangleq \sum_{s \in \mathcal{S}} \mathcal{I}_{s}^{(1)}(\mathbf{W})=\sum_{s \in \mathcal{S}}\left(g_{s} w_{s}\right)^{2}\)

Taylor FO weight pruner also supports dependency-aware mode.

What’s more, we provide a global-sort mode for this pruner which is aligned with paper implementation.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import TaylorFOWeightPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = TaylorFOWeightPruner(model, config_list, trainer, optimizer, criterion, training_batches=20)
masked_model, masks = pruner.compress()

User configuration for Activation Mean Rank Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.TaylorFOWeightPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], trainer: Callable[[torch.nn.modules.module.Module, torch.optim.optimizer.Optimizer, Callable], None], optimizer: torch.optim.optimizer.Optimizer, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], training_batches: int, mode: str = 'normal', dummy_input: Optional[torch.Tensor] = None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • total_sparsityThis is to specify the total sparsity for all layers in this config,

      each layer may have different sparsity.

    • max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.

    • op_types : Conv2d and Linear are supported in TaylorFOWeightPruner.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

  • trainer (Callable[[Module, Optimizer, Callable]) –

    A callable function used to train model or just inference. Take model, optimizer, criterion as input. The model will be trained or inferenced training_epochs epochs.

    Example:

    def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
        training = model.training
        model.train(mode=True)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
            optimizer.step()
        model.train(mode=training)
    

  • optimizer (torch.optim.Optimizer) – The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, so do not use this optimizer in other places.

  • criterion (Callable[[Tensor, Tensor], Tensor]) – The criterion function used in trainer. Take model output and target value as input, and return the loss.

  • training_batches (int) – The batch number used to collect activations.

  • mode (str) –

    ‘normal’, ‘dependency_aware’ or ‘global’.

    If prune the model in a dependency-aware way, this pruner will prune the model according to the taylorFO and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if set ‘dependency_aware’ , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

    If prune the model in a global way, all layer weights with same config will be considered uniformly. That means a single layer may not reach or exceed the sparsity setting in config, but the total pruned weights meet the sparsity setting.

  • dummy_input (Optional[torch.Tensor]) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

ADMM Pruner

Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique, by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.

During the process of solving these two subproblems, the weights of the original model will be changed. Then a fine-grained pruning will be applied to prune the model according to the config list given.

This solution framework applies both to non-structured and different variations of structured pruning schemes.

For more details, please refer to A Systematic DNN Weight Pruning Framework using Alternating Direction Method of Multipliers.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import ADMMPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = ADMMPruner(model, config_list, trainer, optimizer, criterion, iterations=10, training_epochs=1)
masked_model, masks = pruner.compress()

User configuration for ADMM Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.ADMMPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], trainer: Callable[[torch.nn.modules.module.Module, torch.optim.optimizer.Optimizer, Callable], None], optimizer: torch.optim.optimizer.Optimizer, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], iterations: int, training_epochs: int)[source]

ADMM (Alternating Direction Method of Multipliers) Pruner is a kind of mathematical optimization technique. The metric used in this pruner is the absolute value of the weight. In each iteration, the weight with small magnitudes will be set to zero. Only in the final iteration, the mask will be generated and apply to model wrapper.

The original paper refer to: https://arxiv.org/abs/1804.03294.

Parameters
  • model (torch.nn.Module) – Model to be pruned.

  • config_list (List[Dict]) –

    Supported keys:
    • sparsity : This is to specify the sparsity for each layer in this config to be compressed.

    • sparsity_per_layer : Equals to sparsity.

    • rho : Penalty parameters in ADMM algorithm.

    • op_types : Operation types to prune.

    • op_names : Operation names to prune.

    • exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.

  • trainer (Callable[[Module, Optimizer, Callable]) –

    A callable function used to train model or just inference. Take model, optimizer, criterion as input. The model will be trained or inferenced training_epochs epochs.

    Example:

    def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
        training = model.training
        model.train(mode=True)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
            optimizer.step()
        model.train(mode=training)
    

  • optimizer (torch.optim.Optimizer) – The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, so do not use this optimizer in other places.

  • criterion (Callable[[Tensor, Tensor], Tensor]) – The criterion function used in trainer. Take model output and target value as input, and return the loss.

  • iterations (int) – The total iteration number in admm pruning algorithm.

  • training_epochs (int) – The epoch number for training model in each iteration.

Linear Pruner

Linear pruner is an iterative pruner, it will increase sparsity evenly from scratch during each iteration. For example, the final sparsity is set as 0.5, and the iteration number is 5, then the sparsity used in each iteration are [0, 0.1, 0.2, 0.3, 0.4, 0.5].

Usage

from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = LinearPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()

User configuration for Linear Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.LinearPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], pruning_algorithm: str, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[torch.nn.modules.module.Module], None]] = None, speed_up: bool = False, dummy_input: Optional[torch.Tensor] = None, evaluator: Optional[Callable[[torch.nn.modules.module.Module], float]] = None, pruning_params: dict = {})[source]
Parameters
  • model (Module) – The origin unwrapped pytorch model to be pruned.

  • config_list (List[Dict]) – The origin config list provided by the user. Note that this config_list is directly config the origin model. This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.

  • pruning_algorithm (str) – Supported pruning algorithm [‘level’, ‘l1’, ‘l2’, ‘fpgm’, ‘slim’, ‘apoz’, ‘mean_activation’, ‘taylorfo’, ‘admm’]. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.

  • total_iteration (int) – The total iteration number.

  • log_dir (str) – The log directory use to saving the result, you can find the best result under this folder.

  • keep_intermediate_result (bool) – If keeping the intermediate result, including intermediate model and masks during each iteration.

  • finetuner (Optional[Callable[[Module], None]]) – The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.

  • speed_up (bool) – If set True, speed up the model in each iteration.

  • dummy_input (Optional[torch.Tensor]) – If speed_up is True, dummy_input is required for trace the model in speed up.

  • evaluator (Optional[Callable[[Module], float]]) – Evaluate the pruned model and give a score. If evaluator is None, the best result refers to the latest result.

  • pruning_params (dict) – If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.

AGP Pruner

This is an iterative pruner, which the sparsity is increased from an initial sparsity value \(s_{i}\) (usually 0) to a final sparsity value \(s_{f}\) over a span of \(n\) pruning iterations, starting at training step \(t_{0}\) and with pruning frequency \(\Delta t\):

\(s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n \Delta t}\right)^{3} \text { for } t \in\left\{t_{0}, t_{0}+\Delta t, \ldots, t_{0} + n \Delta t\right\}\)

For more details please refer to To prune, or not to prune: exploring the efficacy of pruning for model compression.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import AGPPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = AGPPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()

User configuration for AGP Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.AGPPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], pruning_algorithm: str, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[torch.nn.modules.module.Module], None]] = None, speed_up: bool = False, dummy_input: Optional[torch.Tensor] = None, evaluator: Optional[Callable[[torch.nn.modules.module.Module], float]] = None, pruning_params: dict = {})[source]
Parameters
  • model (Module) – The origin unwrapped pytorch model to be pruned.

  • config_list (List[Dict]) – The origin config list provided by the user. Note that this config_list is directly config the origin model. This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.

  • pruning_algorithm (str) – Supported pruning algorithm [‘level’, ‘l1’, ‘l2’, ‘fpgm’, ‘slim’, ‘apoz’, ‘mean_activation’, ‘taylorfo’, ‘admm’]. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.

  • total_iteration (int) – The total iteration number.

  • log_dir (str) – The log directory use to saving the result, you can find the best result under this folder.

  • keep_intermediate_result (bool) – If keeping the intermediate result, including intermediate model and masks during each iteration.

  • finetuner (Optional[Callable[[Module], None]]) – The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.

  • speed_up (bool) – If set True, speed up the model in each iteration.

  • dummy_input (Optional[torch.Tensor]) – If speed_up is True, dummy_input is required for trace the model in speed up.

  • evaluator (Optional[Callable[[Module], float]]) – Evaluate the pruned model and give a score. If evaluator is None, the best result refers to the latest result.

  • pruning_params (dict) – If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.

Lottery Ticket Pruner

The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks, authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis, and articulate the lottery ticket hypothesis: dense, randomly-initialized, feed-forward networks contain subnetworks (winning tickets) that – when trained in isolation – reach test accuracy comparable to the original network in a similar number of iterations.

In this paper, the authors use the following process to prune a model, called iterative prunning:

  1. Randomly initialize a neural network f(x;theta_0) (where theta0 follows D{theta}).

  2. Train the network for j iterations, arriving at parameters theta_j.

  3. Prune p% of the parameters in theta_j, creating a mask m.

  4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).

  5. Repeat step 2, 3, and 4.

If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import LotteryTicketPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = LotteryTicketPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner, reset_weight=True)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()

User configuration for Lottery Ticket Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.LotteryTicketPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], pruning_algorithm: str, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[torch.nn.modules.module.Module], None]] = None, speed_up: bool = False, dummy_input: Optional[torch.Tensor] = None, evaluator: Optional[Callable[[torch.nn.modules.module.Module], float]] = None, reset_weight: bool = True, pruning_params: dict = {})[source]
Parameters
  • model (Module) – The origin unwrapped pytorch model to be pruned.

  • config_list (List[Dict]) – The origin config list provided by the user. Note that this config_list is directly config the origin model. This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.

  • pruning_algorithm (str) – Supported pruning algorithm [‘level’, ‘l1’, ‘l2’, ‘fpgm’, ‘slim’, ‘apoz’, ‘mean_activation’, ‘taylorfo’, ‘admm’]. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.

  • total_iteration (int) – The total iteration number.

  • log_dir (str) – The log directory use to saving the result, you can find the best result under this folder.

  • keep_intermediate_result (bool) – If keeping the intermediate result, including intermediate model and masks during each iteration.

  • finetuner (Optional[Callable[[Module], None]]) – The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.

  • speed_up (bool) – If set True, speed up the model in each iteration.

  • dummy_input (Optional[torch.Tensor]) – If speed_up is True, dummy_input is required for trace the model in speed up.

  • evaluator (Optional[Callable[[Module], float]]) – Evaluate the pruned model and give a score. If evaluator is None, the best result refers to the latest result.

  • reset_weight (bool) – If set True, the model weight will reset to the original model weight at the end of each iteration step.

  • pruning_params (dict) – If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.

Simulated Annealing Pruner

We implement a guided heuristic search method, Simulated Annealing (SA) algorithm. As mentioned in the paper, this method is enhanced on guided search based on prior experience. The enhanced SA technique is based on the observation that a DNN layer with more number of weights often has a higher degree of model compression with less impact on overall accuracy.

  • Randomly initialize a pruning rate distribution (sparsities).

  • While current_temperature < stop_temperature:

    1. generate a perturbation to current distribution

    2. Perform fast evaluation on the perturbated distribution

    3. accept the perturbation according to the performance and probability, if not accepted, return to step 1

    4. cool down, current_temperature <- current_temperature * cool_down_rate

For more details, please refer to AutoCompress: An Automatic DNN Structured Pruning Framework for Ultra-High Compression Rates.

Usage

from nni.algorithms.compression.v2.pytorch.pruning import SimulatedAnnealingPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm='l1', cool_down_rate=0.9, finetuner=finetuner)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()

User configuration for Simulated Annealing Pruner

PyTorch

class nni.algorithms.compression.v2.pytorch.pruning.SimulatedAnnealingPruner(model: torch.nn.modules.module.Module, config_list: List[Dict], pruning_algorithm: str, evaluator: Callable[[torch.nn.modules.module.Module], float], start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False, finetuner: Optional[Callable[[torch.nn.modules.module.Module], None]] = None, speed_up: bool = False, dummy_input: Optional[torch.Tensor] = None, pruning_params: dict = {})[source]
Parameters
  • model (Module) – The origin unwrapped pytorch model to be pruned.

  • config_list (List[Dict]) – The origin config list provided by the user. Note that this config_list is directly config the origin model. This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.

  • pruning_algorithm (str) – Supported pruning algorithm [‘level’, ‘l1’, ‘l2’, ‘fpgm’, ‘slim’, ‘apoz’, ‘mean_activation’, ‘taylorfo’, ‘admm’]. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.

  • evaluator (Callable[[Module], float]) – Evaluate the pruned model and give a score.

  • start_temperature (float) – Start temperature of the simulated annealing process.

  • stop_temperature (float) – Stop temperature of the simulated annealing process.

  • cool_down_rate (float) – Cool down rate of the temperature.

  • perturbation_magnitude (float) – Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.

  • log_dir (str) – The log directory use to saving the result, you can find the best result under this folder.

  • keep_intermediate_result (bool) – If keeping the intermediate result, including intermediate model and masks during each iteration.

  • finetuner (Optional[Callable[[Module], None]]) – The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.

  • speed_up (bool) – If set True, speed up the model in each iteration.

  • dummy_input (Optional[torch.Tensor]) – If speed_up is True, dummy_input is required for trace the model in speed up.

  • pruning_params (dict) – If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.