Pruning Speedup

class nni.compression.pytorch.speedup.ModelSpeedup(model, dummy_input, masks_file, map_location=None, batch_dim=0, confidence=8, customized_replacers=None, customized_replace_func=None)[source]

This class is to speedup the model with provided weight mask.

  • model (pytorch model) – The model user wants to speedup

  • dummy_input (pytorch tensor, tuple of tensor, list of tensor) – Note: The first dimension of the dummy_input should be the batchsize. The dummy input for `jit.trace`, users should put it on the right device.

  • masks_file (str/dict) – The path of user provided mask file, or the mask object

  • map_location (str) – the device on which masks are placed, same to map_location in `torch.load`

  • batch_dim (int) – the index of batch dimension in the dummy_input

  • confidence (the confidence coefficient of the sparsity inference. This value is) – actually used as the batchsize of the dummy_input.

  • customized_replacerscustomized_replacers is a list of Replacer. Call a Module that does not contain a Module as a leaf-module, a Module that contains a Module as a hyper-module, then replacer is used to replace the hyper-module. The difference between the replacer and replace function is that replacer can perform more efficient replacements to hyper-module, and replace function is used to replace leaf-module. In ModelSpeedup.compress, replacers are first to be called to replace the hyper-modules before replacing all leaf-modules by replace functions.


Infer the mask for all layers in the module, this function can be divided into two steps: first, forward inference of the the masks. Second, backward inference of the mask. We keep repeating these two steps until the masks of the model doesn’t change.


Do some initial work for speedup.


Replace all the modules that have changed (weights/inputs/output) shape. The new module is created using the same arguments of the to-be-replaced module, and correctly inherits its weights.

NOTE: `func` type cannot be replaced as it is not a module, thus, one limitation is that `func` should be not required to be replaced.


There are basically two steps: first, do mask/shape inference, second, replace modules.


Update the direct sparsity for the target node. Here the direct sparsity means that the sparsity in the output tensor that caused by the sparsity in the input tensors/weight tensors.


This function will update the indirect sparsity. To explain what’s indirect sparsity, for example, there is two tensors TA and TB, and we perform the calculation: TC = TA x TB in which TC is also a tensor. Once some values in TA are masked to zeros, then the corresponding positions in TB are also potential sparsities, because these have no effect of the final output(the gradient of these positions in TB equal to 0 all the time). This function it to fine the potential sparsity caused by other sparsity(we call it indirect sparsity here). Basically we can find these potential sparsity through gradient.


node (the NodePy) – The target node to update the indirect sparsity