Source code for nni.common.serializer

import abc
import base64
import collections.abc
import copy
import functools
import inspect
import numbers
import types
import warnings
from io import IOBase
from typing import Any, Dict, List, Optional, TypeVar, Union

import cloudpickle  # use cloudpickle as backend for unserializable types and instances
import json_tricks  # use json_tricks as serializer backend

__all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable']


T = TypeVar('T')


class PayloadTooLarge(Exception):
    pass


class Traceable(abc.ABC):
    """
    A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
    Dict returns a TraceDictType to enable serialization.
    """
    @abc.abstractmethod
    def trace_copy(self) -> 'Traceable':
        """
        Perform a shallow copy.
        NOTE: NONE of the attributes will be preserved.
        This is the one that should be used when you want to "mutate" a serializable object.
        """
        ...

    @property
    @abc.abstractmethod
    def trace_symbol(self) -> Any:
        """
        Symbol object. Could be a class or a function.
        ``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
        convert the symbol into a string and convert the string back to symbol.
        """
        ...

    @property
    @abc.abstractmethod
    def trace_args(self) -> List[Any]:
        """
        List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
        in which case all the positional arguments are converted into keyword arguments.
        """
        ...

    @property
    @abc.abstractmethod
    def trace_kwargs(self) -> Dict[str, Any]:
        """
        Dict of keyword arguments.
        """
        ...


class Translatable(abc.ABC):
    """
    Inherit this class and implement ``translate`` when the wrapped class needs a different
    parameter from the wrapper class in its init function.
    """

    @abc.abstractmethod
    def _translate(self) -> Any:
        pass

    @staticmethod
    def _translate_argument(d: Any) -> Any:
        if isinstance(d, Translatable):
            return d._translate()
        return d


def is_traceable(obj: Any) -> bool:
    """
    Check whether an object is a traceable instance (not type).
    """
    return hasattr(obj, 'trace_copy') and \
        hasattr(obj, 'trace_symbol') and \
        hasattr(obj, 'trace_args') and \
        hasattr(obj, 'trace_kwargs') and \
        not inspect.isclass(obj)


class SerializableObject(Traceable):
    """
    Serializable object is a wrapper of existing python objects, that supports dump and load easily.
    Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
    """

    def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False):
        # use dict to avoid conflicts with user's getattr and setattr
        self.__dict__['_nni_symbol'] = symbol
        self.__dict__['_nni_args'] = args
        self.__dict__['_nni_kwargs'] = kwargs
        self.__dict__['_nni_call_super'] = call_super

        if call_super:
            # call super means that the serializable object is by itself an object of the target class
            super().__init__(
                *[_argument_processor(arg) for arg in args],
                **{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
            )

    def trace_copy(self) -> Union[T, 'SerializableObject']:
        return SerializableObject(
            self.trace_symbol,
            [copy.copy(arg) for arg in self.trace_args],
            {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
        )

    @property
    def trace_symbol(self) -> Any:
        return self._get_nni_attr('symbol')

    @trace_symbol.setter
    def trace_symbol(self, symbol: Any) -> None:
        # for mutation purposes
        self.__dict__['_nni_symbol'] = symbol

    @property
    def trace_args(self) -> List[Any]:
        return self._get_nni_attr('args')

    @trace_args.setter
    def trace_args(self, args: List[Any]):
        self.__dict__['_nni_args'] = args

    @property
    def trace_kwargs(self) -> Dict[str, Any]:
        return self._get_nni_attr('kwargs')

    @trace_kwargs.setter
    def trace_kwargs(self, kwargs: Dict[str, Any]):
        self.__dict__['_nni_kwargs'] = kwargs

    def _get_nni_attr(self, name: str) -> Any:
        return self.__dict__['_nni_' + name]

    def __repr__(self):
        if self._get_nni_attr('call_super'):
            return super().__repr__()
        return 'SerializableObject(' + \
            ', '.join(['type=' + self._get_nni_attr('symbol').__name__] +
                      [repr(d) for d in self._get_nni_attr('args')] +
                      [k + '=' + repr(v) for k, v in self._get_nni_attr('kwargs').items()]) + \
            ')'


def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
    # If an object is already created, this can be a fix so that the necessary info are re-injected into the object.

    def getter_factory(x):
        return lambda self: self.__dict__['_nni_' + x]

    def setter_factory(x):
        def setter(self, val):
            self.__dict__['_nni_' + x] = val

        return setter

    def trace_copy(self):
        return SerializableObject(
            self.trace_symbol,
            [copy.copy(arg) for arg in self.trace_args],
            {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
        )

    attributes = {
        'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')),
        'trace_args': property(getter_factory('args'), setter_factory('args')),
        'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')),
        'trace_copy': trace_copy
    }

    if hasattr(obj, '__class__') and hasattr(obj, '__dict__'):
        for name, method in attributes.items():
            setattr(obj.__class__, name, method)
    else:
        wrapper = type('wrapper', (Traceable, type(obj)), attributes)
        obj = wrapper(obj)  # pylint: disable=abstract-class-instantiated

    # make obj complying with the interface of traceable, though we cannot change its base class
    obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)

    return obj


