Supported Pruning Algorithms on NNI

We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. Fine-grained Pruning generally results in unstructured models, which need specialized hardware or software to speed up the sparse network. Filter Pruning achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric (It is necessary to finetune the model to compensate for the loss of accuracy). Other pruning algorithms iteratively prune weights during optimization, which control the pruning schedule, including some automatic pruning algorithms.

One-shot Pruning

Iteratively Pruning

Others

Level Pruner

This is one basic one-shot pruner: you can set a target sparsity level (expressed as a fraction, 0.6 means we will prune 60% of the weight parameters).

We first sort the weights in the specified layer by their absolute values. And then mask to zero the smallest magnitude weights until the desired sparsity level is reached.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import LevelPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model, config_list)
pruner.compress()

User configuration for Level Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.LevelPruner(model, config_list)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : Operation types to prune.

TensorFlow

Slim Pruner

This is an one-shot pruner, which adds sparsity regularization on the scaling factors of batch normalization (BN) layers during training to identify unimportant channels. The channels with small scaling factor values will be pruned. For more details, please refer to ‘Learning Efficient Convolutional Networks through Network Slimming’.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import SlimPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }]
pruner = SlimPruner(model, config_list, optimizer, trainer, criterion)
pruner.compress()

User configuration for Slim Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.SlimPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_epochs=10, scale=0.0001, dependency_aware=False, dummy_input=None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : Only BatchNorm2d is supported in Slim Pruner.

  • optimizer (torch.optim.Optimizer) – Optimizer used to train model

  • trainer (function) – Function used to sparsify BatchNorm2d scaling factors. Users should write this function as a normal function to train the Pytorch model and include model, optimizer, criterion, epoch as function arguments.

  • criterion (function) – Function used to calculate the loss between the target and the output. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • sparsifying_training_epochs (int) – The number of channel sparsity regularization training epochs before pruning.

  • scale (float) – Penalty parameters for sparsification.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

Reproduced Experiment

We implemented one of the experiments in Learning Efficient Convolutional Networks through Network Slimming, we pruned 70% channels in the VGGNet for CIFAR-10 in the paper, in which 88.5% parameters are pruned. Our experiments results are as follows:

Model

Error(paper/ours)

Parameters

Pruned

VGGNet

6.34/6.69

20.04M

Pruned-VGGNet

6.20/6.34

2.03M

88.5%

The experiments code can be found at examples/model_compress/pruning/basic_pruners_torch.py

python basic_pruners_torch.py --pruner slim --model vgg19 --sparsity 0.7 --speed-up

FPGM Pruner

This is an one-shot pruner, which prunes filters with the smallest geometric median. FPGM chooses the filters with the most replaceable contribution. For more details, please refer to Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration.

We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference dependency-aware for more details.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import FPGMPruner
config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()

User configuration for FPGM Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.FPGMPruner(model, config_list, dependency_aware=False, dummy_input=None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : Only Conv2d is supported in FPGM Pruner.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

L1Filter Pruner

This is an one-shot pruner, which prunes the filters in the convolution layers.

For more details, please refer to PRUNING FILTERS FOR EFFICIENT CONVNETS.

In addition, we also provide a dependency-aware mode for the L1FilterPruner. For more details about the dependency-aware mode, please reference dependency-aware mode.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = L1FilterPruner(model, config_list)
pruner.compress()

User configuration for L1Filter Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.L1FilterPruner(model, config_list, dependency_aware=False, dummy_input=None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : Only Conv2d is supported in L1FilterPruner.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

Reproduced Experiment

We implemented one of the experiments in PRUNING FILTERS FOR EFFICIENT CONVNETS with L1FilterPruner, we pruned VGG-16 for CIFAR-10 to VGG-16-pruned-A in the paper, in which 64% parameters are pruned. Our experiments results are as follows:

Model

Error(paper/ours)

Parameters

Pruned

VGG-16

6.75/6.49

1.5x10^7

VGG-16-pruned-A

6.60/6.47

5.4x10^6

64.0%

The experiments code can be found at examples/model_compress/pruning/basic_pruners_torch.py

python basic_pruners_torch.py --pruner l1filter --model vgg16 --speed-up

L2Filter Pruner

This is a structured pruning algorithm that prunes the filters with the smallest L2 norm of the weights. It is implemented as a one-shot pruner.

We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference dependency-aware for more details.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import L2FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = L2FilterPruner(model, config_list)
pruner.compress()

User configuration for L2Filter Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.L2FilterPruner(model, config_list, dependency_aware=False, dummy_input=None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : Only Conv2d is supported in L2FilterPruner.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.


ActivationAPoZRankFilter Pruner

ActivationAPoZRankFilter Pruner is a pruner which prunes the filters with the smallest importance criterion APoZ calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion APoZ is explained in the paper Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures.

The APoZ is defined as:

\(APoZ_{c}^{(i)} = APoZ\left(O_{c}^{(i)}\right)=\frac{\sum_{k}^{N} \sum_{j}^{M} f\left(O_{c, j}^{(i)}(k)=0\right)}{N \times M}\)

We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference dependency-aware for more details.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import ActivationAPoZRankFilterPruner
config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d']
}]
pruner = ActivationAPoZRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1)
pruner.compress()

Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the op_types field supports only convolutional layers.

