Source code for nni.algorithms.hpo.curvefitting_assessor.curvefitting_assessor

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import datetime
from schema import Schema, Optional

from nni import ClassArgsValidator
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel

logger = logging.getLogger('curvefitting_Assessor')

class CurvefittingClassArgsValidator(ClassArgsValidator):
    def validate_class_args(self, **kwargs):
        Schema({
            'epoch_num': self.range('epoch_num', int, 0, 9999),
            Optional('start_step'): self.range('start_step', int, 0, 9999),
            Optional('threshold'): self.range('threshold', float, 0, 9999),
            Optional('gap'): self.range('gap', int, 1, 9999),
        }).validate(kwargs)

[docs]class CurvefittingAssessor(Assessor): """CurvefittingAssessor uses learning curve fitting algorithm to predict the learning curve performance in the future. It stops a pending trial X at step S if the trial's forecast result at target step is convergence and lower than the best performance in the history. Parameters ---------- epoch_num : int The total number of epoch start_step : int only after receiving start_step number of reported intermediate results threshold : float The threshold that we decide to early stop the worse performance curve. """ def __init__(self, epoch_num=20, start_step=6, threshold=0.95, gap=1): if start_step <= 0: logger.warning('It\'s recommended to set start_step to a positive number') # Record the target position we predict self.target_pos = epoch_num # Start forecasting when historical data reaches start step self.start_step = start_step # Record the compared threshold self.threshold = threshold # Record the number of gap self.gap = gap # Record the number of intermediate result in the lastest judgment self.last_judgment_num = dict() # Record the best performance self.set_best_performance = False self.completed_best_performance = None self.trial_history = [] logger.info('Successfully initials the curvefitting assessor')
[docs] def trial_end(self, trial_job_id, success): """update the best performance of completed trial job Parameters ---------- trial_job_id : int trial job id success : bool True if succssfully finish the experiment, False otherwise """ if success: if self.set_best_performance: self.completed_best_performance = max(self.completed_best_performance, self.trial_history[-1]) else: self.set_best_performance = True self.completed_best_performance = self.trial_history[-1] logger.info('Updated completed best performance, trial job id: %s', trial_job_id) else: logger.info('No need to update, trial job id: %s', trial_job_id)
[docs] def assess_trial(self, trial_job_id, trial_history): """assess whether a trial should be early stop by curve fitting algorithm Parameters ---------- trial_job_id : int trial job id trial_history : list The history performance matrix of each trial Returns ------- bool AssessResult.Good or AssessResult.Bad Raises ------ Exception unrecognize exception in curvefitting_assessor """ scalar_trial_history = extract_scalar_history(trial_history) self.trial_history = scalar_trial_history if not self.set_best_performance: return AssessResult.Good curr_step = len(scalar_trial_history) if curr_step < self.start_step: return AssessResult.Good if trial_job_id in self.last_judgment_num.keys() and curr_step - self.last_judgment_num[trial_job_id] < self.gap: return AssessResult.Good self.last_judgment_num[trial_job_id] = curr_step try: start_time = datetime.datetime.now() # Predict the final result curvemodel = CurveModel(self.target_pos) predict_y = curvemodel.predict(scalar_trial_history) log_message = "Prediction done. Trial job id = {}, Predict value = {}".format(trial_job_id, predict_y) if predict_y is None: logger.info('%s, wait for more information to predict precisely', log_message) return AssessResult.Good else: logger.info(log_message) standard_performance = self.completed_best_performance * self.threshold end_time = datetime.datetime.now() if (end_time - start_time).seconds > 60: logger.warning( 'Curve Fitting Assessor Runtime Exceeds 60s, Trial Id = %s Trial History = %s', trial_job_id, self.trial_history ) if predict_y > standard_performance: return AssessResult.Good return AssessResult.Bad except Exception as exception: logger.exception('unrecognize exception in curvefitting_assessor %s', exception)