Source code for nni.algorithms.hpo.pbt_tuner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
import logging
import os
import random
import numpy as np
from schema import Schema, Optional

import nni
from nni import ClassArgsValidator
import nni.parameter_expressions
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space


logger = logging.getLogger('pbt_tuner_AutoML')


def perturbation(hyperparameter_type, value, resample_probablity, uv, ub, lv, lb, random_state):
    """
    Perturbation for hyperparameters

    Parameters
    ----------
    hyperparameter_type : str
        type of hyperparameter
    value : list
        parameters for sampling hyperparameter
    resample_probability : float
        probability for resampling
    uv : float/int
        upper value after perturbation
    ub : float/int
        upper bound
    lv : float/int
        lower value after perturbation
    lb : float/int
        lower bound
    random_state : RandomState
        random state
    """
    if random.random() < resample_probablity:
        if hyperparameter_type == "choice":
            return value.index(nni.parameter_expressions.choice(value, random_state))
        else:
            return getattr(nni.parameter_expressions, hyperparameter_type)(*(value + [random_state]))
    else:
        if random.random() > 0.5:
            return min(uv, ub)
        else:
            return max(lv, lb)


def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probability, epoch, search_space):
    """
    Replace checkpoint of bot_trial with top, and perturb hyperparameters

    Parameters
    ----------
    bot_trial_info : TrialInfo
        bottom model whose parameters should be replaced
    top_trial_info : TrialInfo
        better model
    factor : float
        factor for perturbation
    resample_probability : float
        probability for resampling
    epoch : int
        step of PBTTuner
    search_space : dict
        search_space to keep perturbed hyperparameters in range
    """
    bot_checkpoint_dir = bot_trial_info.checkpoint_dir
    top_hyper_parameters = top_trial_info.hyper_parameters
    hyper_parameters = copy.deepcopy(top_hyper_parameters)
    random_state = np.random.RandomState()
    hyper_parameters['load_checkpoint_dir'] = hyper_parameters['save_checkpoint_dir']
    hyper_parameters['save_checkpoint_dir'] = os.path.join(bot_checkpoint_dir, str(epoch))
    for key in hyper_parameters.keys():
        hyper_parameter = hyper_parameters[key]
        if key == 'load_checkpoint_dir' or key == 'save_checkpoint_dir':
            continue
        elif search_space[key]["_type"] == "choice":
            choices = search_space[key]["_value"]
            ub, uv = len(choices) - 1, choices.index(hyper_parameter) + 1
            lb, lv = 0, choices.index(hyper_parameter) - 1
        elif search_space[key]["_type"] == "randint":
            lb, ub = search_space[key]["_value"][:2]
            ub -= 1
            uv = hyper_parameter + 1
            lv = hyper_parameter - 1
        elif search_space[key]["_type"] == "uniform":
            lb, ub = search_space[key]["_value"][:2]
            perturb = (ub - lb) * factor
            uv = hyper_parameter + perturb
            lv = hyper_parameter - perturb
        elif search_space[key]["_type"] == "quniform":
            lb, ub, q = search_space[key]["_value"][:3]
            multi = round(hyper_parameter / q)
            uv = (multi + 1) * q
            lv = (multi - 1) * q
        elif search_space[key]["_type"] == "loguniform":
            lb, ub = search_space[key]["_value"][:2]
            perturb = (np.log(ub) - np.log(lb)) * factor
            uv = np.exp(min(np.log(hyper_parameter) + perturb, np.log(ub)))
            lv = np.exp(max(np.log(hyper_parameter) - perturb, np.log(lb)))
        elif search_space[key]["_type"] == "qloguniform":
            lb, ub, q = search_space[key]["_value"][:3]
            multi = round(hyper_parameter / q)
            uv = (multi + 1) * q
            lv = (multi - 1) * q
        elif search_space[key]["_type"] == "normal":
            sigma = search_space[key]["_value"][1]
            perturb = sigma * factor
            uv = ub = hyper_parameter + perturb
            lv = lb = hyper_parameter - perturb
        elif search_space[key]["_type"] == "qnormal":
            q = search_space[key]["_value"][2]
            uv = ub = hyper_parameter + q
            lv = lb = hyper_parameter - q
        elif search_space[key]["_type"] == "lognormal":
            sigma = search_space[key]["_value"][1]
            perturb = sigma * factor
            uv = ub = np.exp(np.log(hyper_parameter) + perturb)
            lv = lb = np.exp(np.log(hyper_parameter) - perturb)
        elif search_space[key]["_type"] == "qlognormal":
            q = search_space[key]["_value"][2]
            uv = ub = hyper_parameter + q
            lv, lb = hyper_parameter - q, 1E-10
        else:
            logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
            continue

        if search_space[key]["_type"] == "choice":
            idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
                               resample_probability, uv, ub, lv, lb, random_state)
            hyper_parameters[key] = choices[idx]
        else:
            hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
                                                 resample_probability, uv, ub, lv, lb, random_state)
    bot_trial_info.hyper_parameters = hyper_parameters
    bot_trial_info.clean_id()