You can view example for more information.

User configuration for ActivationAPoZRankFilter Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.ActivationAPoZRankFilterPruner(model, config_list, optimizer, trainer, criterion, activation='relu', sparsifying_training_batches=1, dependency_aware=False, dummy_input=None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : How much percentage of convolutional filters are to be pruned.

    • op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.

  • optimizer (torch.optim.Optimizer) – Optimizer used to train model

  • trainer (function) – Function used to train the model. Users should write this function as a normal function to train the Pytorch model and include model, optimizer, criterion, epoch as function arguments.

  • criterion (function) – Function used to calculate the loss between the target and the output. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • activation (str) – The activation type.

  • sparsifying_training_batches (int) – The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.


ActivationMeanRankFilter Pruner

ActivationMeanRankFilterPruner is a pruner which prunes the filters with the smallest importance criterion mean activation calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion mean activation is explained in section 2.2 of the paper Pruning Convolutional Neural Networks for Resource Efficient Inference. Other pruning criteria mentioned in this paper will be supported in future release.

We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference dependency-aware for more details.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import ActivationMeanRankFilterPruner
config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d']
}]
pruner = ActivationMeanRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1)
pruner.compress()

Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the op_types field supports only convolutional layers.

You can view example for more information.

User configuration for ActivationMeanRankFilterPruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.ActivationMeanRankFilterPruner(model, config_list, optimizer, trainer, criterion, activation='relu', sparsifying_training_batches=1, dependency_aware=False, dummy_input=None)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : How much percentage of convolutional filters are to be pruned.

    • op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.

  • optimizer (torch.optim.Optimizer) – Optimizer used to train model.

  • trainer (function) – Function used to train the model. Users should write this function as a normal function to train the Pytorch model and include model, optimizer, criterion, epoch as function arguments.

  • criterion (function) – Function used to calculate the loss between the target and the output. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • activation (str) – The activation type.

  • sparsifying_training_batches (int) – The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.


TaylorFOWeightFilter Pruner

TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based on estimated importance calculated from the first order taylor expansion on weights to achieve a preset level of network sparsity. The estimated importance of filters is defined as the paper Importance Estimation for Neural Network Pruning. Other pruning criteria mentioned in this paper will be supported in future release.

\(\widehat{\mathcal{I}}_{\mathcal{S}}^{(1)}(\mathbf{W}) \triangleq \sum_{s \in \mathcal{S}} \mathcal{I}_{s}^{(1)}(\mathbf{W})=\sum_{s \in \mathcal{S}}\left(g_{s} w_{s}\right)^{2}\)

We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference dependency-aware for more details.

What’s more, we provide a global-sort mode for this pruner which is aligned with paper implementation. Please set parameter ‘global_sort’ to True when instantiate TaylorFOWeightFilterPruner.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import TaylorFOWeightFilterPruner
config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d']
}]
pruner = TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1)
pruner.compress()

User configuration for TaylorFOWeightFilter Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=False, dummy_input=None, global_sort=False)[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned

  • config_list (list) –

    Supported keys:
    • sparsity : How much percentage of convolutional filters are to be pruned.

    • op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.

  • optimizer (torch.optim.Optimizer) – Optimizer used to train model

  • trainer (function) – Function used to sparsify BatchNorm2d scaling factors. Users should write this function as a normal function to train the Pytorch model and include model, optimizer, criterion, epoch as function arguments.

  • criterion (function) – Function used to calculate the loss between the target and the output. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • sparsifying_training_batches (int) – The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.

  • dependency_aware (bool) – If prune the model in a dependency-aware way. If it is True, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers.

  • dummy_input (torch.Tensor) – The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model.

  • global_sort (bool) – Only support TaylorFOWeightFilterPruner currently. If prune the model in a global-sort way. If it is True, this pruner will prune the model according to the global contributions information which means channel contributions will be sorted globally and whether specific channel will be pruned depends on global information.


AGP Pruner

This is an iterative pruner, which the sparsity is increased from an initial sparsity value si (usually 0) to a final sparsity value sf over a span of n pruning steps, starting at training step \(t_{0}\) and with pruning frequency \(\Delta t\):

\(s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n \Delta t}\right)^{3} \text { for } t \in\left\{t_{0}, t_{0}+\Delta t, \ldots, t_{0} + n \Delta t\right\}\)

For more details please refer to To prune, or not to prune: exploring the efficacy of pruning for model compression.

Usage

You can prune all weights from 0% to 80% sparsity in 10 epoch with the code below.

PyTorch code

from nni.algorithms.compression.pytorch.pruning import AGPPruner
config_list = [{
    'sparsity': 0.8,
    'op_types': ['default']
}]

# load a pretrained model or train a model before using a pruner
# model = MyModel()
# model.load_state_dict(torch.load('mycheckpoint.pth'))

# AGP pruner prunes model while fine tuning the model by adding a hook on
# optimizer.step(), so an optimizer is required to prune the model.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)

