# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['CrossGraphOptimization']
import logging
import time
import threading
from collections.abc import Iterable
from typing import List, Dict, Tuple, cast
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.space import GraphModelSpace, Node, ModelStatus
from nni.nas.execution.engine import Middleware, ExecutionEngine
from nni.nas.execution.event import ModelEventType, IntermediateMetricEvent, FinalMetricEvent, TrainingEndEvent
from nni.typehint import TrialMetric
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
_logger = logging.getLogger(__name__)
[文档]
class CrossGraphOptimization(Middleware):
"""
The execution engine middleware of Cross-Graph Optimization (CGO).
It's a technique that merges multiple models into one model for training speedup.
See `Retiarii paper <https://www.usenix.org/system/files/osdi20-zhang_quanlu.pdf>`__ for details.
Currently, :class:`CrossGraphOptimization` is only a prototype.
It's not fully tested, and also, comes with a bunch of constraints on the model space and evaluator:
- The models must be in the format of :class:`~nni.nas.space.GraphModelSpace`.
- The evaluator has to be a :class:`~nni.nas.evaluator.pytorch.Lightning` evaluator.
- The ``lightning_module`` argument of the evaluator must be an instance of
:class:`~nni.nas.execution.cgo.evaluator.MultiModelSupervisedLearningModule`.
- The ``trainer`` argument of the evaluator must be an instance of
:class:`~nni.nas.execution.cgo.evaluator.MultiModelTrainer`.
There are also a number of limitations:
- CGO doesn't support stop and resume a checkpoint.
- Only remote training service is supported.
- All model history are stored in memory. The experiment might not scale well.
Parameters
----------
remote_config
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""
def __init__(self, remote_config: RemoteConfig,
max_concurrency: int | None = None,
batch_waiting_time: int = 60) -> None:
super().__init__()
_logger.warning('Cross graph optimization is an experimental feature. Usages are subject to change.')
self._history: List[GraphModelSpace] = []
self._running_models: Dict[int, GraphModelSpace] = {}
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int | None = max_concurrency
devices = self._construct_devices(remote_config)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models: Dict[int, GraphModelSpace] = {}
self._original_model_to_multi_model: Dict[int, GraphModelSpace] = {}
self._trial_to_original_models: Dict[int, List[int]] = {}
self._trial_used_devices: Dict[int, List[Device]] = {}
self._queuing_models: List[Tuple[float, GraphModelSpace]] = []
self._models_to_retry: List[GraphModelSpace] = []
self._queue_lock = threading.Lock()
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def shutdown(self):
self._stopped = True
self._consumer_thread.join()
if self._engine is None:
_logger.warning('Underlying engine is not set. Skip shutdown.')
else:
self.engine.unregister_model_event_callback(ModelEventType.TrainingEnd, self._training_end_callback)
self.engine.unregister_model_event_callback(ModelEventType.FinalMetric, self._final_metric_callback)
self.engine.unregister_model_event_callback(ModelEventType.IntermediateMetric, self._intermediate_metric_callback)
self.engine.shutdown()
def load_state_dict(self, state_dict: dict) -> None:
_logger.info('Cross graph optimization does not preserve any states by itself. Loading the state of inner engine: %s', self.engine)
return self.engine.load_state_dict(state_dict)
def state_dict(self) -> dict:
return self.engine.state_dict()
def set_engine(self, engine: ExecutionEngine) -> None:
super().set_engine(engine)
self.engine.register_model_event_callback(ModelEventType.TrainingEnd, self._training_end_callback)
self.engine.register_model_event_callback(ModelEventType.FinalMetric, self._final_metric_callback)
self.engine.register_model_event_callback(ModelEventType.IntermediateMetric, self._intermediate_metric_callback)
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: GraphModelSpace) -> None:
if any(not isinstance(model, GraphModelSpace) for model in models):
raise TypeError('Cross graph optimization only supports GraphModelSpace.')
curr_time = time.time()
_logger.info('%d models are submitted.', len(models))
with self._queue_lock:
self._queuing_models.extend([(curr_time, _) for _ in models])
self._history.extend(models)
def _submit_retry_models(self, models: List[GraphModelSpace]) -> None:
_logger.info('%d models are retried.', len(models))
with self._queue_lock:
self._models_to_retry.extend(models)
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
# retrying jobs should be first scheduled.
while self._models_to_retry:
with self._queue_lock:
# Get next model and lock the resource.
if len(self.available_devices) > 0:
m = self._models_to_retry[0]
self._models_to_retry = self._models_to_retry[1:]
m = self._schedule_models_in_batch(m)
else:
break
# submit the single model to avoid cross-graph optimization.
self.engine.submit_models(*m)
time.sleep(1)
# Submit merged models
merged_models = []
with self._queue_lock:
curr_time = time.time()
num_models_to_submit = len(self.available_devices)
if self.max_concurrency is not None:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
if self._queuing_models and curr_time - self._queuing_models[0][0] >= self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
merged_models = list(self._schedule_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]]))
self._queuing_models = self._queuing_models[num_models_to_submit:]
_logger.debug('Scheduled %d models in batch.', num_models_to_submit)
# Outside lock to avoid deadlock.
if merged_models:
self.engine.submit_models(*merged_models)
time.sleep(1)
def _schedule_models_in_batch(self, *models: GraphModelSpace) -> Iterable[GraphModelSpace]:
_logger.info('%d models are scheduled in batch.', len(models))
_logger.debug('Scheduled model ids: %s', [m.model_id for m in models])
for model in models:
model.status = ModelStatus.Training
logical = self._build_logical(list(models))
for opt in self._optimizers:
opt.convert(logical)
for model, grouped_models in self._assemble(logical):
assert model.placement is not None
_logger.debug('Created grouped model %d. Original model ids: %s', model.model_id, [m.model_id for m in grouped_models])
# unique non-cpu devices used by the trial
self._trial_used_devices[model.model_id] = list(set([_ for _ in model.placement.values() if isinstance(_, GPUDevice)]))
_logger.debug('Model %d uses devices: %s', model.model_id, self._trial_used_devices[model.model_id])
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[model.model_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[model.model_id] = model
self._trial_to_original_models[model.model_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._trial_to_original_models[model.model_id].append(m.model_id)
yield model
def list_models(self) -> Iterable[GraphModelSpace]:
return self._history
def idle_worker_available(self) -> bool:
# the _queuing_models need to use available_devices first
with self._queue_lock:
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
return bool(available_for_more_models)
def budget_available(self) -> bool:
return self.engine.budget_available()
def _assemble(self, logical_plan: LogicalPlan) -> Iterable[Tuple[GraphModelSpace, List[GraphModelSpace]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
grouped_models: List[Dict[GraphModelSpace, Device]] = []
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[GraphModelSpace, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model, GraphModelSpace), 'Assembled model must be a GraphModelSpace.'
from nni.nas.evaluator.pytorch import Lightning
from .evaluator import MultiModelLightningModule, MultiModelTrainer
if not isinstance(model.evaluator, Lightning):
raise TypeError('Cross-graph optimization only supports pytorch lighting as evaluator.')
if not isinstance(model.evaluator.module, MultiModelLightningModule):
raise TypeError('Cross-graph optimization only support MultiModelLightningModule')
if not isinstance(model.evaluator.trainer, MultiModelTrainer):
raise TypeError('Cross-graph optimization only support MultiModelTrainer')
# Set n_models of the lightning module.
model.evaluator.module.n_models = len(multi_model)
model.status = ModelStatus.Frozen
model.placement = model_placement
model.metrics.strict = False
yield model, list(multi_model.keys())
def _build_logical(self, models: List[GraphModelSpace]) -> LogicalPlan:
assert len(models) > 0
logical_plan = LogicalPlan(model_cls=models[0].__class__, plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def _training_end_callback(self, event: TrainingEndEvent) -> None:
model = cast(GraphModelSpace, event.model)
_logger.debug(f'Training end for merged model {model.model_id}.')
model = self._running_models[model.model_id]
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if model.status == ModelStatus.Trained:
self.dispatch_model_event(TrainingEndEvent(original_model, ModelStatus.Trained))
else:
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[model.model_id]) > 1:
# TODO: should the listeners be notified?
original_model.status = ModelStatus.Frozen
original_model.metrics.clear()
models_to_retry.append(original_model)
else:
self.dispatch_model_event(TrainingEndEvent(original_model, ModelStatus.Failed))
if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)
self.available_devices.extend(self._trial_used_devices[model.model_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[model.model_id]
def _intermediate_metric_callback(self, event: IntermediateMetricEvent) -> None:
model = cast(GraphModelSpace, event.model)
metrics = cast(List[TrialMetric], event.metric)
_logger.debug(f'Received intermediate metrics for merged model {model.model_id}: {metrics}')
if not isinstance(metrics, Iterable):
raise TypeError('Intermediate metrics must be a list of TrialMetric.')
if len(metrics) != len(self._trial_to_original_models[model.model_id]):
raise ValueError('Number of intermediate metrics must be equal to number of original models.')
merged_metrics: Dict[int, TrialMetric] = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[model.model_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self.dispatch_model_event(IntermediateMetricEvent(self._original_models[model_id], merged_metrics[model_id]))
def _final_metric_callback(self, event: FinalMetricEvent) -> None:
model = cast(GraphModelSpace, event.model)
metrics = cast(List[TrialMetric], event.metric)
_logger.debug(f'Received final metrics for merged model {model.model_id}: {metrics}')
if not isinstance(metrics, Iterable):
raise TypeError('Final metrics must be a list of TrialMetric.')
if len(metrics) != len(self._trial_to_original_models[model.model_id]):
raise ValueError('Number of final metrics must be equal to number of original models.')
merged_metrics: Dict[int, TrialMetric] = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[model.model_id][idx]] = metrics[idx]
_logger.debug(f'Mapped to metrics of original models: {merged_metrics}')
for model_id in merged_metrics:
self.dispatch_model_event(FinalMetricEvent(self._original_models[model_id], merged_metrics[model_id]))
class AssemblePolicy:
@staticmethod
def _is_related_node(model: GraphModelSpace, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: GraphModelSpace,
group_model: Dict[GraphModelSpace, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: GraphModelSpace, group_model: Dict[GraphModelSpace, Device]) -> bool:
from nni.nas.evaluator.pytorch import Lightning
from .evaluator import MultiModelLightningModule, MultiModelTrainer
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelLightningModule)
and isinstance(new_model.evaluator.trainer, MultiModelTrainer)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models