Source code for nni.retiarii.serializer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
import functools
import inspect
import types
from typing import Any

import json_tricks

from .utils import get_importable_name, get_module_name, import_, reset_uid


def get_init_parameters_or_fail(obj, silently=False):
    if hasattr(obj, '_init_parameters'):
        return obj._init_parameters
    elif silently:
        return None
    else:
        raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
                         'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
                         'If it is a customized module, please to decorate it with @basic_unit. '
                         'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
                         'try to use serialize or @serialize_cls.')


### This is a patch of json-tricks to make it more useful to us ###


def _serialize_class_instance_encode(obj, primitives=False):
    assert not primitives, 'Encoding with primitives is not supported.'
    try:  # FIXME: raise error
        if hasattr(obj, '__class__'):
            return {
                '__type__': get_importable_name(obj.__class__),
                'arguments': get_init_parameters_or_fail(obj)
            }
    except ValueError:
        pass
    return obj


def _serialize_class_instance_decode(obj):
    if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
        return import_(obj['__type__'])(**obj['arguments'])
    return obj


def _type_encode(obj, primitives=False):
    assert not primitives, 'Encoding with primitives is not supported.'
    if isinstance(obj, type):
        return {'__typename__': get_importable_name(obj, relocate_module=True)}
    if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
        # This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
        # https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
        return {'__typename__': get_importable_name(obj, relocate_module=True)}
    return obj


def _type_decode(obj):
    if isinstance(obj, dict) and '__typename__' in obj:
        return import_(obj['__typename__'])
    return obj


json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])

### End of json-tricks patch ###


class Translatable(abc.ABC):
    """
    Inherit this class and implement ``translate`` when the inner class needs a different
    parameter from the wrapper class in its init function.
    """

    @abc.abstractmethod
    def _translate(self) -> Any:
        pass


def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False, stop_parsing=True):
    class wrapper(cls):
        def __init__(self, *args, **kwargs):
            self._stop_parsing = stop_parsing
            if reset_mutation_uid:
                reset_uid('mutation')
            if store_init_parameters:
                argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
                full_args = {}
                full_args.update(kwargs)

                assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
                for argname, value in zip(argname_list, args):
                    full_args[argname] = value

                # translate parameters
                args = list(args)
                for i, value in enumerate(args):
                    if isinstance(value, Translatable):
                        args[i] = value._translate()
                for i, value in kwargs.items():
                    if isinstance(value, Translatable):
                        kwargs[i] = value._translate()

                self._init_parameters = full_args
            else:
                self._init_parameters = {}

            super().__init__(*args, **kwargs)

    wrapper.__module__ = get_module_name(cls)
    wrapper.__name__ = cls.__name__
    wrapper.__qualname__ = cls.__qualname__
    wrapper.__init__.__doc__ = cls.__init__.__doc__

    return wrapper


def serialize_cls(cls):
    """
    To create an serializable class.
    """
    return _create_wrapper_cls(cls)


def transparent_serialize(cls):
    """
    Wrap a module but does not record parameters. For internal use only.
    """
    return _create_wrapper_cls(cls, store_init_parameters=False)


[docs]def serialize(cls, *args, **kwargs): """ To create an serializable instance inline without decorator. For example, .. code-block:: python self.op = serialize(MyCustomOp, hidden_units=128) """ return serialize_cls(cls)(*args, **kwargs)
def basic_unit(cls): """ To wrap a module as a basic unit, to stop it from parsing and make it mutate-able. """ import torch.nn as nn assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' return serialize_cls(cls) def model_wrapper(cls): """ Wrap the model if you are using pure-python execution engine. The wrapper serves two purposes: 1. Capture the init parameters of python class so that it can be re-instantiated in another process. 2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios. """ return _create_wrapper_cls(cls, reset_mutation_uid=True, stop_parsing=False)