nni.nas.space.metrics 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

__all__ = ['Metrics']

from typing import Any, Sequence, cast

from nni.typehint import TrialMetric

[文档] class Metrics: """ Data structure that manages the metric data (e.g., loss, accuracy, etc.). NOTE: Multiple metrics and minimized metrics are not supported in the current iteration. Parameters ---------- strict Whether to convert the metrics into a float. If ``true``, only float metrics or dict with "default" are accepted. """ def __init__(self, strict: bool = True): self.strict = strict self._intermediates: list[TrialMetric] = [] self._final: TrialMetric | None = None def __bool__(self): """Return whether the metrics has been (at least partially) filled.""" return bool(self.intermediates or self.final) def __repr__(self): return f"Metrics(intermediates=<array of length {len(self.intermediates)}>, final={self.final})" def _dump(self) -> dict: rv: dict[str, Any] = {'intermediates': self._intermediates} if self.final is not None: rv['final'] = self.final return rv @classmethod def _load(cls, intermediates: list[TrialMetric], final: TrialMetric | None = None) -> Metrics: rv = Metrics() rv._intermediates = intermediates rv._final = final return rv def add_intermediate(self, metric: TrialMetric) -> None: self._intermediates.append(self._canonicalize_metric(metric)) def clear(self) -> None: self._intermediates.clear() self._final = None def __eq__(self, other: Any) -> bool: return isinstance(other, Metrics) and self._intermediates == other._intermediates and self._final == other._final @property def intermediates(self) -> Sequence[TrialMetric]: return self._intermediates @property def final(self) -> TrialMetric | None: return self._final @final.setter def final(self, metric: TrialMetric) -> None: self._final = self._canonicalize_metric(metric) def _canonicalize_metric(self, metric: Any) -> TrialMetric: if not self.strict: return cast(TrialMetric, metric) if isinstance(metric, dict): if 'default' not in metric: raise ValueError(f"Metric dict {metric} does not contain key 'default'") metric = metric['default'] if not isinstance(metric, (int, float)): raise ValueError(f"Metric {metric} is not a number") return float(metric)