Pruning Speedup

class nni.compression.pytorch.speedup.ModelSpeedup(model, dummy_input, masks_file, map_location=None, batch_dim=0, confidence=8, customized_replace_func=None)[source]

This class is to speedup the model with provided weight mask.

  • model (pytorch model) – The model user wants to speedup

  • dummy_input (pytorch tensor, tuple of tensor, list of tensor) – Note: The first dimension of the dummy_input should be the batchsize. The dummy input for `jit.trace`, users should put it on the right device.

  • masks_file (str/dict) – The path of user provided mask file, or the mask object

  • map_location (str) – the device on which masks are placed, same to map_location in `torch.load`

  • batch_dim (int) – the index of batch dimension in the dummy_input

  • confidence (the confidence coefficient of the sparsity inference. This value is) – actually used as the batchsize of the dummy_input.

  • customized_replace_func (None/Dict) –

    If customized_replace_func is not None, then we will use the given function to replace the corresponding modules. The key of the dict is the opertor types and the value is the replace function of corresponding opertor. The replace function should take two input parameters, one is the original module, the second input parameter is tuple of the input mask, output mask and weight mask. This replace function should prune the module accordingly. Here is an example of the replace function(more examples can refer to

    def example_replace(ori_module, masks):
        in_mask, out_mask, weight_mask = masks
        # prune the ori_module to a new smaller module according to the mask
        return new_small_module


Infer the mask for all layers in the module, this function can be divided into two steps: first, forward inference of the the masks. Second, backward inference of the mask. We keep repeating these two steps until the masks of the model doesn’t change.


Do some initial work for speedup.


Replace all the modules that have changed (weights/inputs/output) shape. The new module is created using the same arguments of the to-be-replaced module, and correctly inherits its weights.

NOTE: `func` type cannot be replaced as it is not a module, thus, one limitation is that `func` should be not required to be replaced.

replace_submodule(unique_name, reindex_dim=None, reindex=None)[source]

Replace the submodule according to the inferred sparsity.

  • unique_name (str) – The unique_name of the submodule to replace.

  • reindex_dim (int) – The dimension of the re-index operation.

  • reindex (Reindex) – The index tensor. Normally this variable is None. If we want to reindex the output of this submodule, we can pass the index by this parameter.


There are basically two steps: first, do mask/shape inference, second, replace modules.


Update the direct sparsity for the target node. Here the direct sparsity means that the sparsity in the output tensor that caused by the sparsity in the input tensors/weight tensors.


This function will update the indirect sparsity. To explain what’s indirect sparsity, for example, there is two tensors TA and TB, and we perform the calculation: TC = TA x TB in which TC is also a tensor. Once some values in TA are masked to zeros, then the corresponding positions in TB are also potential sparsities, because these have no effect of the final output(the gradient of these positions in TB equal to 0 all the time). This function it to fine the potential sparsity caused by other sparsity(we call it indirect sparsity here). Basically we can find these potential sparsity through gradient.


node (the NodePy) – The target node to update the indirect sparsity