# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import queue
import logging
import copy
import torch
import torch.nn as nn
from nni.common.graph_utils import build_module_graph
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.utils import get_module_by_name
from .compress_modules import replace_module
from .infer_mask import AutoMaskInference
from .jit_translate import jit_to_python_function
from ..utils import rand_like_with_shape
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
[docs]class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask.
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
device.
masks_file : str/dict
The path of user provided mask file, or the mask object
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
the index of batch dimension in the dummy_input
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8):
assert confidence > 1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
self.ori_state_dict = copy.deepcopy(model.state_dict())
self.bound_model = model
self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.batch_dim = batch_dim
self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim)
self.torch_graph = build_module_graph(model, self.dummy_input)
# dict object to save the auto inferences objects of the submodules
self.auto_inferences = {}
# the index dict to find the corresponding torch._C.Value object
# according to the debug name
# we need the dummy_input to infer the mask automaticlly, so we save
# the indexes from tensor's debugname to the torch._C.Value object.
self.debugname_to_value = {}
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
if isinstance(masks_file, str) and os.path.exists(masks_file):
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict):
self.masks = masks_file
else:
raise Exception('Please provide the mask or the path of the mask file')
self.constant = {}
# self.internal_result save the internal output of the submodules
self.internal_result = {}
def _random_model_input(self, dummy_input, confidence, batch_dim):
"""
Get the new random dummy input accordint to the original dummy_input
and confidence, batch_dim.
Parameters
----------
dummy_input: Tensor or list/dict of Tensors
The dummy_input given by the user.
confidence: int
The new batch size of the generated dummy_input.
batch_dim: int
The index of the batch dimension.
Returns
------
new_dummy_input: Tensor or list/dict of Tensors
The generated dummy_input for mask inference.
device: torch.device
The device of the generated dummy_inputs
"""
input_errmsg = 'Only support the tensor, list/tuple/dict of tensors as input'
# Some model may use list of tensors as input, for example transformers
new_dummy_input, device = None, None
if isinstance(dummy_input, torch.Tensor):
input_shape = list(dummy_input.size())
# set the batchsize to the confidence ratio
input_shape[batch_dim] = confidence
new_dummy_input = rand_like_with_shape(input_shape, dummy_input)
device = dummy_input.device
elif isinstance(dummy_input, (tuple, list)):
# else if the dummy input is list/tuple
new_dummy_input = []
old_batchsize = dummy_input[0].size(0)
device = dummy_input[0].device
for _, t_input in enumerate(dummy_input):
assert isinstance(t_input, torch.Tensor), input_errmsg
assert t_input.size(0) == old_batchsize, 'The first dimension should be batchsize\
and the batchsize of all inputs should be the same!'
input_shape = list(t_input.size())
input_shape[batch_dim] = confidence
# rand_func = torch.randint if t_input.dtype
new_dummy_input.append(
rand_like_with_shape(input_shape, t_input))
elif isinstance(dummy_input, dict):
new_dummy_input = {}
tmp_key = list(dummy_input.keys())[0]
old_batchsize = dummy_input[tmp_key].size(0)
device = dummy_input[tmp_key].device
for in_name, t_input in dummy_input.items():
assert isinstance(t_input, torch.Tensor), input_errmsg
assert old_batchsize == t_input.size(0), 'The first dimension should be batchsize\
and the batchsize of all inputs should be the same!'
input_shape = list(t_input.size())
input_shape[batch_dim] = confidence
new_dummy_input[in_name] = rand_like_with_shape(
input_shape, t_input)
else:
raise TypeError(input_errmsg)
return new_dummy_input, device
def _prepare_dummy_input(self, node):
"""
Prepare the dummy_input for the auto mask inference.
Parameters
----------
node: NodePyGroup
Returns
-------
dummy_input: list
List of tensors that will be used as input for the target node.
debugnames: list of strs
Debugnames of the dummy_inputs.
"""
_logger.debug('Prepare auto mask inference for node: %s',
node.unique_name)
# prepare the inputs and outputs mask for this node,
# if there is already a mask in self.masks, then use
# the original mask tensor, else create a new one.
inputs_name = node.inputs
# build the dummy_input, in_masks the target node
dummy_input = []
debugnames = []
for _input in inputs_name:
if _input not in self.internal_result:
# if the input debug name is not in self.internal_result,
# then this node isn't a output tensor of any predecessor
# nodes. This node is a attribute of the submodule, such as
# weight or bias, etc. We will skip these tensors.
# If we don't want this specific judgement here, we can merge
# the `prim::GetAttr` node of the weight/bias tensor into the key
# node, such as `conv`.
# This is caused by the `meage_module_node` function in the
# _graph_utils.py, because it doesn't merge the prim::GetAttr
# node into the key node. In current version of _graph_utils.py,
# we will only merge the nodes that have same scope name, however,
# the scope name of the correponding prim::GetAttr node of `weight` tensor
# is None.
continue
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
dummy_input.append(self.internal_result[_input].detach())
debugnames.append(_input)
return dummy_input, debugnames
def update_direct_sparsity(self, node):
"""
Update the direct sparsity for the target node. Here the direct sparsity
means that the sparsity in the output tensor that caused by the sparsity
in the input tensors/weight tensors.
"""
# this name is consistent with the name returned by named_modules()
module_name = node.name
_logger.info('Update mask for %s', module_name)
unique_name = node.unique_name
dummy_input, input_debugname = self._prepare_dummy_input(node)
# get the input mask from self.masks
# Note: the input mask of the successor nodes are
# already created by the predecessor node
in_masks = [self.masks[debugname] for debugname in input_debugname]
in_constants = [self.constant[debugname]
for debugname in input_debugname]
if node.type == 'func':
# we cannot get the runable function directly from the jit traced
# graph, so we translate it back to python function, Note: the function
# is appliable to both cpu/gpu devices, the output tensors will be on the
# same device of the input tensors
func = jit_to_python_function(node, self)
if func is None:
# no need to infer the sparsity for this node
self.auto_inferences[unique_name] = None
return
# function doesn't have weights
_auto_infer = AutoMaskInference(
func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim)
else:
weight_mask = None
if module_name in self.masks:
weight_mask = self.masks[module_name]
_, module = get_module_by_name(self.bound_model, module_name)
_auto_infer = AutoMaskInference(
module, dummy_input, in_masks, weight_mask, in_constants=in_constants,
state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim)
self.auto_inferences[unique_name] = _auto_infer
_auto_infer.name = node.unique_name
_auto_infer.update_direct_sparsity()
# also save the input debug names into the auto_infer
_auto_infer.input_debugname = input_debugname
# update the mask tensor and the internal output of the submodules
# after manually unpack the tuple/list of tensors, the number of the outputs
# of each node should always be one(Except for the TupleUnpack node at the end
# of the whole model)
assert len(
node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually'
out_debugname = node.outputs[0]
# update the output mask into self.masks
self.masks[out_debugname] = _auto_infer.output_mask
self.constant[out_debugname] = _auto_infer.out_constant
# update the output result into self.internal_result, so that
# the successor nodes can take these output tensors as inputs.
self.internal_result[out_debugname] = _auto_infer.output
# update the parameter mask of the node
self.masks[module_name] = _auto_infer.weight_mask
def update_indirect_sparsity(self, node):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
Parameters
---------
node: the NodePy
The target node to update the indirect sparsity
"""
unique_name = node.unique_name
if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None:
# if the auto inference object already in self.auto_inference, then
# directly update the previous one
# self.auto_inferences[unique_name].update()
_logger.info(
'Update the indirect sparsity for the %s', unique_name)
auto_infer = self.auto_inferences[unique_name]
auto_infer.update_indirect_sparsity()
# pass the gradient to the predecessor nodes
for in_id, tin in enumerate(auto_infer.dummy_input):
debug_name = auto_infer.input_debugname[in_id]
last_output = self.internal_result[debug_name]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data
else:
last_output.grad = tin.grad
else:
_logger.warning('Note: %s does not have corresponding mask inference object', node.name)
def _vnode_to_value(self, c_node):
"""
translate the C Value node into the values/tensors.
"""
errmsg = "Only support the torch._C.Value type"
assert isinstance(c_node, torch._C.Value), errmsg
if isinstance(c_node.type(), torch._C.TensorType):
shape = tuple(c_node.type().sizes())
dtype = c_node.type().scalarType()
# TODO should use a more general way to get the input
if dtype.startswith('Float') or dtype.startswith('Double'):
return torch.rand(shape).to(self.device)
else:
# This small range is due to the ·ReLU6·, we will add
# the manual specific mask inference rule for several
# ops in the future, so that we can remove the constraint.
return torch.randint(0, 10, shape, device=self.device)
else:
value = c_node.toIValue()
# TODO support more kinds of value node
errmsg = "Doesn't support convert %s to values", str(c_node.type())
# currently only support the tensors and constant values
assert value is not None, errmsg
return value
def infer_modules_masks(self):
"""
Infer the mask for all layers in the module, this function can be divided into
two steps: first, forward inference of the the masks. Second, backward inference
of the mask. We keep repeating these two steps until the masks of the model doesn't
change.
"""
# unpack the tensor tuple/list before the mask inference
self.torch_graph.unpack_manually()
# find the input/ouput tensor of the whole graph
graph_input = []
graph_output = []
for name, nodeio in self.torch_graph.nodes_py.nodes_io.items():
if nodeio.input_or_output == 'input':
graph_input.append((name, nodeio))
# also put the graph input tensor into the internal_result
# TODO if we can find the corresponding relation between the value node
# and the dummy_inputs, we can use the inputs value in the dummy_input
value = self._vnode_to_value(self.debugname_to_value[name])
self.internal_result[name] = value
# create the mask tensor for the input value
if isinstance(self.internal_result[name], torch.Tensor):
self.masks[name] = torch.ones_like(value)
self.constant[name] = torch.zeros_like(value)
elif nodeio.input_or_output == 'output':
graph_output.append((name, nodeio))
# count the degree for the node in the graph
in_degree = {}
out_degree = {}
visit_queue = queue.Queue()
for node in self.torch_graph.nodes_py.nodes_op:
successors = self.torch_graph.find_successors(node.unique_name)
out_degree[node.unique_name] = len(successors)
predecessors = self.torch_graph.find_predecessors(node.unique_name)
in_degree[node.unique_name] = len(predecessors)
if in_degree[node.unique_name] == 0:
visit_queue.put(node)
# Forward mask inference
while not visit_queue.empty():
curnode = visit_queue.get()
# forward mask inference for curnode
self.update_direct_sparsity(curnode)
successors = self.torch_graph.find_successors(curnode.unique_name)
for successor in successors:
in_degree[successor] -= 1
if in_degree[successor] == 0:
visit_queue.put(self.torch_graph.name_to_node[successor])
# backward mask inference
for unique_name in out_degree:
if out_degree[unique_name] == 0:
visit_queue.put(self.torch_graph.name_to_node[unique_name])
while not visit_queue.empty():
curnode = visit_queue.get()
self.update_indirect_sparsity(curnode)
predecessors = self.torch_graph.find_predecessors(
curnode.unique_name)
for predecessor in predecessors:
out_degree[predecessor] -= 1
if out_degree[predecessor] == 0:
visit_queue.put(self.torch_graph.name_to_node[predecessor])
def replace_compressed_modules(self):
"""
Replace all the modules that have changed (weights/inputs/output) shape.
The new module is created using the same arguments of the to-be-replaced module,
and correctly inherits its weights.
NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation
is that ```func``` should be not required to be replaced.
"""
with torch.no_grad():
for unique_name in self.auto_inferences:
self.replace_submodule(unique_name)
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
"""
Replace the submodule according to the inferred sparsity.
unique_name: str
The unique_name of the submodule to replace.
reindex_dim: int
The dimension of the re-index operation.
reindex: Reindex
The index tensor. Normally this variable is None. If we want to reindex the
output of this submodule, we can pass the index by this parameter.
"""
class ReindexModule(nn.Module):
"""
ReindexModule is used to resolve the mask conflict when replace the submodule.
Basically, we can use two ways to resolve the mask conflict: (1) unmask some
values(will introduce more computation overhead) (2) reindex and padd the output
tensor of the target op(introduce more memory access overhad). Currently this
method is shutdown, in the future, we will merge these two methods into a graph
pass which is used to resolve the mask conflict.
"""
def __init__(self, ori_module, reindex_dim, reindex):
super(ReindexModule, self).__init__()
self.ori_module = ori_module
self.reindex_dim = reindex_dim
self.reindex = reindex
tmp_index = [slice(None, None) for i in range(reindex_dim+1)]
# the index for the tensor
tmp_index[reindex_dim] = reindex
self.t_index = tuple(tmp_index)
def forward(self, x):
tmpout = self.ori_module(x)
shape = list(tmpout.size())
shape[self.reindex_dim] = self.reindex.size(0)
out = torch.zeros(tuple(shape), device=tmpout.device,
requires_grad=tmpout.requires_grad)
out[self.t_index] = tmpout
return out
assert unique_name in self.auto_inferences
g_node = self.torch_graph.name_to_node[unique_name]
_logger.debug("replace %s, in %s type, with op_type %s",
unique_name, g_node.type, g_node.op_type)
auto_infer = self.auto_inferences[unique_name]
if g_node.type == 'module':
if g_node.unique_name in self.torch_graph.reused_module:
if reindex_dim is not None:
_logger.warning(
'Cannot replace a reused module with padding operator!!')
return None
super_module, leaf_module = get_module_by_name(
self.bound_model, g_node.name)
m_type = g_node.op_type
if not m_type in replace_module:
raise RuntimeError(
"Has not supported replacing the module: `{}`".format(m_type))
_logger.info("replace module (name: %s, op_type: %s)",
g_node.name, m_type)
compressed_module = replace_module[m_type](
leaf_module, auto_infer.get_masks())
new_submodule = compressed_module
if reindex_dim is None:
setattr(super_module, g_node.name.split(
'.')[-1], compressed_module)
elif reindex_dim is not None and reindex is not None:
# reindex the output of this submodule and replace the orginal module
new_submodule = ReindexModule(
compressed_module, reindex_dim, reindex)
setattr(super_module, g_node.name.split(
'.')[-1], new_submodule)
return new_submodule
elif g_node.type == 'func':
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
unique_name, g_node.op_type)
return None
else:
raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def initialize_speedup(self):
"""
Do some initial work for speedup.
"""
# initialize the self.debugname_to_value
# build a mapping table from the debug name of the tensor
# to its value node in the graph
traced_graph = self.torch_graph.trace.graph
for node in traced_graph.nodes():
for _input in node.inputs():
debug_name = _input.debugName()
if debug_name not in self.debugname_to_value:
self.debugname_to_value[debug_name] = _input
for _output in node.outputs():
debug_name = _output.debugName()
if debug_name not in self.debugname_to_value:
self.debugname_to_value[debug_name] = _output
# put the model itself into internel_result to perform the
# value inference for the 'prim::GetAttr', the first ClassType
# of the whole graph is the model class
for graph_input in traced_graph.inputs():
if graph_input.type().kind() == 'ClassType':
self.internal_result[graph_input.debugName()
] = self.bound_model
break
def speedup_model(self):
"""
There are basically two steps: first, do mask/shape inference,
second, replace modules.
"""
_logger.info("start to speed up the model")
self.initialize_speedup()
training = self.bound_model.training
# set to the evaluation mode
self.bound_model.train(False)
# TODO suppose to fix the conflict after the sparsity propagation
# which is more elegent
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info('resolve the mask conflict')
# load the original stat dict before replace the model
self.bound_model.load_state_dict(self.ori_state_dict)
_logger.info("replace compressed modules...")
# the mask conflict should be already resolved
self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done")