nni.nas.experiment.config.engine 源代码
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = [
'ExecutionEngineConfig', 'TrainingServiceEngineConfig', 'CgoEngineConfig', 'SequentialEngineConfig',
]
import logging
from dataclasses import dataclass
from typing import Optional
from nni.experiment.config import ExperimentConfig
from nni.experiment.config.training_services import RemoteConfig
from nni.experiment.config.utils import parse_time
from .utils import NamedSubclassConfigBase
_logger = logging.getLogger(__name__)
[文档]
@dataclass(init=False)
class ExecutionEngineConfig(NamedSubclassConfigBase):
"""Base class for execution engine config. Useful for instance check."""
[文档]
@dataclass(init=False)
class TrainingServiceEngineConfig(ExecutionEngineConfig):
"""Engine used together with NNI training service.
Training service specific configs should go here,
but they are now in top-level experiment config for historical reasons.
"""
name: str = 'ts'
[文档]
@dataclass(init=False)
class SequentialEngineConfig(ExecutionEngineConfig):
"""Engine that executes the models sequentially."""
name: str = 'sequential'
continue_on_failure: bool = False
max_model_count: Optional[int] = None
max_duration: Optional[float] = None
def _canonicalize(self, parents):
assert len(parents) > 0
parent_config = parents[0]
assert isinstance(parent_config, ExperimentConfig), 'SequentialEngineConfig must be a child of ExperimentConfig'
if self.max_model_count is None:
self.max_model_count = parent_config.max_trial_number
if self.max_duration is None and parent_config.max_trial_duration is not None:
self.max_duration = parse_time(parent_config.max_trial_duration)
if isinstance(parent_config.trial_concurrency, int) and parent_config.trial_concurrency > 1:
_logger.warning('Sequential engine does not support trial concurrency > 1')
return super()._canonicalize(parents)
[文档]
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
"""Engine for cross-graph optimization."""
name: str = 'cgo'
max_concurrency_cgo: int
batch_waiting_time: int
training_service: Optional[RemoteConfig] = None
def _canonicalize(self, parents):
"""Copy the training service config from the parent experiment config."""
assert len(parents) > 0
parent_config = parents[0]
assert isinstance(parent_config, ExperimentConfig), 'CgoEngineConfig must be a child of ExperimentConfig'
if not isinstance(parent_config.training_service, RemoteConfig):
raise TypeError("CGO execution engine currently only supports remote training service")
self.training_service = parent_config.training_service
return super()._canonicalize(parents)