nni.nas.evaluator.functional 源代码
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import ClassVar
from nni.common.serializer import SerializableObject
from .evaluator import MutableEvaluator
[文档]class FunctionalEvaluator(MutableEvaluator):
"""
Functional evaluator that directly takes a function and thus should be general.
See :class:`~nni.nas.evaluator.Evaluator` for instructions on how to write this function.
Attributes
----------
function
The full name of the function.
arguments
Keyword arguments for the function other than model.
"""
# The functional evaluator has already been equipped with "trace" functionality.
# It shouldn't be traced again when wrapped with `nni.trace`.
_traced: ClassVar[bool] = True
def __init__(self, function, **kwargs):
self.function = function
self.arguments = kwargs
def extra_repr(self):
return f"{self.function!r}, arguments={self.arguments!r})"
# NOTE: FunctionalEvaluator implements the traceable interface by itself,
# so that it doesn't need the `nni.trace` decorator.
# But I guess it works with the decorator as well.
@property
def trace_symbol(self):
return self.__class__
@property
def trace_args(self):
return []
@property
def trace_kwargs(self):
return {
'function': self.function,
**self.arguments
}
def trace_copy(self):
return SerializableObject(self.__class__, [], self.trace_kwargs)
def evaluate(self, model):
return self.function(model, **self.arguments)