Pruning Speedup

class nni.compression.pytorch.speedup.v2.ModelSpeedup(model, dummy_input, masks_or_file, map_location=None, batch_dim=0, batch_size=8, customized_mask_updaters=None, customized_replacers=None, graph_module=None, garbage_collect_values=True, logger=None)[source]

This class is to speedup the model with provided weight mask, the masked module will be replaced by a new dense module. ModelSpeedup use concrete trace based on torch.fx to get the graph, note that the trace may fail if there is stochastic structure in the model.

  • model (torch.nn.Module) – The model user wants to speedup.

  • dummy_input (Any) – A tensor or a tuple, the dummy input to execute the model.

  • masks_or_file (Any) – The path of user provided masks file, or the masks object.

  • map_location (Any) – The device on which masks are placed, same to map_location in `torch.load`.

  • batch_dim (int) – The index of batch dimension in the dummy_input.

  • batch_size (int) – The batch_size coefficient of the sparsity inference. This value is actually used as the batchsize of the dummy_input.

  • customized_mask_updaters (List[MaskUpdater] | None) – A list of MaskUpdater. NNI will automatically infer sparsity based on the data distribution in the forward and backward process, but if some special operations lead to automatic sparsity inference errors, users can manually set the mask inference rules for the special operations to make the mask inference correct.

  • customized_replacers (List[Replacer] | None) – A list of Replacer. The replacer is used to replace the origin module with a compressed module. Users can costomized the replacement logic by customized a replacer. Before the built-in replacement logic in nni is executed, the replacement logic in the customized replacer list will be executed sequentially first.

  • graph_module (GraphModule | None) – A torch.fx.GraphModule. If ModelSpeedup default concrete trace cannot meet the needs, users can directly pass in a torch.fx.GraphModule instead.

  • garbage_collect_values (bool) – If the garbage_collect_values is True, nni will delete cache information after the cache has none usage.

  • logger (logging.Logger | None) – Set a logger. If the value is None, nni will use the default logger.

  • note:: (..) – Backwards-compatibility for this API is guaranteed.

placeholder(target, args, kwargs)[source]

Override the execution for ‘placeholder’ ops.


Backwards-compatibility for this API is guaranteed.


Propagate normally to get informations of intermediate variables such as shape, dtype of tensors. Default action: execute and store output to node_info.output_origin(intermediate variables when assigned), and node_info.output_inplace(intermediate variables after in-place ops).


Replace all the modules that have changed (weights/inputs/output) shape. The new module is created using the same arguments of the to-be-replaced module, and correctly inherits its weights.

NOTE: `func` type cannot be replaced as it is not a module, thus, one limitation is that `func` should be not required to be replaced.

store_attr(path, obj)[source]


Backwards-compatibility for this API is guaranteed.


Detect the tensor should be seen as an intermediate tensor.