# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
batch_tuner.py including:
class BatchTuner
"""
import logging
import nni
from nni.tuner import Tuner
TYPE = '_type'
CHOICE = 'choice'
VALUE = '_value'
LOGGER = logging.getLogger('batch_tuner_AutoML')
[docs]class BatchTuner(Tuner):
"""
BatchTuner is tuner will running all the configure that user want to run batchly.
Examples
--------
The search space only be accepted like:
::
{'combine_params':
{ '_type': 'choice',
'_value': '[{...}, {...}, {...}]',
}
}
"""
def __init__(self):
self._count = -1
self._values = []
[docs] def is_valid(self, search_space):
"""
Check the search space is valid: only contains 'choice' type
Parameters
----------
search_space : dict
Returns
-------
None or list
If valid, return candidate values; else return None.
"""
if not len(search_space) == 1:
raise RuntimeError('BatchTuner only supprt one combined-paramreters key.')
for param in search_space:
param_type = search_space[param][TYPE]
if not param_type == CHOICE:
raise RuntimeError('BatchTuner only supprt \
one combined-paramreters type is choice.')
if isinstance(search_space[param][VALUE], list):
return search_space[param][VALUE]
raise RuntimeError('The combined-paramreters \
value in BatchTuner is not a list.')
return None
[docs] def update_search_space(self, search_space):
"""Update the search space
Parameters
----------
search_space : dict
"""
self._values = self.is_valid(search_space)
[docs] def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
----------
parameter_id : int
Returns
-------
dict
A candidate parameter group.
"""
self._count += 1
if self._count > len(self._values) - 1:
raise nni.NoMoreTrialError('no more parameters now.')
return self._values[self._count]
[docs] def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
[docs] def import_data(self, data):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
if not self._values:
LOGGER.info("Search space has not been initialized, skip this data import")
return
self._values = self._values[(self._count+1):]
self._count = -1
_completed_num = 0
for trial_info in data:
LOGGER .info("Importing data, current processing \
progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
LOGGER.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
_completed_num += 1
if _params in self._values:
self._values.remove(_params)
LOGGER .info("Successfully import data to batch tuner, \
total data: %d, imported data: %d.", len(data), _completed_num)