Source code for nni.nas.strategy.rl

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import logging
import warnings
from typing import Optional, TYPE_CHECKING

from nni.mutable import SampleValidationError
from nni.nas.execution import ExecutionEngine
from import ExecutableModelSpace

from .base import Strategy

    has_tianshou = True
    from import ReplayBuffer
    from ._rl_impl import PolicyFactory, TuningTrajectoryGenerator, default_policy_fn
except ImportError:
    has_tianshou = False

    from tianshou.policy import BasePolicy

_logger = logging.getLogger(__name__)

[docs] class PolicyBasedRL(Strategy): """ Algorithm for policy-based reinforcement learning. This is a wrapper of algorithms provided in tianshou (PPO by default), and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., `REINFORCE <>`__ as in `this paper <>`__). Parameters ---------- samples_per_update How many models (trajectories) each time collector collects. After each collect, trainer will sample batch from replay buffer and do the update. replay_buffer_size Size of replay buffer. If it's none, the size will be the expected trajectory length times ``samples_per_update``. reward_for_invalid The reward for a sample that didn't pass validation, or the training doesn't return a metric. If not provided, failed models will be simply ignored as if nothing happened. policy_fn Since environment is created on the fly, the policy needs to be a factory function that creates a policy on-the-fly. It takes :class:`~nni.nas.strategy._rl_impl.TuningEnvironment` as input and returns a policy. By default, it will use the policy returned by :func:`~nni.nas.strategy._rl_impl.default_policy_fn`. update_kwargs Keyword arguments for ``policy.update``. See tianshou's BasePolicy for details. There is a special key ``"update_times"`` that can be used to specify how many times ``policy.update`` is called, which can be used to sufficiently exploit the current available trajectories in the replay buffer (for example when actor and critic needs to be updated alternatively multiple times). By default, it's ``{'batch_size': 32, 'repeat': 5, 'update_times': 5}``. """ _invalid_patience = 20 def __init__(self, *, samples_per_update: int = 20, replay_buffer_size: int | None = None, reward_for_invalid: float | None = None, policy_fn: Optional[PolicyFactory] = None, update_kwargs: dict | None = None, **kwargs): super().__init__() if 'max_collect' in kwargs: warnings.warn('`max_collect` is deprecated. It has no effect now.', DeprecationWarning) if 'trial_per_collect' in kwargs: warnings.warn('`trial_per_collect` is deprecated. Use `samples_per_update` instead.', DeprecationWarning) samples_per_update = kwargs['trial_per_collect'] if not has_tianshou: raise ImportError('`tianshou` is required to run RL-based strategy. ' 'Please use "pip install tianshou" to install it beforehand.') self.policy_fn = policy_fn or default_policy_fn self.samples_per_update = samples_per_update self.replay_buffer_size = replay_buffer_size self.reward_for_invalid = reward_for_invalid self.update_kwargs = {'batch_size': 32, 'repeat': 5, 'update_times': 5} if update_kwargs is None else update_kwargs self._current_episode = 0 self._successful_episode = 0 self._running_models: list[tuple[ExecutableModelSpace, TuningTrajectoryGenerator]] = [] self._trajectory_count = 0 self._policy: BasePolicy | None = None self._replay_buffer: ReplayBuffer | None = None def extra_repr(self) -> str: return f'samples_per_update={self.samples_per_update}, replay_buffer_size={self.replay_buffer_size}, ' + \ f'reward_for_invalid={self.reward_for_invalid}' def _harvest_running_models(self) -> bool: """Harvest completed models and add their trajectories to replay buffer. Return true if the policy has just been updated on the latest buffer or no new trajectories found. False otherwise. """ running_indices = [] recently_updated = True for index, (model, generator) in enumerate(self._running_models): if model.status.completed(): if model.metric is not None: # No matter success or failure, as long as it is completed and gets a metric. trajectory = generator.send_reward(model.metric)'[Trajectory %4d] (%s, %s) %s', self._trajectory_count + 1, model.status.value, model.metric, model.sample) recently_updated = self._add_trajectory(trajectory) else:'%s has no metric. Skip.', model) else: running_indices.append(index) self._running_models = [self._running_models[i] for i in running_indices] return recently_updated def _add_trajectory(self, trajectory: ReplayBuffer) -> bool: """Add the trajectory to replay buffer and execute update if necessary. Return true if an update is just executed. False otherwise. """ assert self._replay_buffer is not None self._replay_buffer.update(trajectory) self._trajectory_count += 1 if self._trajectory_count % self.samples_per_update == 0: self._update_policy() return True return False def _update_policy(self) -> None: """Update the RL policy on current replay buffer."""'[Trajectory %4d] Updating policy...', self._trajectory_count) assert self._policy is not None and self._replay_buffer is not None update_times = self.update_kwargs.get('update_times', 1) for _ in range(update_times): self._policy.update(0, self._replay_buffer, **self.update_kwargs) def _initialize(self, model_space: ExecutableModelSpace, engine: ExecutionEngine) -> ExecutableModelSpace: generator = TuningTrajectoryGenerator(model_space, self.policy_fn) self._policy = generator.policy if self.replay_buffer_size is None: replay_buffer_size = generator.expected_trajectory_length * self.samples_per_update else: replay_buffer_size = self.replay_buffer_size self._replay_buffer = ReplayBuffer(replay_buffer_size) return model_space def _run(self) -> None: assert self._policy is not None and self._replay_buffer is not None self._policy.train() _invalid_count = 0'Sampling models with RL policy:\n%s', self._policy) while True: self._harvest_running_models() if not self.wait_for_resource():'Budget exhausted. No more sampling.') break generator = TuningTrajectoryGenerator(self.model_space, self._policy) sample = generator.next_sample() try: model = self.model_space.freeze(sample) _invalid_count = 0 except SampleValidationError: _logger.debug('Invalid sample generated. It will be handled following the setting of `reward_for_invalid`: %s', sample) _invalid_count += 1 if _invalid_count > self._invalid_patience: _logger.warning('Too many (over %d) invalid samples generated. No more sampling.', self._invalid_patience) break if self.reward_for_invalid is not None: trajectory = generator.send_reward(self.reward_for_invalid)'[Trajectory %4d] (invalid, %s) %s', self._trajectory_count + 1, self.reward_for_invalid if self.reward_for_invalid is not None else 'skip', sample) self._add_trajectory(trajectory) continue # Now the model training is destined to happen. self._running_models.append((model, generator)) self.engine.submit_models(model)'Harvesting final running models.') # First use engine.wait to wait for models to finish training. self.engine.wait_models(*[model for model, _ in self._running_models]) # Then put their metrics into buffer. if not self._harvest_running_models(): # Train on final harvested models.'Training on final harvested models.') self._update_policy() def state_dict(self) -> dict: result = { 'current_episode': self._current_episode, 'successful_episode': self._successful_episode, 'trajectory_count': self._trajectory_count, 'num_running_models': len(self._running_models), } if self._policy is None or self._replay_buffer is None: _logger.warning('State dict of policy and replay buffer is not saved because they are not initialized yet.') else: result['policy'] = self._policy.state_dict() result['replay_buffer'] = self._replay_buffer return result def load_state_dict(self, state_dict: dict) -> None: if state_dict.get('num_running_models', 0) > 0: _logger.warning('Loaded state dict has %d running models. They will be ignored.', state_dict['num_running_models']) self._current_episode = state_dict['current_episode'] self._successful_episode = state_dict['successful_episode'] self._trajectory_count = state_dict['trajectory_count'] if self._policy is None or self._replay_buffer is None: _logger.warning('State dict of policy and replay buffer is not restored because they are not initialized yet.') elif 'policy' not in state_dict or 'replay_buffer' not in state_dict: _logger.warning('Policy and replay buffer is not restored because they are not found in saved in state_dict.') else: self._policy.load_state_dict(state_dict['policy']) self._replay_buffer = state_dict['replay_buffer']