Source code for nni.nas.pytorch.utils

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from collections import OrderedDict

import numpy as np
import torch

_counter = 0

_logger = logging.getLogger(__name__)


[docs]def global_mutable_counting(): """ A program level counter starting from 1. """ global _counter _counter += 1 return _counter
def _reset_global_mutable_counting(): """ Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys. """ global _counter _counter = 0
[docs]def to_device(obj, device): """ Move a tensor, tuple, list, or dict onto device. """ if torch.is_tensor(obj): return obj.to(device) if isinstance(obj, tuple): return tuple(to_device(t, device) for t in obj) if isinstance(obj, list): return [to_device(t, device) for t in obj] if isinstance(obj, dict): return {k: to_device(v, device) for k, v in obj.items()} if isinstance(obj, (int, float, str)): return obj raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def to_list(arr): if torch.is_tensor(arr): return arr.cpu().numpy().tolist() if isinstance(arr, np.ndarray): return arr.tolist() if isinstance(arr, (list, tuple)): return list(arr) return arr
[docs]class AverageMeterGroup: """ Average meter group for multiple average meters. """ def __init__(self): self.meters = OrderedDict()
[docs] def update(self, data): """ Update the meter group with a dict of metrics. Non-exist average meters will be automatically created. """ for k, v in data.items(): if k not in self.meters: self.meters[k] = AverageMeter(k, ":4f") self.meters[k].update(v)
def __getattr__(self, item): return self.meters[item] def __getitem__(self, item): return self.meters[item] def __str__(self): return " ".join(str(v) for v in self.meters.values())
[docs] def summary(self): """ Return a summary string of group data. """ return " ".join(v.summary() for v in self.meters.values())
[docs]class AverageMeter: """ Computes and stores the average and current value. Parameters ---------- name : str Name to display. fmt : str Format string to print the values. """ def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset()
[docs] def reset(self): """ Reset the meter. """ self.val = 0 self.avg = 0 self.sum = 0 self.count = 0
[docs] def update(self, val, n=1): """ Update with value and weight. Parameters ---------- val : float or int The new value to be accounted in. n : int The weight of the new value. """ self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count
def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def summary(self): fmtstr = '{name}: {avg' + self.fmt + '}' return fmtstr.format(**self.__dict__)
[docs]class StructuredMutableTreeNode: """ A structured representation of a search space. A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`. This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a ``Mutable`` (other than ``MutableScope``). Parameters ---------- mutable : nni.nas.pytorch.mutables.Mutable The mutable that current node is linked with. """ def __init__(self, mutable): self.mutable = mutable self.children = []
[docs] def add_child(self, mutable): """ Add a tree node to the children list of current node. """ self.children.append(StructuredMutableTreeNode(mutable)) return self.children[-1]
[docs] def type(self): """ Return the ``type`` of mutable content. """ return type(self.mutable)
def __iter__(self): return self.traverse()
[docs] def traverse(self, order="pre", deduplicate=True, memo=None): """ Return a generator that generates a list of mutables in this tree. Parameters ---------- order : str pre or post. If pre, current mutable is yield before children. Otherwise after. deduplicate : bool If true, mutables with the same key will not appear after the first appearance. memo : dict An auxiliary dict that memorize keys seen before, so that deduplication is possible. Returns ------- generator of Mutable """ if memo is None: memo = set() assert order in ["pre", "post"] if order == "pre": if self.mutable is not None: if not deduplicate or self.mutable.key not in memo: memo.add(self.mutable.key) yield self.mutable for child in self.children: for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): yield m if order == "post": if self.mutable is not None: if not deduplicate or self.mutable.key not in memo: memo.add(self.mutable.key) yield self.mutable