Source code for nni.algorithms.hpo.batch_tuner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
batch_tuner.py including:
    class BatchTuner
"""

import logging

import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner

TYPE = '_type'
CHOICE = 'choice'
VALUE = '_value'

LOGGER = logging.getLogger('batch_tuner_AutoML')

[docs]class BatchTuner(Tuner): """ BatchTuner is tuner will running all the configure that user want to run batchly. Examples -------- The search space only be accepted like: :: {'combine_params': { '_type': 'choice', '_value': '[{...}, {...}, {...}]', } } """ def __init__(self): self._count = -1 self._values = []
[docs] def is_valid(self, search_space): """ Check the search space is valid: only contains 'choice' type Parameters ---------- search_space : dict Returns ------- None or list If valid, return candidate values; else return None. """ if not len(search_space) == 1: raise RuntimeError('BatchTuner only supprt one combined-paramreters key.') for param in search_space: param_type = search_space[param][TYPE] if not param_type == CHOICE: raise RuntimeError('BatchTuner only supprt \ one combined-paramreters type is choice.') if isinstance(search_space[param][VALUE], list): return search_space[param][VALUE] raise RuntimeError('The combined-paramreters \ value in BatchTuner is not a list.') return None
[docs] def update_search_space(self, search_space): """Update the search space Parameters ---------- search_space : dict """ validate_search_space(search_space, ['choice']) self._values = self.is_valid(search_space)
[docs] def generate_parameters(self, parameter_id, **kwargs): """Returns a dict of trial (hyper-)parameters, as a serializable object. Parameters ---------- parameter_id : int Returns ------- dict A candidate parameter group. """ self._count += 1 if self._count > len(self._values) - 1: raise nni.NoMoreTrialError('no more parameters now.') return self._values[self._count]
[docs] def receive_trial_result(self, parameter_id, parameters, value, **kwargs): pass
[docs] def import_data(self, data): """Import additional data for tuning Parameters ---------- data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' """ if not self._values: LOGGER.info("Search space has not been initialized, skip this data import") return self._values = self._values[(self._count+1):] self._count = -1 _completed_num = 0 for trial_info in data: LOGGER .info("Importing data, current processing \ progress %s / %s", _completed_num, len(data)) # simply validate data format assert "parameter" in trial_info _params = trial_info["parameter"] assert "value" in trial_info _value = trial_info['value'] if not _value: LOGGER.info("Useless trial data, value is %s, skip this trial data.", _value) continue _completed_num += 1 if _params in self._values: self._values.remove(_params) LOGGER .info("Successfully import data to batch tuner, \ total data: %d, imported data: %d.", len(data), _completed_num)