nni.algorithms.hpo.random_tuner 源代码
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Naive random tuner.
You can specify an integer seed to determine random result.
"""
from __future__ import annotations
__all__ = ['RandomTuner']
import logging
import numpy as np
import schema
from nni import ClassArgsValidator
from nni.common.hpo_utils import Deduplicator, format_search_space, deformat_parameters
from nni.tuner import Tuner
_logger = logging.getLogger('nni.tuner.random')
[文档]class RandomTuner(Tuner):
"""
A naive tuner that generates fully random hyperparameters.
Examples
--------
.. code-block::
config.tuner.name = 'Random'
config.tuner.class_args = {
'seed': 100
}
Parameters
----------
seed
The random seed.
"""
def __init__(self, seed: int | None = None, optimize_mode: str | None = None):
self.space = None
if seed is None: # explicitly generate a seed to make the experiment reproducible
seed = np.random.default_rng().integers(2 ** 31)
self.rng = np.random.default_rng(seed)
self.dedup = None
_logger.info(f'Using random seed {seed}')
if optimize_mode is not None:
_logger.info(f'Ignored optimize_mode "{optimize_mode}"')
def update_search_space(self, space):
self.space = format_search_space(space)
self.dedup = Deduplicator(self.space)
def generate_parameters(self, *args, **kwargs):
params = suggest(self.rng, self.space)
params = self.dedup(params)
return deformat_parameters(params, self.space)
def receive_trial_result(self, *args, **kwargs):
pass
class RandomClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
schema.Schema({
schema.Optional('optimize_mode'): str,
schema.Optional('seed'): int,
}).validate(kwargs)
def suggest(rng, space):
params = {}
for key, spec in space.items():
if spec.is_activated_in(params):
params[key] = suggest_parameter(rng, spec)
return params
def suggest_parameter(rng, spec):
if spec.categorical:
return rng.integers(spec.size)
if spec.normal_distributed:
return rng.normal(spec.mu, spec.sigma)
else:
return rng.uniform(spec.low, spec.high)