[docs]def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]: """ Annotate a function or a class if you want to preserve where it comes from. This is usually used in the following scenarios: 1) Care more about execution configuration rather than results, which is usually the case in AutoML. For example, you want to mutate the parameters of a function. 2) Repeat execution is not an issue (e.g., reproducible, execution is fast without side effects). When a class/function is annotated, all the instances/calls will return a object as it normally will. Although the object might act like a normal object, it's actually a different object with NNI-specific properties. One exception is that if your function returns None, it will return an empty traceable object instead, which should raise your attention when you want to check whether the None ``is None``. When parameters of functions are received, it is first stored, and then a shallow copy will be passed to wrapped function/class. This is to prevent mutable objects gets modified in the wrapped function/class. When the function finished execution, we also record extra information about where this object comes from. That's why it's called "trace". When call ``nni.dump``, that information will be used, by default. If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument list and types. This can be useful to extract semantics, but can be tricky in some corner cases. .. warning:: Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class. This might hang when generators produce an infinite sequence. We might introduce an API to control this behavior in future. Example: .. code-block:: python @nni.trace def foo(bar): pass """ def wrap(cls_or_func): # already annotated, do nothing if getattr(cls_or_func, '_traced', False): return cls_or_func if isinstance(cls_or_func, type): cls_or_func = _trace_cls(cls_or_func, kw_only) elif _is_function(cls_or_func): cls_or_func = _trace_func(cls_or_func, kw_only) else: raise TypeError(f'{cls_or_func} of type {type(cls_or_func)} is not supported to be traced. ' 'File an issue at https://github.com/microsoft/nni/issues if you believe this is a mistake.') cls_or_func._traced = True return cls_or_func # if we're being called as @trace() if cls_or_func is None: return wrap # if we are called without parentheses return wrap(cls_or_func)
[docs]def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096, allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]: """ Convert a nested data structure to a json string. Save to file if fp is specified. Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle. The serializer is not designed for long-term storage use, but rather to copy data between processes. The format is also subject to change between NNI releases. Parameters ---------- obj : any The object to dump. fp : file handler or path File to write to. Keep it none if you want to dump a string. pickle_size_limit : int This is set to avoid too long serialization result. Set to -1 to disable size check. allow_nan : bool Whether to allow nan to be serialized. Different from default value in json-tricks, our default value is true. json_tricks_kwargs : dict Other keyword arguments passed to json tricks (backend), e.g., indent=2. Returns ------- str or bytes Normally str. Sometimes bytes (if compressed). """ encoders = [ # we don't need to check for dependency as many of those have already been required by NNI json_tricks.pathlib_encode, # pathlib is a required dependency for NNI json_tricks.pandas_encode, # pandas is a required dependency json_tricks.numpy_encode, # required json_tricks.encoders.enum_instance_encode, json_tricks.json_date_time_encode, # same as json_tricks json_tricks.json_complex_encode, json_tricks.json_set_encode, json_tricks.numeric_types_encode, functools.partial(_json_tricks_serializable_object_encode, use_trace=use_trace), functools.partial(_json_tricks_func_or_cls_encode, pickle_size_limit=pickle_size_limit), functools.partial(_json_tricks_any_object_encode, pickle_size_limit=pickle_size_limit), ] json_tricks_kwargs['allow_nan'] = allow_nan if fp is not None: return json_tricks.dump(obj, fp, obj_encoders=encoders, **json_tricks_kwargs) else: return json_tricks.dumps(obj, obj_encoders=encoders, **json_tricks_kwargs)
[docs]def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comments: bool = True, **json_tricks_kwargs) -> Any: """ Load the string or from file, and convert it to a complex data structure. At least one of string or fp has to be not none. Parameters ---------- string : str JSON string to parse. Can be set to none if fp is used. fp : str File path to load JSON from. Can be set to none if string is used. ignore_comments : bool Remove comments (starting with ``#`` or ``//``). Default is true. Returns ------- any The loaded object. """ assert string is not None or fp is not None # see encoders for explanation hooks = [ json_tricks.pathlib_hook, json_tricks.pandas_hook, json_tricks.json_numpy_obj_hook, json_tricks.decoders.EnumInstanceHook(), json_tricks.json_date_time_hook, json_tricks.json_complex_hook, json_tricks.json_set_hook, json_tricks.numeric_types_hook, _json_tricks_serializable_object_decode, _json_tricks_func_or_cls_decode, _json_tricks_any_object_decode ] # to bypass a deprecation warning in json-tricks json_tricks_kwargs['ignore_comments'] = ignore_comments if string is not None: if isinstance(string, IOBase): raise TypeError(f'Expect a string, found a {string}. If you intend to use a file, use `nni.load(fp=file)`') return json_tricks.loads(string, obj_pairs_hooks=hooks, **json_tricks_kwargs) else: return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs)
def _trace_cls(base, kw_only, call_super=True): # the implementation to trace a class is to store a copy of init arguments # this won't support class that defines a customized new but should work for most cases class wrapper(SerializableObject, base): def __init__(self, *args, **kwargs): # store a copy of initial parameters args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True) # calling serializable object init to initialize the full object super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super) _copy_class_wrapper_attributes(base, wrapper) return wrapper def _trace_func(func, kw_only): @functools.wraps(func) def wrapper(*args, **kwargs): # similar to class, store parameters here args, kwargs = _formulate_arguments(func, args, kwargs, kw_only) # it's not clear whether this wrapper can handle all the types in python # There are many cases here: https://docs.python.org/3/reference/datamodel.html # but it looks that we have handled most commonly used cases res = func( *[_argument_processor(arg) for arg in args], **{kw: _argument_processor(arg) for kw, arg in kwargs.items()} ) if res is None: # don't call super, makes no sense. # an empty serializable object is "none". Don't check it though. res = SerializableObject(func, args, kwargs, call_super=False) elif hasattr(res, '__class__') and hasattr(res, '__dict__'): # is a class, inject interface directly # need to be done before primitive types because there could be inheritance here. res = inject_trace_info(res, func, args, kwargs) elif isinstance(res, (collections.abc.Callable, types.ModuleType, IOBase)): raise TypeError(f'Try to add trace info to {res}, but functions and modules are not supported.') elif isinstance(res, (numbers.Number, collections.abc.Sequence, collections.abc.Set, collections.abc.Mapping)): # handle primitive types like int, str, set, dict, tuple # NOTE: simple types including none, bool, int, float, list, tuple, dict # will be directly captured by python json encoder # and thus not possible to restore the trace parameters after dump and reload. # this is a known limitation. res = inject_trace_info(res, func, args, kwargs) else: raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. ' 'Please file an issue at https://github.com/microsoft/nni/issues') return res return wrapper def _copy_class_wrapper_attributes(base, wrapper): _MISSING = '_missing' # assign magic attributes like __module__, __qualname__, __doc__ for k in functools.WRAPPER_ASSIGNMENTS: v = getattr(base, k, _MISSING) if v is not _MISSING: try: setattr(wrapper, k, v) except AttributeError: pass wrapper.__wrapped__ = base def _argument_processor(arg): # 1) translate # handle cases like ValueChoice # This is needed because sometimes the recorded arguments are meant to be different from what the wrapped object receives. arg = Translatable._translate_argument(arg) # 2) prevent the stored parameters to be mutated by wrapped class. # an example: https://github.com/microsoft/nni/issues/4329 if isinstance(arg, (collections.abc.MutableMapping, collections.abc.MutableSequence, collections.abc.MutableSet)): arg = copy.copy(arg) return arg def _formulate_single_argument(arg): # this is different from argument processor # it directly apply the transformation on the stored arguments # expand generator into list # Note that some types that are generator (such as range(10)) may not be identified as generator here. if isinstance(arg, types.GeneratorType): arg = list(arg) return arg def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False): # This is to formulate the arguments and make them well-formed. if kw_only: # get arguments passed to a function, and save it as a dict argname_list = list(inspect.signature(func).parameters.keys()) if is_class_init: argname_list = argname_list[1:] full_args = {} # match arguments with given arguments # args should be longer than given list, because args can be used in a kwargs way assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.' for argname, value in zip(argname_list, args): full_args[argname] = value # use kwargs to override full_args.update(kwargs) args, kwargs = [], full_args args = [_formulate_single_argument(arg) for arg in args] kwargs = {k: _formulate_single_argument(arg) for k, arg in kwargs.items()} return list(args), kwargs def _is_function(obj: Any) -> bool: # https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType, types.BuiltinMethodType)) def _import_cls_or_func_from_name(target: str) -> Any: if target is None: return None path, identifier = target.rsplit('.', 1) module = __import__(path, globals(), locals(), [identifier]) return getattr(module, identifier) def _strip_trace_type(traceable: Any) -> Any: if getattr(traceable, '_traced', False): return traceable.__wrapped__ return traceable def _get_cls_or_func_name(cls_or_func: Any) -> str: module_name = cls_or_func.__module__ if module_name == '__main__': raise ImportError('Cannot use a path to identify something from __main__.') full_name = module_name + '.' + cls_or_func.__name__ try: imported = _import_cls_or_func_from_name(full_name) # ignores the differences in trace if _strip_trace_type(imported) != _strip_trace_type(cls_or_func): raise ImportError(f'Imported {imported} is not same as expected. The function might be dynamically created.') except ImportError: raise ImportError(f'Import {cls_or_func.__name__} from "{module_name}" failed.') return full_name def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str: try: name = _get_cls_or_func_name(cls_or_func) # import success, use a path format return 'path:' + name except (ImportError, AttributeError): b = cloudpickle.dumps(cls_or_func) if len(b) > pickle_size_limit: raise ValueError(f'Pickle too large when trying to dump {cls_or_func}. ' 'Please try to raise pickle_size_limit if you insist.') # fallback to cloudpickle return 'bytes:' + base64.b64encode(b).decode() def import_cls_or_func_from_hybrid_name(s: str) -> Any: if s.startswith('bytes:'): b = base64.b64decode(s.split(':', 1)[-1]) return cloudpickle.loads(b) if s.startswith('path:'): s = s.split(':', 1)[-1] return _import_cls_or_func_from_name(s) def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str: if not isinstance(cls_or_func, type) and not _is_function(cls_or_func): # not a function or class, continue return cls_or_func return { '__nni_type__': get_hybrid_cls_or_func_name(cls_or_func, pickle_size_limit) } def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any: if isinstance(s, dict) and '__nni_type__' in s: s = s['__nni_type__'] return import_cls_or_func_from_hybrid_name(s) return s def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False, use_trace: bool = True) -> Dict[str, Any]: # Encodes a serializable object instance to json. # do nothing to instance that is not a serializable object and do not use trace if not use_trace or not is_traceable(obj): return obj if isinstance(obj.trace_symbol, property): # commonly made mistake when users forget to call the traced function/class. warnings.warn(f'The symbol of {obj} is found to be a property. Did you forget to create the instance with ``xx(...)``?') ret = {'__symbol__': get_hybrid_cls_or_func_name(obj.trace_symbol)} if obj.trace_args: ret['__args__'] = obj.trace_args if obj.trace_kwargs: ret['__kwargs__'] = obj.trace_kwargs return ret def _json_tricks_serializable_object_decode(obj: Dict[str, Any]) -> Any: if isinstance(obj, dict) and '__symbol__' in obj: symbol = import_cls_or_func_from_hybrid_name(obj['__symbol__']) args = obj.get('__args__', []) kwargs = obj.get('__kwargs__', {}) return trace(symbol)(*args, **kwargs) return obj def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> Any: # We want to use this to replace the class instance encode in json-tricks. # Therefore the coverage should be roughly same. if isinstance(obj, list) or isinstance(obj, dict): return obj if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')): b = cloudpickle.dumps(obj) if len(b) > pickle_size_limit > 0: raise PayloadTooLarge(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are ' 'not decorated by @nni.trace. Another option is to force bytes pickling and ' 'try to raise pickle_size_limit.') # use base64 to dump a bytes array return { '__nni_obj__': base64.b64encode(b).decode() } return obj def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any: if isinstance(obj, dict) and '__nni_obj__' in obj: obj = obj['__nni_obj__'] b = base64.b64decode(obj) return cloudpickle.loads(b) return obj