pruner = AGPPruner(model, config_list, optimizer, trainer, criterion, pruning_algorithm='level')
pruner.compress()

AGP pruner uses LevelPruner algorithms to prune the weight by default, however you can set pruning_algorithm parameter to other values to use other pruning algorithms:

  • level: LevelPruner

  • slim: SlimPruner

  • l1: L1FilterPruner

  • l2: L2FilterPruner

  • fpgm: FPGMPruner

  • taylorfo: TaylorFOWeightFilterPruner

  • apoz: ActivationAPoZRankFilterPruner

  • mean_activation: ActivationMeanRankFilterPruner

User configuration for AGP Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.AGPPruner(model, config_list, optimizer, trainer, criterion, num_iterations=10, epochs_per_iteration=1, pruning_algorithm='level')[source]
Parameters
  • model (torch.nn.Module) – Model to be pruned.

  • config_list (listlist) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : See supported type in your specific pruning algorithm.

  • optimizer (torch.optim.Optimizer) – Optimizer used to train model.

  • trainer (function) – Function to train the model

  • criterion (function) – Function used to calculate the loss between the target and the output. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • num_iterations (int) – Total number of iterations in pruning process. We will calculate mask at the end of an iteration.

  • epochs_per_iteration (int) – The number of training epochs for each iteration.

  • pruning_algorithm (str) – Algorithms being used to prune model, choose from [‘level’, ‘slim’, ‘l1’, ‘l2’, ‘fpgm’, ‘taylorfo’, ‘apoz’, ‘mean_activation’], by default level


NetAdapt Pruner

NetAdapt allows a user to automatically simplify a pretrained network to meet the resource budget. Given the overall sparsity, NetAdapt will automatically generate the sparsities distribution among different layers by iterative pruning.

For more details, please refer to NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import NetAdaptPruner
config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d']
}]
pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator,base_algo='l1', experiment_data_dir='./')
pruner.compress()

You can view example for more information.

User configuration for NetAdapt Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.NetAdaptPruner(model, config_list, short_term_fine_tuner, evaluator, optimize_mode='maximize', base_algo='l1', sparsity_per_iteration=0.05, experiment_data_dir='./')[source]

A Pytorch implementation of NetAdapt compression algorithm.

Parameters
  • model (pytorch model) – The model to be pruned.

  • config_list (list) –

    Supported keys:
    • sparsity : The target overall sparsity.

    • op_types : The operation type to prune.

  • short_term_fine_tuner (function) –

    function to short-term fine tune the masked model. This function should include model as the only parameter, and fine tune the model for a short term after each pruning iteration. Example:

    def short_term_fine_tuner(model, epoch=3):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        train_loader = ...
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        model.train()
        for _ in range(epoch):
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
    

  • evaluator (function) –

    function to evaluate the masked model. This function should include model as the only parameter, and returns a scalar value. Example:

    def evaluator(model):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        val_loader = ...
        model.eval()
        correct = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                # get the index of the max log-probability
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(val_loader.dataset)
        return accuracy
    

  • optimize_mode (str) – optimize mode, maximize or minimize, by default maximize.

  • base_algo (str) – Base pruning algorithm. level, l1, l2 or fpgm, by default l1. Given the sparsity distribution among the ops, the assigned base_algo is used to decide which filters/channels/weights to prune.

  • sparsity_per_iteration (float) – sparsity to prune in each iteration.

  • experiment_data_dir (str) – PATH to save experiment data, including the config_list generated for the base pruning algorithm and the performance of the pruned model.

SimulatedAnnealing Pruner

We implement a guided heuristic search method, Simulated Annealing (SA) algorithm, with enhancement on guided search based on prior experience. The enhanced SA technique is based on the observation that a DNN layer with more number of weights often has a higher degree of model compression with less impact on overall accuracy.

  • Randomly initialize a pruning rate distribution (sparsities).

  • While current_temperature < stop_temperature:

    1. generate a perturbation to current distribution

    2. Perform fast evaluation on the perturbated distribution

    3. accept the perturbation according to the performance and probability, if not accepted, return to step 1

    4. cool down, current_temperature <- current_temperature * cool_down_rate

For more details, please refer to AutoCompress: An Automatic DNN Structured Pruning Framework for Ultra-High Compression Rates.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner
config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d']
}]
pruner = SimulatedAnnealingPruner(model, config_list, evaluator=evaluator, base_algo='l1', cool_down_rate=0.9, experiment_data_dir='./')
pruner.compress()

You can view example for more information.

User configuration for SimulatedAnnealing Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.SimulatedAnnealingPruner(model, config_list, evaluator, optimize_mode='maximize', base_algo='l1', start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, experiment_data_dir='./')[source]

A Pytorch implementation of Simulated Annealing compression algorithm.

