import logging
from typing import Optional, Callable

from nni.nas.execution import query_available_resources
from .base import BaseStrategy
from .utils import dry_run_for_search_space

    has_tianshou = True
    import torch
    from import Collector, VectorReplayBuffer
    from tianshou.env import BaseVectorEnv
    from tianshou.policy import BasePolicy, PPOPolicy  # pylint: disable=unused-import
    from ._rl_impl import ModelEvaluationEnv, MultiThreadEnvWorker, Preprocessor, Actor, Critic
except ImportError:
    has_tianshou = False

_logger = logging.getLogger(__name__)

[docs]class PolicyBasedRL(BaseStrategy): """ Algorithm for policy-based reinforcement learning. This is a wrapper of algorithms provided in tianshou (PPO by default), and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., `REINFORCE <>`__ as in `this paper <>`__). Parameters ---------- max_collect : int How many times collector runs to collect trials for RL. Default 100. trial_per_collect : int How many trials (trajectories) each time collector collects. After each collect, trainer will sample batch from replay buffer and do the update. Default: 20. policy_fn : function Takes :class:`ModelEvaluationEnv` as input and return a policy. See :meth:`PolicyBasedRL._default_policy_fn` for an example. """ def __init__(self, max_collect: int = 100, trial_per_collect = 20, policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None): if not has_tianshou: raise ImportError('`tianshou` is required to run RL-based strategy. ' 'Please use "pip install tianshou" to install it beforehand.') self.policy_fn = policy_fn or self._default_policy_fn self.max_collect = max_collect self.trial_per_collect = trial_per_collect @staticmethod def _default_policy_fn(env): net = Preprocessor(env.observation_space) actor = Actor(env.action_space, net) critic = Critic(net) optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4) return PPOPolicy(actor, critic, optim, torch.distributions.Categorical, discount_factor=1., action_space=env.action_space) def run(self, base_model, applied_mutators): search_space = dry_run_for_search_space(base_model, applied_mutators) concurrency = query_available_resources() env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space) policy = self.policy_fn(env_fn()) env = BaseVectorEnv([env_fn for _ in range(concurrency)], MultiThreadEnvWorker) collector = Collector(policy, env, VectorReplayBuffer(20000, len(env))) for cur_collect in range(1, self.max_collect + 1):'Collect [%d] Running...', cur_collect) result = collector.collect(n_episode=self.trial_per_collect)'Collect [%d] Result: %s', cur_collect, str(result)) policy.update(0, collector.buffer, batch_size=64, repeat=5)