class TrialInfo:
    """
    Information of each trial, refresh for each epoch

    """

    def __init__(self, checkpoint_dir=None, hyper_parameters=None, parameter_id=None, score=None):
        self.checkpoint_dir = checkpoint_dir
        self.hyper_parameters = hyper_parameters
        self.parameter_id = parameter_id
        self.score = score

    def clean_id(self):
        self.parameter_id = None

class PBTClassArgsValidator(ClassArgsValidator):
    def validate_class_args(self, **kwargs):
        Schema({
            'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
            Optional('all_checkpoint_dir'): str,
            Optional('population_size'): self.range('population_size', int, 0, 99999),
            Optional('factors'): float,
            Optional('fraction'): float,
        }).validate(kwargs)

[docs] class PBTTuner(Tuner): """ Population Based Training (PBT) comes from `Population Based Training of Neural Networks <https://arxiv.org/abs/1711.09846v1>`__. It's a simple asynchronous optimization algorithm which effectively utilizes a fixed computational budget to jointly optimize a population of models and their hyperparameters to maximize performance. Importantly, PBT discovers a schedule of hyperparameter settings rather than following the generally sub-optimal strategy of trying to find a single fixed set to use for the whole course of training. .. image:: ../../img/pbt.jpg PBT tuner initializes a population with several trials (i.e., ``population_size``). There are four steps in the above figure, each trial only runs by one step. How long is one step is controlled by trial code, e.g., one epoch. When a trial starts, it loads a checkpoint specified by PBT tuner and continues to run one step, then saves checkpoint to a directory specified by PBT tuner and exits. The trials in a population run steps synchronously, that is, after all the trials finish the ``i``-th step, the ``(i+1)``-th step can be started. Exploitation and exploration of PBT are executed between two consecutive steps. Two important steps to follow if you are trying to use PBT tuner: 1. **Provide checkpoint directory**. Since some trials need to load other trial's checkpoint, users should provide a directory (i.e., ``all_checkpoint_dir``) which is accessible by every trial. It is easy for local mode, users could directly use the default directory or specify any directory on the local machine. For other training services, users should follow :doc:`the document of those training services </experiment/training_service/shared_storage>` to provide a directory in a shared storage, such as NFS, Azure storage. 2. **Modify your trial code**. Before running a step, a trial needs to load a checkpoint, the checkpoint directory is specified in hyper-parameter configuration generated by PBT tuner, i.e., ``params['load_checkpoint_dir']``. Similarly, the directory for saving checkpoint is also included in the configuration, i.e., ``params['save_checkpoint_dir']``. Here, ``all_checkpoint_dir`` is base folder of ``load_checkpoint_dir`` and ``save_checkpoint_dir`` whose format is ``all_checkpoint_dir/<population-id>/<step>``. .. code-block:: python params = nni.get_next_parameter() # the path of the checkpoint to load load_path = os.path.join(params['load_checkpoint_dir'], 'model.pth') # load checkpoint from `load_path` ... # run one step ... # the path for saving a checkpoint save_path = os.path.join(params['save_checkpoint_dir'], 'model.pth') # save checkpoint to `save_path` ... The complete example code can be found :githublink:`here <examples/trials/mnist-pbt-tuner-pytorch>`. Parameters ---------- optimize_mode : ``maximize`` or ``minimize``, default: ``maximize`` If ``maximize``, the tuner will target to maximize metrics. If ``minimize``, the tuner will target to minimize metrics. all_checkpoint_dir : str Directory for trials to load and save checkpoint. If not specified, the directory would be ``~/nni/checkpoint/``. Note that if the experiment is not local mode, users should provide a path in a shared storage which can be accessed by all the trials. population_size : int, default = 10 Number of trials in a population. Each step has this number of trials. In our implementation, one step is running each trial by specific training epochs set by users. factor : float, default = (1.2, 0.8) Factors for perturbation of hyperparameters. resample_probability : float, default = 0.25 Probability for resampling. fraction : float, default = 0.2 Fraction for selecting bottom and top trials. Examples -------- Below is an example of PBT tuner configuration in experiment config file. .. code-block:: yaml tuner: name: PBT classArgs: optimize_mode: maximize all_checkpoint_dir: /the/path/to/store/checkpoints population_size: 10 Notes ----- Assessor is not allowed if PBT tuner is used. """ def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2, resample_probability=0.25, fraction=0.2): self.optimize_mode = OptimizeMode(optimize_mode) if all_checkpoint_dir is None: all_checkpoint_dir = os.getenv('NNI_CHECKPOINT_DIRECTORY') logger.info("Checkpoint dir is set to %s by default.", all_checkpoint_dir) self.all_checkpoint_dir = all_checkpoint_dir self.population_size = population_size self.factor = factor self.resample_probability = resample_probability self.fraction = fraction # defined in trial code #self.perturbation_interval = perturbation_interval self.population = None self.pos = -1 self.param_ids = [] self.running = {} self.finished = [] self.credit = 0 self.finished_trials = 0 self.epoch = 0 self.searchspace_json = None self.space = None self.send_trial_callback = None logger.info('PBT tuner initialization') def update_search_space(self, search_space): """ Get search space Parameters ---------- search_space : dict Search space """ logger.info('Update search space %s', search_space) self.searchspace_json = search_space self.space = json2space(self.searchspace_json) self.random_state = np.random.RandomState() self.population = [] is_rand = dict() for item in self.space: is_rand[item] = True for i in range(self.population_size): hyper_parameters = json2parameter( self.searchspace_json, is_rand, self.random_state) hyper_parameters = split_index(hyper_parameters) checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i)) hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch)) hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch)) self.population.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=hyper_parameters)) def generate_multiple_parameters(self, parameter_id_list, **kwargs): """ Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects. Parameters ---------- parameter_id_list : list of int Unique identifiers for each set of requested hyper-parameters. These will later be used in :meth:`receive_trial_result`. **kwargs Used for send_trial_callback. Returns ------- list A list of newly generated configurations """ result = [] self.send_trial_callback = kwargs['st_callback'] for parameter_id in parameter_id_list: had_exception = False try: logger.debug("generating param for %s", parameter_id) res = self.generate_parameters(parameter_id, **kwargs) except nni.NoMoreTrialError: had_exception = True if not had_exception: result.append(res) return result def generate_parameters(self, parameter_id, **kwargs): """ Generate parameters, if no trial configration for now, self.credit plus 1 to send the config later Parameters ---------- parameter_id : int Unique identifier for requested hyper-parameters. This will later be used in :meth:`receive_trial_result`. **kwargs Not used Returns ------- dict One newly generated configuration """ if self.pos == self.population_size - 1: logger.debug('Credit added by one in parameters request') self.credit += 1 self.param_ids.append(parameter_id) raise nni.NoMoreTrialError('No more parameters now.') self.pos += 1 trial_info = self.population[self.pos] trial_info.parameter_id = parameter_id self.running[parameter_id] = trial_info logger.info('Generate parameter : %s', trial_info.hyper_parameters) return trial_info.hyper_parameters def _proceed_next_epoch(self): """ """ logger.info('Proceeding to next epoch') self.epoch += 1 self.population = [] self.pos = -1 self.running = {} #exploit and explore reverse = True if self.optimize_mode == OptimizeMode.Maximize else False self.finished = sorted(self.finished, key=lambda x: x.score, reverse=reverse) cutoff = int(np.ceil(self.fraction * len(self.finished))) tops = self.finished[:cutoff] bottoms = self.finished[self.finished_trials - cutoff:] for bottom in bottoms: top = np.random.choice(tops) exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json) for trial in self.finished: if trial not in bottoms: trial.clean_id() trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir'] trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch)) self.finished_trials = 0 for _ in range(self.population_size): trial_info = self.finished.pop() self.population.append(trial_info) while self.credit > 0 and self.pos + 1 < len(self.population): self.credit -= 1 self.pos += 1 parameter_id = self.param_ids.pop() trial_info = self.population[self.pos] trial_info.parameter_id = parameter_id self.running[parameter_id] = trial_info self.send_trial_callback(parameter_id, trial_info.hyper_parameters) def receive_trial_result(self, parameter_id, parameters, value, **kwargs): """ Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to train the model. Parameters ---------- parameter_id : int Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`. parameters : dict Hyper-parameters generated by :meth:`generate_parameters`. value : dict Result from trial (the return value of :func:`nni.report_final_result`). """ logger.info('Get one trial result, id = %d, value = %s', parameter_id, value) value = extract_scalar_reward(value) trial_info = self.running.pop(parameter_id, None) trial_info.score = value self.finished.append(trial_info) self.finished_trials += 1 if self.finished_trials == self.population_size: self._proceed_next_epoch() def trial_end(self, parameter_id, success, **kwargs): """ Deal with trial failure Parameters ---------- parameter_id : int Unique identifier for hyper-parameters used by this trial. success : bool True if the trial successfully completed; False if failed or terminated. **kwargs Unstable parameters which should be ignored by normal users. """ if success: return if self.optimize_mode == OptimizeMode.Minimize: value = float('inf') else: value = float('-inf') trial_info = self.running.pop(parameter_id, None) trial_info.score = value self.finished.append(trial_info) self.finished_trials += 1 if self.finished_trials == self.population_size: self._proceed_next_epoch() def import_data(self, data): """ Parameters ---------- data : json obj imported data records Returns ------- int the start epoch number after data imported, only used for unittest """ if self.running: logger.warning("Do not support importing data in the middle of experiment") return # the following is for experiment resume _completed_num = 0 epoch_data_dict = {} for trial_info in data: logger.info("Process data record %s / %s", _completed_num, len(data)) _completed_num += 1 # simply validate data format _params = trial_info["parameter"] _value = trial_info['value'] # assign fake value for failed trials if not _value: logger.info("Useless trial data, value is %s, skip this trial data.", _value) _value = float('inf') if self.optimize_mode == OptimizeMode.Minimize else float('-inf') _value = extract_scalar_reward(_value) if 'save_checkpoint_dir' not in _params: logger.warning("Invalid data record: save_checkpoint_dir is missing, abandon data import.") return epoch_num = int(os.path.basename(_params['save_checkpoint_dir'])) if epoch_num not in epoch_data_dict: epoch_data_dict[epoch_num] = [] epoch_data_dict[epoch_num].append((_params, _value)) if not epoch_data_dict: logger.warning("No valid epochs, abandon data import.") return # figure out start epoch for resume max_epoch_num = max(epoch_data_dict, key=int) if len(epoch_data_dict[max_epoch_num]) < self.population_size: max_epoch_num -= 1 # If there is no a single complete round, no data to import, start from scratch if max_epoch_num < 0: logger.warning("No completed epoch, abandon data import.") return assert len(epoch_data_dict[max_epoch_num]) == self.population_size # check existence of trial save checkpoint dir for params, _ in epoch_data_dict[max_epoch_num]: if not os.path.isdir(params['save_checkpoint_dir']): logger.warning("save_checkpoint_dir %s does not exist, data will not be resumed", params['save_checkpoint_dir']) return # resume data self.epoch = max_epoch_num self.finished_trials = self.population_size for params, value in epoch_data_dict[max_epoch_num]: checkpoint_dir = os.path.dirname(params['save_checkpoint_dir']) self.finished.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=params, score=value)) self._proceed_next_epoch() logger.info("Successfully import data to PBT tuner, total data: %d, imported data: %d.", len(data), self.population_size) logger.info("Start from epoch %d ...", self.epoch) return self.epoch # return for test