Parameters
  • model (pytorch model) – The model to be pruned.

  • config_list (list) –

    Supported keys:
    • sparsity : The target overall sparsity.

    • op_types : The operation type to prune.

  • evaluator (function) –

    Function to evaluate the pruned model. This function should include model as the only parameter, and returns a scalar value. Example:

    def evaluator(model):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        val_loader = ...
        model.eval()
        correct = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                # get the index of the max log-probability
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(val_loader.dataset)
        return accuracy
    

  • optimize_mode (str) – Optimize mode, maximize or minimize, by default maximize.

  • base_algo (str) – Base pruning algorithm. level, l1, l2 or fpgm, by default l1. Given the sparsity distribution among the ops, the assigned base_algo is used to decide which filters/channels/weights to prune.

  • start_temperature (float) – Start temperature of the simulated annealing process.

  • stop_temperature (float) – Stop temperature of the simulated annealing process.

  • cool_down_rate (float) – Cool down rate of the temperature.

  • perturbation_magnitude (float) – Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.

  • experiment_data_dir (string) – PATH to save experiment data, including the config_list generated for the base pruning algorithm, the performance of the pruned model and the pruning history.

AutoCompress Pruner

For each round, AutoCompressPruner prune the model for the same sparsity to achive the overall sparsity:

1. Generate sparsities distribution using SimulatedAnnealingPruner
2. Perform ADMM-based structured pruning to generate pruning result for the next round.
   Here we use `speedup` to perform real pruning.

For more details, please refer to AutoCompress: An Automatic DNN Structured Pruning Framework for Ultra-High Compression Rates.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import AutoCompressPruner
config_list = [{
        'sparsity': 0.5,
        'op_types': ['Conv2d']
    }]
pruner = AutoCompressPruner(
            model, config_list, trainer=trainer, evaluator=evaluator,
            dummy_input=dummy_input, num_iterations=3, optimize_mode='maximize', base_algo='l1',
            cool_down_rate=0.9, admm_num_iterations=30, admm_training_epochs=5, experiment_data_dir='./')
pruner.compress()

You can view example for more information.

User configuration for AutoCompress Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.AutoCompressPruner(model, config_list, trainer, evaluator, dummy_input, criterion=CrossEntropyLoss(), num_iterations=3, optimize_mode='maximize', base_algo='l1', start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, admm_num_iterations=30, admm_epochs_per_iteration=5, row=0.0001, experiment_data_dir='./')[source]

A Pytorch implementation of AutoCompress pruning algorithm.

Parameters
  • model (pytorch model) – The model to be pruned.

  • config_list (list) –

    Supported keys:
    • sparsity : The target overall sparsity.

    • op_types : The operation type to prune.

  • trainer (function) – Function used for the first subproblem of ADMM Pruner. Users should write this function as a normal function to train the Pytorch model and include model, optimizer, criterion, epoch as function arguments.

  • criterion (function) – Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • evaluator (function) –

    function to evaluate the pruned model. This function should include model as the only parameter, and returns a scalar value. Example:

    def evaluator(model):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        val_loader = ...
        model.eval()
        correct = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                # get the index of the max log-probability
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(val_loader.dataset)
        return accuracy
    

  • dummy_input (pytorch tensor) – The dummy input for `jit.trace`, users should put it on right device before pass in.

  • num_iterations (int) – Number of overall iterations.

  • optimize_mode (str) – optimize mode, maximize or minimize, by default maximize.

  • base_algo (str) – Base pruning algorithm. level, l1, l2 or fpgm, by default l1. Given the sparsity distribution among the ops, the assigned base_algo is used to decide which filters/channels/weights to prune.

  • start_temperature (float) – Start temperature of the simulated annealing process.

  • stop_temperature (float) – Stop temperature of the simulated annealing process.

  • cool_down_rate (float) – Cool down rate of the temperature.

  • perturbation_magnitude (float) – Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.

  • admm_num_iterations (int) – Number of iterations of ADMM Pruner.

  • admm_epochs_per_iteration (int) – Training epochs of the first optimization subproblem of ADMMPruner.

  • row (float) – Penalty parameters for ADMM training.

  • experiment_data_dir (string) – PATH to store temporary experiment data.

AMC Pruner

AMC pruner leverages reinforcement learning to provide the model compression policy. This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio, better preserving the accuracy and freeing human labor.

For more details, please refer to AMC: AutoML for Model Compression and Acceleration on Mobile Devices.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import AMCPruner
config_list = [{
        'op_types': ['Conv2d', 'Linear']
    }]
pruner = AMCPruner(model, config_list, evaluator, val_loader, flops_ratio=0.5)
pruner.compress()

You can view example for more information.

User configuration for AMC Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.AMCPruner(model, config_list, evaluator, val_loader, suffix=None, model_type='mobilenet', dataset='cifar10', flops_ratio=0.5, lbound=0.2, rbound=1.0, reward='acc_reward', n_calibration_batches=60, n_points_per_layer=10, channel_round=8, hidden1=300, hidden2=300, lr_c=0.001, lr_a=0.0001, warmup=100, discount=1.0, bsize=64, rmsize=100, window_length=1, tau=0.01, init_delta=0.5, delta_decay=0.99, max_episode_length=1000000000.0, output_dir='./logs', debug=False, train_episode=800, epsilon=50000, seed=None)[source]

