Pruning Speedup

class nni.compression.pytorch.speedup.ModelSpeedup(model, dummy_input, masks_file, map_location=None, batch_dim=0, confidence=8)[源代码]

This class is to speedup the model with provided weight mask.

参数
  • model (pytorch model) -- The model user wants to speedup

  • dummy_input (pytorch tensor, tuple of tensor, list of tensor) -- Note: The first dimension of the dummy_input should be the batchsize. The dummy input for `jit.trace`, users should put it on the right device.

  • masks_file (str/dict) -- The path of user provided mask file, or the mask object

  • map_location (str) -- the device on which masks are placed, same to map_location in `torch.load`

  • batch_dim (int) -- the index of batch dimension in the dummy_input

  • confidence (the confidence coefficient of the sparsity inference. This value is) -- actually used as the batchsize of the dummy_input.

infer_modules_masks()[源代码]

Infer the mask for all layers in the module, this function can be divided into two steps: first, forward inference of the the masks. Second, backward inference of the mask. We keep repeating these two steps until the masks of the model doesn't change.

initialize_speedup()[源代码]

Do some initial work for speedup.

replace_compressed_modules()[源代码]

Replace all the modules that have changed (weights/inputs/output) shape. The new module is created using the same arguments of the to-be-replaced module, and correctly inherits its weights.

NOTE: `func` type cannot be replaced as it is not a module, thus, one limitation is that `func` should be not required to be replaced.

replace_submodule(unique_name, reindex_dim=None, reindex=None)[源代码]

Replace the submodule according to the inferred sparsity.

参数
  • unique_name (str) -- The unique_name of the submodule to replace.

  • reindex_dim (int) -- The dimension of the re-index operation.

  • reindex (Reindex) -- The index tensor. Normally this variable is None. If we want to reindex the output of this submodule, we can pass the index by this parameter.

speedup_model()[源代码]

There are basically two steps: first, do mask/shape inference, second, replace modules.

update_direct_sparsity(node)[源代码]

Update the direct sparsity for the target node. Here the direct sparsity means that the sparsity in the output tensor that caused by the sparsity in the input tensors/weight tensors.

update_indirect_sparsity(node)[源代码]

This function will update the indirect sparsity. To explain what's indirect sparsity, for example, there is two tensors TA and TB, and we perform the calculation: TC = TA x TB in which TC is also a tensor. Once some values in TA are masked to zeros, then the corresponding positions in TB are also potential sparsities, because these have no effect of the final output(the gradient of these positions in TB equal to 0 all the time). This function it to fine the potential sparsity caused by other sparsity(we call it indirect sparsity here). Basically we can find these potential sparsity through gradient.

参数

node (the NodePy) -- The target node to update the indirect sparsity