# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
from pathlib import Path
import queue
from typing import List
import torch
from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_mask import AutoMaskInference
from .jit_translate import jit_to_python_function
from .replacer import Replacer, DefaultReplacer
from ..utils import rand_like_with_shape, check_ddp_model, reset_ddp_model
from ..utils.attr import has_nested_attr, get_nested_attr
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
[docs]class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask.
Parameters
----------
model : pytorch model
The model user wants to speedup
dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
device.
masks_file : str/dict
The path of user provided mask file, or the mask object
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
the index of batch dimension in the dummy_input
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
customized_replacers
``customized_replacers`` is a list of ``Replacer``.
Call a ``Module`` that does not contain a ``Module`` as a leaf-module,
a ``Module`` that contains a ``Module`` as a hyper-module, then replacer is used to replace the hyper-module.
The difference between the replacer and replace function is that replacer can perform more efficient replacements
to hyper-module, and replace function is used to replace leaf-module.
In ``ModelSpeedup.compress``, replacers are first to be called to replace the hyper-modules before
replacing all leaf-modules by replace functions.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8, customized_replacers=None, customized_replace_func=None):
assert confidence > 1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
self.ori_state_dict = copy.deepcopy(model.state_dict())
self.bound_model = model
self.is_ddp_model, self.ddp_params = check_ddp_model(self.bound_model)
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.batch_dim = batch_dim
self.confidence = confidence
self.dummy_input, self.device = self._random_model_input(
dummy_input, confidence, batch_dim)
self.torch_graph = build_module_graph(model, self.dummy_input)
# dict object to save the auto inferences objects of the submodules
self.auto_inferences = {}
# the index dict to find the corresponding torch._C.Value object
# according to the debug name
# we need the dummy_input to infer the mask automaticlly, so we save
# the indexes from tensor's debugname to the torch._C.Value object.
self.debugname_to_value = {}
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
if isinstance(masks_file, (str, Path)) and Path(masks_file).exists():
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict):
self.masks = masks_file
else:
raise Exception('Please provide the mask or the path of the mask file')
self.constant = {}
# self.internal_result save the internal output of the submodules
self.internal_result = {}
self.default_replacer = DefaultReplacer(replace_module)
self.customized_replacers: List[Replacer] = customized_replacers if customized_replacers is not None else []
if customized_replace_func is not None:
warn_msg = '`customized_replace_func` has been deprecated, please using `customized_replacers`, '
warn_msg += 'it can be easily transfer to a replacer by '
warn_msg += 'customized_replacers=[DefaultReplacer(customized_replace_func)]'
_logger.warning(warn_msg)
self.customized_replacers.append(DefaultReplacer(customized_replace_func))
def _random_model_input(self, dummy_input, confidence, batch_dim):
"""
Get the new random dummy input accordint to the original dummy_input
and confidence, batch_dim.
Parameters
----------
dummy_input: Tensor or list/dict of Tensors
The dummy_input given by the user.
confidence: int
The new batch size of the generated dummy_input.
batch_dim: int
The index of the batch dimension.
Returns
------
new_dummy_input: Tensor or list/dict of Tensors
The generated dummy_input for mask inference.
device: torch.device
The device of the generated dummy_inputs
"""
input_errmsg = 'Only support the tensor, list/tuple/dict of tensors as input'
# Some model may use list of tensors as input, for example transformers
new_dummy_input, device = None, None
if isinstance(dummy_input, torch.Tensor):
input_shape = list(dummy_input.size())
# set the batchsize to the confidence ratio
input_shape[batch_dim] = confidence
new_dummy_input = rand_like_with_shape(input_shape, dummy_input)
device = dummy_input.device
elif isinstance(dummy_input, (tuple, list)):
# else if the dummy input is list/tuple
new_dummy_input = []
old_batchsize = dummy_input[0].size(0)
device = dummy_input[0].device
for _, t_input in enumerate(dummy_input):
assert isinstance(t_input, torch.Tensor), input_errmsg
# This check is too strict...
# assert t_input.size(0) == old_batchsize, 'The first dimension should be batchsize\
# and the batchsize of all inputs should be the same!'
input_shape = list(t_input.size())
if t_input.size(0) == old_batchsize:
input_shape[batch_dim] = confidence
else:
assert confidence == old_batchsize, 'Input tensor first dimension is not batchsize, \
in this situation, confidence should equal to dummy input fisrt tensor first dimension size.'
# rand_func = torch.randint if t_input.dtype
new_dummy_input.append(
rand_like_with_shape(input_shape, t_input))
elif isinstance(dummy_input, dict):
new_dummy_input = {}
tmp_key = list(dummy_input.keys())[0]
old_batchsize = dummy_input[tmp_key].size(0)
device = dummy_input[tmp_key].device
for in_name, t_input in dummy_input.items():
assert isinstance(t_input, torch.Tensor), input_errmsg
# This check is too strict...
# assert old_batchsize == t_input.size(0), 'The first dimension should be batchsize\
# and the batchsize of all inputs should be the same!'
input_shape = list(t_input.size())
if t_input.size(0) == old_batchsize:
input_shape[batch_dim] = confidence
else:
assert confidence == old_batchsize, 'Input tensor first dimension is not batchsize, \
in this situation, confidence should equal to dummy input fisrt tensor first dimension size.'
new_dummy_input[in_name] = rand_like_with_shape(
input_shape, t_input)
else:
raise TypeError(input_errmsg)
return new_dummy_input, device
def _prepare_dummy_input(self, node):
"""
Prepare the dummy_input for the auto mask inference.
Parameters
----------
node: NodePyGroup
Returns
-------
dummy_input: list
List of tensors that will be used as input for the target node.
debugnames: list of strs
Debugnames of the dummy_inputs.
"""
_logger.debug('Prepare auto mask inference for node: %s',
node.unique_name)
# prepare the inputs and outputs mask for this node,
# if there is already a mask in self.masks, then use
# the original mask tensor, else create a new one.
inputs_name = node.inputs
# build the dummy_input, in_masks the target node
dummy_input = []
debugnames = []
for _input in inputs_name:
if _input not in self.internal_result:
# if the input debug name is not in self.internal_result,
# then this node isn't a output tensor of any predecessor
# nodes. This node is a attribute of the submodule, such as
# weight or bias, etc. We will skip these tensors.
# If we don't want this specific judgement here, we can merge
# the `prim::GetAttr` node of the weight/bias tensor into the key
# node, such as `conv`.
# This is caused by the `meage_module_node` function in the
# _graph_utils.py, because it doesn't merge the prim::GetAttr
# node into the key node. In current version of _graph_utils.py,
# we will only merge the nodes that have same scope name, however,
# the scope name of the correponding prim::GetAttr node of `weight` tensor
# is None.
continue
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
if isinstance(self.internal_result[_input], torch.Tensor):
dummy_input.append(self.internal_result[_input].detach())
else:
dummy_input.append(self.internal_result[_input])
debugnames.append(_input)
return dummy_input, debugnames
[docs] def update_direct_sparsity(self, node):
"""
Update the direct sparsity for the target node. Here the direct sparsity
means that the sparsity in the output tensor that caused by the sparsity
in the input tensors/weight tensors.
"""
# this name is consistent with the name returned by named_modules()
module_name = node.name
_logger.info('Update mask for %s', module_name)
unique_name = node.unique_name
dummy_input, input_debugname = self._prepare_dummy_input(node)
# get the input mask from self.masks
# Note: the input mask of the successor nodes are
# already created by the predecessor node
in_masks = [self.masks[debugname] for debugname in input_debugname]
in_constants = [self.constant[debugname]
for debugname in input_debugname]
if node.type == 'func':
# we cannot get the runable function directly from the jit traced
# graph, so we translate it back to python function, Note: the function
# is appliable to both cpu/gpu devices, the output tensors will be on the
# same device of the input tensors
func = jit_to_python_function(node, self)
if func is None:
# no need to infer the sparsity for this node
self.auto_inferences[unique_name] = None
return
# function doesn't have weights
_auto_infer = AutoMaskInference(
func, dummy_input, self, in_masks, in_constants=in_constants)
else:
weight_mask = None
if module_name in self.masks:
weight_mask = self.masks[module_name]
_, module = get_module_by_name(self.bound_model, module_name)
_auto_infer = AutoMaskInference(
module, dummy_input, self, in_masks, weight_mask, in_constants=in_constants,
state_dict=copy.deepcopy(module.state_dict()))
self.auto_inferences[unique_name] = _auto_infer
_auto_infer.name = node.unique_name
_auto_infer.update_direct_sparsity()
# also save the input debug names into the auto_infer
_auto_infer.input_debugname = input_debugname
# update the mask tensor and the internal output of the submodules
# after manually unpack the tuple/list of tensors, the number of the outputs
# of each node should always be one(Except for the TupleUnpack node at the end
# of the whole model)
assert len(
node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually'
out_debugname = node.outputs[0]
# update the output mask into self.masks
self.masks[out_debugname] = _auto_infer.output_mask
self.constant[out_debugname] = _auto_infer.out_constant
# update the output result into self.internal_result, so that
# the successor nodes can take these output tensors as inputs.
self.internal_result[out_debugname] = _auto_infer.output
# update the parameter mask of the node
self.masks[module_name] = _auto_infer.weight_mask
[docs] def update_indirect_sparsity(self, node):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
Parameters
---------
node: the NodePy
The target node to update the indirect sparsity
"""
unique_name = node.unique_name
if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None:
# if the auto inference object already in self.auto_inference, then
# directly update the previous one
# self.auto_inferences[unique_name].update()
_logger.info(
'Update the indirect sparsity for the %s', unique_name)
auto_infer = self.auto_inferences[unique_name]
auto_infer.update_indirect_sparsity()
# pass the gradient to the predecessor nodes
for in_id, tin in enumerate(auto_infer.dummy_input):
debug_name = auto_infer.input_debugname[in_id]
last_output = self.internal_result[debug_name]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data
elif last_output.grad is None:
last_output.grad = tin.grad
elif last_output.grad is not None and tin.grad is None:
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2)
# the size operation of tin will have no gradient
continue
else:
_logger.warning(
'Note: %s does not have corresponding mask inference object', node.name)
def _vnode_to_value(self, c_node):
"""
translate the C Value node into the values/tensors.
"""
errmsg = "Only support the torch._C.Value type"
assert isinstance(c_node, torch._C.Value), errmsg
if isinstance(c_node.type(), torch._C.TensorType):
shape = tuple(c_node.type().sizes())
dtype = c_node.type().scalarType()
# TODO should use a more general way to get the input
if dtype.startswith('Float') or dtype.startswith('Double'):
return torch.rand(shape).to(self.device)
else:
# This small range is due to the ·ReLU6·, we will add
# the manual specific mask inference rule for several
# ops in the future, so that we can remove the constraint.
return torch.randint(0, 10, shape, device=self.device)
else:
value = c_node.toIValue()
# TODO support more kinds of value node
errmsg = "Doesn't support convert %s to values", str(c_node.type())
# currently only support the tensors and constant values
assert value is not None, errmsg
return value
[docs] def infer_modules_masks(self):
"""
Infer the mask for all layers in the module, this function can be divided into
two steps: first, forward inference of the the masks. Second, backward inference
of the mask. We keep repeating these two steps until the masks of the model doesn't
change.
"""
# unpack the tensor tuple/list before the mask inference
self.torch_graph.unpack_manually()
# find the input/ouput tensor of the whole graph
graph_input = []
graph_output = []
for name, nodeio in self.torch_graph.nodes_py.nodes_io.items():
if nodeio.input_or_output == 'input':
graph_input.append((name, nodeio))
# also put the graph input tensor into the internal_result
# TODO if we can find the corresponding relation between the value node
# and the dummy_inputs, we can use the inputs value in the dummy_input
value = self._vnode_to_value(self.debugname_to_value[name])
self.internal_result[name] = value
# create the mask tensor for the input value
if isinstance(self.internal_result[name], torch.Tensor):
self.masks[name] = torch.ones_like(value)
self.constant[name] = torch.zeros_like(value)
elif nodeio.input_or_output == 'output':
graph_output.append((name, nodeio))
# count the degree for the node in the graph
in_degree = {}
out_degree = {}
visit_queue = queue.Queue()
for node in self.torch_graph.nodes_py.nodes_op:
successors = self.torch_graph.find_successors(node.unique_name)
out_degree[node.unique_name] = len(successors)
predecessors = set(self.torch_graph.find_predecessors(node.unique_name))
in_degree[node.unique_name] = len(predecessors)
if in_degree[node.unique_name] == 0:
visit_queue.put(node)
# Forward mask inference
while not visit_queue.empty():
curnode = visit_queue.get()
# forward mask inference for curnode
self.update_direct_sparsity(curnode)
successors = self.torch_graph.find_successors(curnode.unique_name)
for successor in successors:
in_degree[successor] -= 1
if in_degree[successor] == 0:
visit_queue.put(self.torch_graph.name_to_node[successor])
# backward mask inference
for unique_name in out_degree:
if out_degree[unique_name] == 0:
visit_queue.put(self.torch_graph.name_to_node[unique_name])
while not visit_queue.empty():
curnode = visit_queue.get()
self.update_indirect_sparsity(curnode)
predecessors = set(self.torch_graph.find_predecessors(
curnode.unique_name))
for predecessor in predecessors:
out_degree[predecessor] -= 1
if out_degree[predecessor] == 0:
visit_queue.put(self.torch_graph.name_to_node[predecessor])
[docs] def replace_compressed_modules(self):
"""
Replace all the modules that have changed (weights/inputs/output) shape.
The new module is created using the same arguments of the to-be-replaced module,
and correctly inherits its weights.
NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
is that ```func``` should be not required to be replaced.
"""
with torch.no_grad():
for replacer in self.customized_replacers:
replacer.replace_modules(self.bound_model, self.auto_inferences)
self.default_replacer.replace_modules(self.bound_model, self.auto_inferences)
for unique_name in self.auto_inferences:
if has_nested_attr(self.bound_model, unique_name):
module = get_nested_attr(self.bound_model, unique_name)
if isinstance(module, torch.nn.Module):
err_msg = f"Has not supported replacing module with type: {type(module)}, "
err_msg += f"you could report an issue at https://github.com/microsoft/nni. "
err_msg += f"If you know how to replace {type(module)}, "
err_msg += f"you could implement module replacement by passing in"
err_msg += f"`customized_replace_func` to `{self.__class__.__name__}`. "
err_msg += f"You are welcome to contribute back to nni as native support "
err_msg += f"if you have implemented the replacement function, "
err_msg += f"so that more users can benefit from your contributions."
_logger.error(err_msg)
[docs] def initialize_speedup(self):
"""
Do some initial work for speedup.
"""
# initialize the self.debugname_to_value
# build a mapping table from the debug name of the tensor
# to its value node in the graph
traced_graph = self.torch_graph.trace.graph
for node in traced_graph.nodes():
for _input in node.inputs():
debug_name = _input.debugName()
if debug_name not in self.debugname_to_value:
self.debugname_to_value[debug_name] = _input
for _output in node.outputs():
debug_name = _output.debugName()
if debug_name not in self.debugname_to_value:
self.debugname_to_value[debug_name] = _output
# put the model itself into internel_result to perform the
# value inference for the 'prim::GetAttr', the first ClassType
# of the whole graph is the model class
for graph_input in traced_graph.inputs():
if graph_input.type().kind() == 'ClassType':
self.internal_result[graph_input.debugName()
] = self.bound_model
break
[docs] def speedup_model(self):
"""
There are basically two steps: first, do mask/shape inference,
second, replace modules.
"""
_logger.info("start to speedup the model")
self.initialize_speedup()
training = self.bound_model.training
# set to the evaluation mode
self.bound_model.train(False)
# TODO suppose to fix the conflict after the sparsity propagation
# which is more elegent
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info('resolve the mask conflict')
# sometimes, mask conflict will happen during infer masks
# fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
# load the original stat dict before replace the model
self.bound_model.load_state_dict(self.ori_state_dict)
_logger.info("replace compressed modules...")
# the mask conflict should be already resolved
self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done")
if self.is_ddp_model:
self.bound_model = reset_ddp_model(self.bound_model, self.ddp_params)
return self.bound_model