A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices. (https://arxiv.org/pdf/1802.03494.pdf)

Parameters
  • model – nn.Module The model to be pruned.

  • config_list – list Configuration list to configure layer pruning. Supported keys: - op_types: operation type to be pruned - op_names: operation name to be pruned

  • evaluator – function function to evaluate the pruned model. The prototype of the function: >>> def evaluator(val_loader, model): >>> … >>> return acc

  • val_loader – torch.utils.data.DataLoader Data loader of validation dataset.

  • suffix – str suffix to help you remember what experiment you ran. Default: None.

  • environment (# parameters for pruning) –

  • model_type – str model type to prune, currently ‘mobilenet’ and ‘mobilenetv2’ are supported. Default: mobilenet

  • flops_ratio – float preserve flops ratio. Default: 0.5

  • lbound – float minimum weight preserve ratio for each layer. Default: 0.2

  • rbound – float maximum weight preserve ratio for each layer. Default: 1.0

  • reward – function reward function type: - acc_reward: accuracy * 0.01 - acc_flops_reward: - (100 - accuracy) * 0.01 * np.log(flops) Default: acc_reward

  • pruning (# parameters for channel) –

  • n_calibration_batches – int number of batches to extract layer information. Default: 60

  • n_points_per_layer – int number of feature points per layer. Default: 10

  • channel_round – int round channel to multiple of channel_round. Default: 8

  • agent (# parameters for training ddpg) –

  • hidden1 – int hidden num of first fully connect layer. Default: 300

  • hidden2 – int hidden num of second fully connect layer. Default: 300

  • lr_c – float learning rate for critic. Default: 1e-3

  • lr_a – float learning rate for actor. Default: 1e-4

  • warmup – int number of episodes without training but only filling the replay memory. During warmup episodes, random actions ares used for pruning. Default: 100

  • discount – float next Q value discount for deep Q value target. Default: 0.99

  • bsize – int minibatch size for training DDPG agent. Default: 64

  • rmsize – int memory size for each layer. Default: 100

  • window_length – int replay buffer window length. Default: 1

  • tau – float moving average for target network being used by soft_update. Default: 0.99

  • noise (#) –

  • init_delta – float initial variance of truncated normal distribution

  • delta_decay – float delta decay during exploration

  • agent

  • max_episode_length – int maximum episode length

  • output_dir – str output directory to save log files and model files. Default: ./logs

  • debug – boolean debug mode

  • train_episode – int train iters each timestep. Default: 800

  • epsilon – int linear decay of exploration policy. Default: 50000

  • seed – int random seed to set for reproduce experiment. Default: None

Reproduced Experiment

We implemented one of the experiments in AMC: AutoML for Model Compression and Acceleration on Mobile Devices, we pruned MobileNet to 50% FLOPS for ImageNet in the paper. Our experiments results are as follows:

Model

Top 1 acc.(paper/ours)

Top 5 acc. (paper/ours)

FLOPS

MobileNet

70.5% / 69.9%

89.3% / 89.1%

50%

The experiments code can be found at examples/model_compress/pruning/

ADMM Pruner

Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique, by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.

During the process of solving these two subproblems, the weights of the original model will be changed. An one-shot pruner will then be applied to prune the model according to the config list given.

This solution framework applies both to non-structured and different variations of structured pruning schemes.

For more details, please refer to A Systematic DNN Weight Pruning Framework using Alternating Direction Method of Multipliers.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import ADMMPruner
config_list = [{
            'sparsity': 0.8,
            'op_types': ['Conv2d'],
            'op_names': ['conv1']
        }, {
            'sparsity': 0.92,
            'op_types': ['Conv2d'],
            'op_names': ['conv2']
        }]
pruner = ADMMPruner(model, config_list, trainer, num_iterations=30, epochs_per_iteration=5)
pruner.compress()

You can view example for more information.

User configuration for ADMM Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.ADMMPruner(model, config_list, trainer, criterion=CrossEntropyLoss(), num_iterations=30, epochs_per_iteration=5, row=0.0001, base_algo='l1')[source]

A Pytorch implementation of ADMM Pruner algorithm.

Parameters
  • model (torch.nn.Module) – Model to be pruned.

  • config_list (list) – List on pruning configs.

  • trainer (function) – Function used for the first subproblem. Users should write this function as a normal function to train the Pytorch model and include model, optimizer, criterion, epoch as function arguments.

  • criterion (function) – Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • num_iterations (int) – Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner.

  • epochs_per_iteration (int) – Training epochs of the first subproblem.

  • row (float) – Penalty parameters for ADMM training.

  • base_algo (str) – Base pruning algorithm. level, l1, l2 or fpgm, by default l1. Given the sparsity distribution among the ops, the assigned base_algo is used to decide which filters/channels/weights to prune.

Lottery Ticket Hypothesis

The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks, authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis, and articulate the lottery ticket hypothesis: dense, randomly-initialized, feed-forward networks contain subnetworks (winning tickets) that – when trained in isolation – reach test accuracy comparable to the original network in a similar number of iterations.

In this paper, the authors use the following process to prune a model, called iterative prunning:

  1. Randomly initialize a neural network f(x;theta_0) (where theta0 follows D{theta}).

  2. Train the network for j iterations, arriving at parameters theta_j.

  3. Prune p% of the parameters in theta_j, creating a mask m.

  4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).

  5. Repeat step 2, 3, and 4.

If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning, each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import LotteryTicketPruner
config_list = [{
    'prune_iterations': 5,
    'sparsity': 0.8,
    'op_types': ['default']
}]
pruner = LotteryTicketPruner(model, config_list, optimizer)
pruner.compress()
for _ in pruner.get_prune_iterations():
    pruner.prune_iteration_start()
    for epoch in range(epoch_num):
        ...

The above configuration means that there are 5 times of iterative pruning. As the 5 times iterative pruning are executed in the same run, LotteryTicketPruner needs model and optimizer (Note that should add ``lr_scheduler`` if used) to reset their states every time a new prune iteration starts. Please use get_prune_iterations to get the pruning iterations, and invoke prune_iteration_start at the beginning of each iteration. epoch_num is better to be large enough for model convergence, because the hypothesis is that the performance (accuracy) got in latter rounds with high sparsity could be comparable with that got in the first round.

User configuration for LotteryTicket Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.LotteryTicketPruner(model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True)[source]
Parameters
  • model (pytorch model) – The model to be pruned

  • config_list (list) –

    Supported keys:
    • prune_iterations : The number of rounds for the iterative pruning.

    • sparsity : The final sparsity when the compression is done.

  • optimizer (pytorch optimizer) – The optimizer for the model

  • lr_scheduler (pytorch lr scheduler) – The lr scheduler for the model if used

  • reset_weights (bool) – Whether reset weights and optimizer at the beginning of each round.

Reproduced Experiment

We try to reproduce the experiment result of the fully connected network on MNIST using the same configuration as in the paper. The code can be referred here. In this experiment, we prune 10 times, for each pruning we train the pruned model for 50 epochs.

The above figure shows the result of the fully connected network. round0-sparsity-0.0 is the performance without pruning. Consistent with the paper, pruning around 80% also obtain similar performance compared to non-pruning, and converges a little faster. If pruning too much, e.g., larger than 94%, the accuracy becomes lower and convergence becomes a little slower. A little different from the paper, the trend of the data in the paper is relatively more clear.

Sensitivity Pruner

For each round, SensitivityPruner prunes the model based on the sensitivity to the accuracy of each layer until meeting the final configured sparsity of the whole model:

1. Analyze the sensitivity of each layer in the current state of the model.
2. Prune each layer according to the sensitivity.

For more details, please refer to Learning both Weights and Connections for Efficient Neural Networks.

Usage

PyTorch code

from nni.algorithms.compression.pytorch.pruning import SensitivityPruner
config_list = [{
        'sparsity': 0.5,
        'op_types': ['Conv2d']
    }]
pruner = SensitivityPruner(model, config_list, finetuner=fine_tuner, evaluator=evaluator)
# eval_args and finetune_args are the parameters passed to the evaluator and finetuner respectively
pruner.compress(eval_args=[model], finetune_args=[model])

User configuration for Sensitivity Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.SensitivityPruner(model, config_list, evaluator, finetuner=None, base_algo='l1', sparsity_proportion_calc=None, sparsity_per_iter=0.1, acc_drop_threshold=0.05, checkpoint_dir=None)[source]

This function prune the model based on the sensitivity for each layer.

Parameters
  • model (torch.nn.Module) – model to be compressed

  • evaluator (function) – validation function for the model. This function should return the accuracy of the validation dataset. The input parameters of evaluator can be specified in the parameter eval_args and ‘eval_kwargs’ of the compress function if needed. Example: >>> def evaluator(model): >>> device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”) >>> val_loader = … >>> model.eval() >>> correct = 0 >>> with torch.no_grad(): >>> for data, target in val_loader: >>> data, target = data.to(device), target.to(device) >>> output = model(data) >>> # get the index of the max log-probability >>> pred = output.argmax(dim=1, keepdim=True) >>> correct += pred.eq(target.view_as(pred)).sum().item() >>> accuracy = correct / len(val_loader.dataset) >>> return accuracy

  • finetuner (function) – finetune function for the model. This parameter is not essential, if is not None, the sensitivity pruner will finetune the model after pruning in each iteration. The input parameters of finetuner can be specified in the parameter of compress called finetune_args and finetune_kwargs if needed. Example: >>> def finetuner(model, epoch=3): >>> device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”) >>> train_loader = … >>> criterion = torch.nn.CrossEntropyLoss() >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01) >>> model.train() >>> for _ in range(epoch): >>> for _, (data, target) in enumerate(train_loader): >>> data, target = data.to(device), target.to(device) >>> optimizer.zero_grad() >>> output = model(data) >>> loss = criterion(output, target) >>> loss.backward() >>> optimizer.step()

  • base_algo (str) – base pruning algorithm. level, l1, l2 or fpgm, by default l1.

  • sparsity_proportion_calc (function) – This function generate the sparsity proportion between the conv layers according to the sensitivity analysis results. We provide a default function to quantify the sparsity proportion according to the sensitivity analysis results. Users can also customize this function according to their needs. The input of this function is a dict, for example : {‘conv1’ : {0.1: 0.9, 0.2 : 0.8}, ‘conv2’ : {0.1: 0.9, 0.2 : 0.8}}, in which, ‘conv1’ and is the name of the conv layer, and 0.1:0.9 means when the sparsity of conv1 is 0.1 (10%), the model’s val accuracy equals to 0.9.

  • sparsity_per_iter (float) – The sparsity of the model that the pruner try to prune in each iteration.

  • acc_drop_threshold (float) – The hyperparameter used to quantifiy the sensitivity for each layer.

  • checkpoint_dir (str) – The dir path to save the checkpoints during the pruning.

Transformer Head Pruner

Transformer Head Pruner is a tool designed for pruning attention heads from the models belonging to the Transformer family. The following image from Efficient Transformers: A Survey gives a good overview the general structure of the Transformer.

Typically, each attention layer in the Transformer models consists of four weights: three projection matrices for query, key, value, and an output projection matrix. The outputs of the former three matrices contains the projected results for all heads. Normally, the results are then reshaped so that each head performs that attention computation independently. The final results are concatenated back before fed into the output projection. Therefore, when an attention head is pruned, the same weights corresponding to that heads in the three projection matrices are pruned. Also, the weights in the output projection corresponding to the head’s output are pruned. In our implementation, we calculate and apply masks to the four matrices together.

Note: currently, the pruner can only handle models with projection weights written as separate Linear modules, i.e., it expects four Linear modules corresponding to query, key, value, and an output projections. Therefore, in the config_list, you should either write ['Linear'] for the op_types field, or write names corresponding to Linear modules for the op_names field. For instance, the Huggingface transformers are supported, but torch.nn.Transformer is not.

The pruner implements the following algorithm:

Repeat for each pruning iteration (1 for one-shot pruning):
   1. Calculate importance scores for each head in each specified layer using a specific criterion.
   2. Sort heads locally or globally, and prune out some heads with lowest scores. The number of pruned heads is determined according to the sparsity specified in the config.
   3. If the specified pruning iteration is larger than 1 (iterative pruning), finetune the model for a while before the next pruning iteration.

Currently, the following head sorting criteria are supported:

  • “l1_weight”: rank heads by the L1-norm of weights of the query, key, and value projection matrices.

  • “l2_weight”: rank heads by the L2-norm of weights of the query, key, and value projection matrices.

  • “l1_activation”: rank heads by the L1-norm of their attention computation output.

  • “l2_activation”: rank heads by the L2-norm of their attention computation output.

  • “taylorfo”: rank heads by l1 norm of the output of attention computation * gradient for this output. Check more details in this paper and this one.

We support local sorting (i.e., sorting heads within a layer) and global sorting (sorting all heads together), and you can control by setting the global_sort parameter. Note that if global_sort=True is passed, all weights must have the same sparsity in the config list. However, this does not mean that each layer will be prune to the same sparsity as specified. This sparsity value will be interpreted as a global sparsity, and each layer is likely to have different sparsity after pruning by global sort. As a reminder, we found that if global sorting is used, it is usually helpful to use an iterative pruning scheme, interleaving pruning with intermediate finetuning, since global sorting often results in non-uniform sparsity distributions, which makes the model more susceptible to forgetting.

In our implementation, we support two ways to group the four weights in the same layer together. You can either pass a nested list containing the names of these modules as the pruner’s initialization parameters (usage below), or simply pass a dummy input instead and the pruner will run torch.jit.trace to group the weights (experimental feature). However, if you would like to assign different sparsity to each layer, you can only use the first option, i.e., passing names of the weights to the pruner (see usage below). Also, note that we require the weights belonging to the same layer to have the same sparsity.

Usage

Suppose we want to prune a BERT with Huggingface implementation, which has the following architecture (obtained by calling print(model)). Note that we only show the first layer of the repeated layers in the encoder’s ModuleList layer.

Usage Example: one-shot pruning, assigning sparsity 0.5 to the first six layers and sparsity 0.25 to the last six layers (PyTorch code). Note that

  • Here we specify op_names in the config list to assign different sparsity to different layers.

  • Meanwhile, we pass attention_name_groups to the pruner so that the pruner may group together the weights belonging to the same attention layer.

  • Since in this example we want to do one-shot pruning, the num_iterations parameter is set to 1, and the parameter epochs_per_iteration is ignored. If you would like to do iterative pruning instead, you can set the num_iterations parameter to the number of pruning iterations, and the epochs_per_iteration parameter to the number of finetuning epochs between two iterations.

  • The arguments trainer and optimizer are only used when we want to do iterative pruning, or the ranking criterion is taylorfo. Here these two parameters are ignored by the pruner.

  • The argument forward_runner is only used when the ranking criterion is l1_activation or l2_activation. Here this parameter is ignored by the pruner.

from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
attention_name_groups = list(zip(["encoder.layer.{}.attention.self.query".format(i) for i in range(12)],
                                 ["encoder.layer.{}.attention.self.key".format(i) for i in range(12)],
                                 ["encoder.layer.{}.attention.self.value".format(i) for i in range(12)],
                                 ["encoder.layer.{}.attention.output.dense".format(i) for i in range(12)]))

kwargs = {"ranking_criterion": "l1_weight",
          "global_sort": False,
          "num_iterations": 1,
          "epochs_per_iteration": 1,    # this is ignored when num_iterations = 1
          "head_hidden_dim": 64,
          "attention_name_groups": attention_name_groups,
          "trainer": trainer,
          "optimizer": optimizer,
          "forward_runner": forward_runner
          }
config_list = [{
     "sparsity": 0.5,
     "op_types": ["Linear"],
     "op_names": [x for layer in attention_name_groups[:6] for x in layer]      # first six layers
}, {
     "sparsity": 0.25,
     "op_types": ["Linear"],
     "op_names": [x for layer in attention_name_groups[6:] for x in layer]      # last six layers
}]

pruner = TransformerHeadPruner(model, config_list, **kwargs)
pruner.compress()

In addition to this usage guide, we provide a more detailed example of pruning BERT (Huggingface implementation) for transfer learning on the tasks from the GLUE benchmark. Please find it in this page. To run the example, first make sure that you install the package transformers and datasets. Then, you may start by running the following command:

./run.sh gpu_id glue_task

By default, the code will download a pretrained BERT language model, and then finetune for several epochs on the downstream GLUE task. Then, the TransformerHeadPruner will be used to prune out heads from each layer by a certain criterion (by default, the code lets the pruner uses magnitude ranking, and prunes out 50% of the heads in each layer in an one-shot manner). Finally, the pruned model will be finetuned in the downstream task for several epochs. You can check the details of pruning from the logs printed out by the example. You can also experiment with different pruning settings by changing the parameters in run.sh, or directly changing the config_list in transformer_pruning.py.

User configuration for Transformer Head Pruner

PyTorch

class nni.algorithms.compression.pytorch.pruning.TransformerHeadPruner(model, config_list, head_hidden_dim, attention_name_groups=None, dummy_input=None, ranking_criterion='l1_weight', global_sort=False, num_iterations=1, epochs_per_iteration=1, optimizer=None, trainer=None, criterion=None, forward_runner=None, **algo_kwargs)[source]

A pruner specialized for pruning attention heads in models belong to the transformer family.

Parameters
  • model (torch.nn.Module) – Model to be pruned. Expect a model from transformers library (e.g., BertModel). This pruner can work with other customized transformer models, but some ranking modes might fail.

  • config_list (list) –

    Supported keys:
    • sparsity : This is to specify the sparsity operations to be compressed to.

    • op_types : Optional. Operation types to prune. (Should be ‘Linear’ for this pruner.)

    • op_names : Optional. Operation names to prune.

  • head_hidden_dim (int) – Dimension of the hidden dimension of each attention head. (e.g., 64 for BERT) We assume that this head_hidden_dim is constant across the entire model.

  • attention_name_groups (list (Optional)) – List of groups of names for weights of each attention layer. Each element should be a four-element list, with the first three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj.

  • dummy_input (torch.Tensor (Optional)) – Input to model’s forward method, used to infer module grouping if attention_name_groups is not specified. This tensor is used by the underlying torch.jit.trace to infer the module graph.

  • ranking_criterion (str) –

    The criterion for ranking attention heads. Currently we support:
    • l1_weight: l1 norm of Q_proj, K_proj, and V_proj

    • l2_weight: l2 norm of Q_proj, K_proj, and V_proj

    • l1_activation: l1 norm of the output of attention computation

    • l2_activation: l2 norm of the output of attention computation

    • taylorfo: l1 norm of the output of attention computation * gradient for this output

      (check more details in the masker documentation)

  • global_sort (bool) – Whether rank the heads globally or locally before deciding heads to prune.

  • num_iterations (int) – Number of pruning iterations. Defaults to 1 (ont-shot pruning). If num_iterations > 1, the pruner will split the sparsity specified in config_list uniformly and assign a fraction to each pruning iteration.

  • epochs_per_iteration (int) – Number of finetuning epochs before the next pruning iteration. Only used when num_iterations > 1. If num_iterations is 1, then no finetuning is performed by the pruner after pruning.

  • optimizer (torch.optim.Optimizer) – Optimizer used to train model

  • trainer (function) – Function used to finetune the model between pruning iterations. Only used when num_iterations > 1 or ranking_criterion is ‘taylorfo’. Users should write this function as a normal function to train the PyTorch model and include model, optimizer, criterion, epoch as function arguments. Note that the trainer is also used for collecting gradients for pruning if ranking_criterion is ‘taylorfo’. In that case, epoch=None will be passed.

  • criterion (function) – Function used to calculate the loss between the target and the output. Only used when num_iterations > 1 or ranking_criterion is ‘taylorfo’. For example, you can use torch.nn.CrossEntropyLoss() as input.

  • forward_runner (function) – Function used to perform a “dry run” on the model on the entire train/validation dataset in order to collect data for pruning required by the criteria ‘l1_activation’ or ‘l2_activation’. Only used when ranking_criterion is ‘l1_activation’ or ‘l2_activation’. Users should write this function as a normal function that accepts a PyTorch model and runs forward on the model using the entire train/validation dataset. This function is not expected to perform any backpropagation or parameter updates.