# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import defaultdict
import numpy as np
import torch
from .base_mutator import BaseMutator
from .mutables import LayerChoice, InputChoice
from .utils import to_list
logger = logging.getLogger(__name__)
[docs]class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = dict()
self._connect_all = False
[docs] def sample_search(self):
"""
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
[docs] def sample_final(self):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
[docs] def reset(self):
"""
Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local
variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly.
"""
self._cache = self.sample_search()
[docs] def export(self):
"""
Resample (for final) and return results.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
sampled = self.sample_final()
result = dict()
for mutable in self.mutables:
if not isinstance(mutable, (LayerChoice, InputChoice)):
# not supported as built-in
continue
result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key))
if sampled:
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
return result
[docs] def status(self):
"""
Return current selection status of mutator.
Returns
-------
dict
A mapping from key of mutables to decisions. All weights (boolean type and float type)
are converted into real number values. Numpy arrays and tensors are converted into list.
"""
data = dict()
for k, v in self._cache.items():
if torch.is_tensor(v):
v = v.detach().cpu().numpy()
if isinstance(v, np.ndarray):
v = v.astype(np.float32).tolist()
data[k] = v
return data
[docs] def graph(self, inputs):
"""
Return model supernet graph.
Parameters
----------
inputs: tuple of tensor
Inputs that will be feeded into the network.
Returns
-------
dict
Containing ``node``, in Tensorboard GraphDef format.
Additional key ``mutable`` is a map from key to list of modules.
"""
if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from nni._graph_utils import build_graph
from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed
try:
self._connect_all = True
graph_def, _ = build_graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def)
finally:
self._connect_all = False
# `mutable` is to map the keys to a list of corresponding modules.
# A key can be linked to multiple modules, use `dedup=False` to find them all.
result["mutable"] = defaultdict(list)
for mutable in self.mutables.traverse(deduplicate=False):
# A module will be represent in the format of
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
# This format is aligned with the scope name jit gives.
modules = mutable.name.split(".")
path = [
{"type": self.model.__class__.__name__, "name": ""}
]
m = self.model
for module in modules:
m = getattr(m, module)
path.append({
"type": m.__class__.__name__,
"name": module
})
result["mutable"][mutable.key].append(path)
return result
[docs] def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
Only operations with non-zero weight will be executed. The results will be added to a list.
Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
Parameters
----------
mutable : LayerChoice
Layer choice module.
args : list of torch.Tensor
Inputs
kwargs : dict
Inputs
Returns
-------
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*args, **kwargs) for op in mutable]), \
torch.ones(len(mutable)).bool()
def _map_fn(op, args, kwargs):
return op(*args, **kwargs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
"""
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
"BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
"FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError("Unrecognized mask '%s'" % mask)
if not torch.is_tensor(mask):
mask = torch.tensor(mask) # pylint: disable=not-callable
return out, mask
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _all_connect_tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
return torch.stack(tensor_list).sum(0)
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable : Mutable
Returns
-------
object
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result
def _convert_mutable_decision_to_human_readable(self, mutable, sampled):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list = to_list(sampled)
converted = None
# If it's a boolean array, we can do optimization.
if all([t == 0 or t == 1 for t in multihot_list]):
if isinstance(mutable, LayerChoice):
assert len(multihot_list) == len(mutable), \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all modules have different names and they indeed have names
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if isinstance(mutable, InputChoice):
assert len(multihot_list) == mutable.n_candidates, \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all input candidates have different names
if len(set(mutable.choose_from)) == mutable.n_candidates:
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if converted is not None:
# if only one element, then remove the bracket
if len(converted) == 1:
converted = converted[0]
else:
# do nothing
converted = multihot_list
return converted