nni.nas.strategy.utils 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

__all__ = ['DeduplicationHelper', 'DuplicationError', 'RetrySamplingHelper']

import logging
from typing import Any, Type, TypeVar, Callable

from nni.mutable import SampleValidationError

_logger = logging.getLogger(__name__)

T = TypeVar('T')


def _to_hashable(obj):
    """Trick to make a dict saveable in a set."""
    if isinstance(obj, dict):
        return frozenset((k, _to_hashable(v)) for k, v in obj.items())
    if isinstance(obj, list):
        return tuple(_to_hashable(v) for v in obj)
    return obj


[文档] class DuplicationError(SampleValidationError): """Exception raised when a sample is duplicated.""" def __init__(self, sample): super().__init__(f'Duplicated sample found: {sample}')
[文档] class DeduplicationHelper: """Helper class to deduplicate samples. Different from the deduplication on the HPO side, this class simply checks if a sample has been tried before, and does nothing else. """ def __init__(self, raise_on_dup: bool = False): self._history = set() self._raise_on_dup = raise_on_dup
[文档] def dedup(self, sample: Any) -> bool: """ If the new sample has not been seen before, it will be added to the history and return True. Otherwise, return False directly. If raise_on_dup is true, a :class:`DuplicationError` will be raised instead of returning False. """ sample = _to_hashable(sample) if sample in self._history: _logger.debug('Duplicated sample found: %s', sample) if self._raise_on_dup: raise DuplicationError(sample) return False self._history.add(sample) return True
[文档] def remove(self, sample: Any) -> None: """ Remove a sample from the history. """ self._history.remove(_to_hashable(sample))
def reset(self): self._history = set() def state_dict(self): return { 'dedup_history': list(self._history) } def load_state_dict(self, state_dict): self._history = set(state_dict['dedup_history'])
[文档] class RetrySamplingHelper: """Helper class to retry a function until it succeeds. Typical use case is to retry random sampling until a non-duplicate / valid sample is found. Parameters ---------- retries Number of retries. exception_types Exception types to catch. raise_last Whether to raise the last exception if all retries failed. """ def __init__(self, retries: int = 500, exception_types: tuple[Type[Exception]] = (SampleValidationError,), raise_last: bool = False): self.retries = retries self.exception_types = exception_types self.raise_last = raise_last def retry(self, func: Callable[..., T], *args, **kwargs) -> T | None: for retry in range(self.retries): try: return func(*args, **kwargs) except self.exception_types as e: if retry in [0, 10, 100, 1000]: _logger.debug('Sampling failed. %d retries so far. Exception caught: %r', retry, e) if retry >= self.retries - 1 and self.raise_last: _logger.warning('Sampling failed after %d retries. Giving up and raising the last exception.', self.retries) raise _logger.warning('Sampling failed after %d retires. Giving up and returning None.', self.retries) return None