Spaces:
Sleeping
Sleeping
| from typing import Optional, List, Dict, Any, Tuple, Union | |
| from abc import ABC, abstractmethod | |
| from collections import namedtuple | |
| from easydict import EasyDict | |
| import copy | |
| import torch | |
| from ding.model import create_model | |
| from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \ | |
| POLICY_REGISTRY | |
| class Policy(ABC): | |
| """ | |
| Overview: | |
| The basic class of Reinforcement Learning (RL) and Imitation Learning (IL) policy in DI-engine. | |
| Property: | |
| ``cfg``, ``learn_mode``, ``collect_mode``, ``eval_mode`` | |
| """ | |
| def default_config(cls: type) -> EasyDict: | |
| """ | |
| Overview: | |
| Get the default config of policy. This method is used to create the default config of policy. | |
| Returns: | |
| - cfg (:obj:`EasyDict`): The default config of corresponding policy. For the derived policy class, \ | |
| it will recursively merge the default config of base class and its own default config. | |
| .. tip:: | |
| This method will deepcopy the ``config`` attribute of the class and return the result. So users don't need \ | |
| to worry about the modification of the returned config. | |
| """ | |
| if cls == Policy: | |
| raise RuntimeError("Basic class Policy doesn't have completed default_config") | |
| base_cls = cls.__base__ | |
| if base_cls == Policy: | |
| base_policy_cfg = EasyDict(copy.deepcopy(Policy.config)) | |
| else: | |
| base_policy_cfg = copy.deepcopy(base_cls.default_config()) | |
| cfg = EasyDict(copy.deepcopy(cls.config)) | |
| cfg = deep_merge_dicts(base_policy_cfg, cfg) | |
| cfg.cfg_type = cls.__name__ + 'Dict' | |
| return cfg | |
| learn_function = namedtuple( | |
| 'learn_function', [ | |
| 'forward', | |
| 'reset', | |
| 'info', | |
| 'monitor_vars', | |
| 'get_attribute', | |
| 'set_attribute', | |
| 'state_dict', | |
| 'load_state_dict', | |
| ] | |
| ) | |
| collect_function = namedtuple( | |
| 'collect_function', [ | |
| 'forward', | |
| 'process_transition', | |
| 'get_train_sample', | |
| 'reset', | |
| 'get_attribute', | |
| 'set_attribute', | |
| 'state_dict', | |
| 'load_state_dict', | |
| ] | |
| ) | |
| eval_function = namedtuple( | |
| 'eval_function', [ | |
| 'forward', | |
| 'reset', | |
| 'get_attribute', | |
| 'set_attribute', | |
| 'state_dict', | |
| 'load_state_dict', | |
| ] | |
| ) | |
| total_field = set(['learn', 'collect', 'eval']) | |
| config = dict( | |
| # (bool) Whether the learning policy is the same as the collecting data policy (on-policy). | |
| on_policy=False, | |
| # (bool) Whether to use cuda in policy. | |
| cuda=False, | |
| # (bool) Whether to use data parallel multi-gpu mode in policy. | |
| multi_gpu=False, | |
| # (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters. | |
| bp_update_sync=True, | |
| # (bool) Whether to enable infinite trajectory length in data collecting. | |
| traj_len_inf=False, | |
| # neural network model config | |
| model=dict(), | |
| ) | |
| def __init__( | |
| self, | |
| cfg: EasyDict, | |
| model: Optional[torch.nn.Module] = None, | |
| enable_field: Optional[List[str]] = None | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize policy instance according to input configures and model. This method will initialize differnent \ | |
| fields in policy, including ``learn``, ``collect``, ``eval``. The ``learn`` field is used to train the \ | |
| policy, the ``collect`` field is used to collect data for training, and the ``eval`` field is used to \ | |
| evaluate the policy. The ``enable_field`` is used to specify which field to initialize, if it is None, \ | |
| then all fields will be initialized. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. For the default config, \ | |
| see the ``config`` attribute and its comments of policy class. | |
| - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. If it \ | |
| is None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \ | |
| Otherwise, the model will be set to the ``model`` instance created by outside caller. | |
| - enable_field (:obj:`Optional[List[str]]`): The field list to initialize. If it is None, then all fields \ | |
| will be initialized. Otherwise, only the fields in ``enable_field`` will be initialized, which is \ | |
| beneficial to save resources. | |
| .. note:: | |
| For the derived policy class, it should implement the ``_init_learn``, ``_init_collect``, ``_init_eval`` \ | |
| method to initialize the corresponding field. | |
| """ | |
| self._cfg = cfg | |
| self._on_policy = self._cfg.on_policy | |
| if enable_field is None: | |
| self._enable_field = self.total_field | |
| else: | |
| self._enable_field = enable_field | |
| assert set(self._enable_field).issubset(self.total_field), self._enable_field | |
| if len(set(self._enable_field).intersection(set(['learn', 'collect', 'eval']))) > 0: | |
| model = self._create_model(cfg, model) | |
| self._cuda = cfg.cuda and torch.cuda.is_available() | |
| # now only support multi-gpu for only enable learn mode | |
| if len(set(self._enable_field).intersection(set(['learn']))) > 0: | |
| multi_gpu = self._cfg.multi_gpu | |
| self._rank = get_rank() if multi_gpu else 0 | |
| if self._cuda: | |
| # model.cuda() is an in-place operation. | |
| model.cuda() | |
| if multi_gpu: | |
| bp_update_sync = self._cfg.bp_update_sync | |
| self._bp_update_sync = bp_update_sync | |
| self._init_multi_gpu_setting(model, bp_update_sync) | |
| else: | |
| self._rank = 0 | |
| if self._cuda: | |
| # model.cuda() is an in-place operation. | |
| model.cuda() | |
| self._model = model | |
| self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' | |
| else: | |
| self._cuda = False | |
| self._rank = 0 | |
| self._device = 'cpu' | |
| # call the initialization method of different modes, such as ``_init_learn``, ``_init_collect``, ``_init_eval`` | |
| for field in self._enable_field: | |
| getattr(self, '_init_' + field)() | |
| def _init_multi_gpu_setting(self, model: torch.nn.Module, bp_update_sync: bool) -> None: | |
| """ | |
| Overview: | |
| Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning \ | |
| of the training, and prepare the hook function to allreduce the gradients of model parameters. | |
| Arguments: | |
| - model (:obj:`torch.nn.Module`): The neural network model to be trained. | |
| - bp_update_sync (:obj:`bool`): Whether to synchronize update the model parameters after allreduce the \ | |
| gradients of model parameters. Async update can be parallel in different network layers like pipeline \ | |
| so that it can save time. | |
| """ | |
| for name, param in model.state_dict().items(): | |
| assert isinstance(param.data, torch.Tensor), type(param.data) | |
| broadcast(param.data, 0) | |
| # here we manually set the gradient to zero tensor at the beginning of the training, which is necessary for | |
| # the case that different GPUs have different computation graph. | |
| for name, param in model.named_parameters(): | |
| setattr(param, 'grad', torch.zeros_like(param)) | |
| if not bp_update_sync: | |
| def make_hook(name, p): | |
| def hook(*ignore): | |
| allreduce_async(name, p.grad.data) | |
| return hook | |
| for i, (name, p) in enumerate(model.named_parameters()): | |
| if p.requires_grad: | |
| p_tmp = p.expand_as(p) | |
| grad_acc = p_tmp.grad_fn.next_functions[0][0] | |
| grad_acc.register_hook(make_hook(name, p)) | |
| def _create_model(self, cfg: EasyDict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: | |
| """ | |
| Overview: | |
| Create or validate the neural network model according to input configures and model. If the input model is \ | |
| None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \ | |
| Otherwise, the model will be verified as an instance of ``torch.nn.Module`` and set to the ``model`` \ | |
| instance created by outside caller. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. | |
| - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. User can refer to \ | |
| the default model defined in corresponding policy to customize its own model. | |
| Returns: | |
| - model (:obj:`torch.nn.Module`): The created neural network model. The different modes of policy will \ | |
| add distinct wrappers and plugins to the model, which is used to train, collect and evaluate. | |
| Raises: | |
| - RuntimeError: If the input model is not None and is not an instance of ``torch.nn.Module``. | |
| """ | |
| if model is None: | |
| model_cfg = cfg.model | |
| if 'type' not in model_cfg: | |
| m_type, import_names = self.default_model() | |
| model_cfg.type = m_type | |
| model_cfg.import_names = import_names | |
| return create_model(model_cfg) | |
| else: | |
| if isinstance(model, torch.nn.Module): | |
| return model | |
| else: | |
| raise RuntimeError("invalid model: {}".format(type(model))) | |
| def cfg(self) -> EasyDict: | |
| return self._cfg | |
| def _init_learn(self) -> None: | |
| """ | |
| Overview: | |
| Initialize the learn mode of policy, including related attributes and modules. This method will be \ | |
| called in ``__init__`` method if ``learn`` field is in ``enable_field``. Almost different policies have \ | |
| its own learn mode, so this method must be overrided in subclass. | |
| .. note:: | |
| For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ | |
| and ``_load_state_dict_learn`` methods. | |
| .. note:: | |
| For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. | |
| .. note:: | |
| If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ | |
| with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. | |
| """ | |
| raise NotImplementedError | |
| def _init_collect(self) -> None: | |
| """ | |
| Overview: | |
| Initialize the collect mode of policy, including related attributes and modules. This method will be \ | |
| called in ``__init__`` method if ``collect`` field is in ``enable_field``. Almost different policies have \ | |
| its own collect mode, so this method must be overrided in subclass. | |
| .. note:: | |
| For the member variables that need to be saved and loaded, please refer to the ``_state_dict_collect`` \ | |
| and ``_load_state_dict_collect`` methods. | |
| .. note:: | |
| If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ | |
| with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. | |
| """ | |
| raise NotImplementedError | |
| def _init_eval(self) -> None: | |
| """ | |
| Overview: | |
| Initialize the eval mode of policy, including related attributes and modules. This method will be \ | |
| called in ``__init__`` method if ``eval`` field is in ``enable_field``. Almost different policies have \ | |
| its own eval mode, so this method must be overrided in subclass. | |
| .. note:: | |
| For the member variables that need to be saved and loaded, please refer to the ``_state_dict_eval`` \ | |
| and ``_load_state_dict_eval`` methods. | |
| .. note:: | |
| If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ | |
| with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. | |
| """ | |
| raise NotImplementedError | |
| def learn_mode(self) -> 'Policy.learn_function': # noqa | |
| """ | |
| Overview: | |
| Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \ | |
| to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ | |
| subclass can override the interfaces to customize its own learn mode. | |
| Returns: | |
| - interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \ | |
| whose values of distinct fields are different internal methods. | |
| Examples: | |
| >>> policy = Policy(cfg, model) | |
| >>> policy_learn = policy.learn_mode | |
| >>> train_output = policy_learn.forward(data) | |
| >>> state_dict = policy_learn.state_dict() | |
| """ | |
| return Policy.learn_function( | |
| self._forward_learn, | |
| self._reset_learn, | |
| self.__repr__, | |
| self._monitor_vars_learn, | |
| self._get_attribute, | |
| self._set_attribute, | |
| self._state_dict_learn, | |
| self._load_state_dict_learn, | |
| ) | |
| def collect_mode(self) -> 'Policy.collect_function': # noqa | |
| """ | |
| Overview: | |
| Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \ | |
| to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ | |
| subclass can override the interfaces to customize its own collect mode. | |
| Returns: | |
| - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \ | |
| namedtuple whose values of distinct fields are different internal methods. | |
| Examples: | |
| >>> policy = Policy(cfg, model) | |
| >>> policy_collect = policy.collect_mode | |
| >>> obs = env_manager.ready_obs | |
| >>> inference_output = policy_collect.forward(obs) | |
| >>> next_obs, rew, done, info = env_manager.step(inference_output.action) | |
| """ | |
| return Policy.collect_function( | |
| self._forward_collect, | |
| self._process_transition, | |
| self._get_train_sample, | |
| self._reset_collect, | |
| self._get_attribute, | |
| self._set_attribute, | |
| self._state_dict_collect, | |
| self._load_state_dict_collect, | |
| ) | |
| def eval_mode(self) -> 'Policy.eval_function': # noqa | |
| """ | |
| Overview: | |
| Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple \ | |
| to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ | |
| subclass can override the interfaces to customize its own eval mode. | |
| Returns: | |
| - interfaces (:obj:`Policy.eval_function`): The interfaces of eval mode of policy, it is a namedtuple \ | |
| whose values of distinct fields are different internal methods. | |
| Examples: | |
| >>> policy = Policy(cfg, model) | |
| >>> policy_eval = policy.eval_mode | |
| >>> obs = env_manager.ready_obs | |
| >>> inference_output = policy_eval.forward(obs) | |
| >>> next_obs, rew, done, info = env_manager.step(inference_output.action) | |
| """ | |
| return Policy.eval_function( | |
| self._forward_eval, | |
| self._reset_eval, | |
| self._get_attribute, | |
| self._set_attribute, | |
| self._state_dict_eval, | |
| self._load_state_dict_eval, | |
| ) | |
| def _set_attribute(self, name: str, value: Any) -> None: | |
| """ | |
| Overview: | |
| In order to control the access of the policy attributes, we expose different modes to outside rather than \ | |
| directly use the policy instance. And we also provide a method to set the attribute of the policy in \ | |
| different modes. And the new attribute will named as ``_{name}``. | |
| Arguments: | |
| - name (:obj:`str`): The name of the attribute. | |
| - value (:obj:`Any`): The value of the attribute. | |
| """ | |
| setattr(self, '_' + name, value) | |
| def _get_attribute(self, name: str) -> Any: | |
| """ | |
| Overview: | |
| In order to control the access of the policy attributes, we expose different modes to outside rather than \ | |
| directly use the policy instance. And we also provide a method to get the attribute of the policy in \ | |
| different modes. | |
| Arguments: | |
| - name (:obj:`str`): The name of the attribute. | |
| Returns: | |
| - value (:obj:`Any`): The value of the attribute. | |
| .. note:: | |
| DI-engine's policy will first try to access `_get_{name}` method, and then try to access `_{name}` \ | |
| attribute. If both of them are not found, it will raise a ``NotImplementedError``. | |
| """ | |
| if hasattr(self, '_get_' + name): | |
| return getattr(self, '_get_' + name)() | |
| elif hasattr(self, '_' + name): | |
| return getattr(self, '_' + name) | |
| else: | |
| raise NotImplementedError | |
| def __repr__(self) -> str: | |
| """ | |
| Overview: | |
| Get the string representation of the policy. | |
| Returns: | |
| - repr (:obj:`str`): The string representation of the policy. | |
| """ | |
| return "DI-engine DRL Policy\n{}".format(repr(self._model)) | |
| def sync_gradients(self, model: torch.nn.Module) -> None: | |
| """ | |
| Overview: | |
| Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training. | |
| Arguments: | |
| - model (:obj:`torch.nn.Module`): The model to synchronize gradients. | |
| .. note:: | |
| This method is only used in multi-gpu training, and it shoule be called after ``backward`` method and \ | |
| before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \ | |
| gradients allreduce and optimizer updates. | |
| """ | |
| if self._bp_update_sync: | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| allreduce(param.grad.data) | |
| else: | |
| synchronize() | |
| # don't need to implement default_model method by force | |
| def default_model(self) -> Tuple[str, List[str]]: | |
| """ | |
| Overview: | |
| Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ | |
| automatically call this method to get the default model setting and create model. | |
| Returns: | |
| - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. | |
| .. note:: | |
| The user can define and use customized network model but must obey the same inferface definition indicated \ | |
| by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ | |
| ``ding.model.template.q_learning.DQN`` | |
| """ | |
| raise NotImplementedError | |
| # *************************************** learn function ************************************ | |
| def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Policy forward function of learn mode (training policy and updating parameters). Forward means \ | |
| that the policy inputs some training batch data from the replay buffer and then returns the output \ | |
| result, including various training information such as loss value, policy entropy, q value, priority, \ | |
| and so on. This method is left to be implemented by the subclass, and more arguments can be added in \ | |
| ``data`` item if necessary. | |
| Arguments: | |
| - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ | |
| training samples. For each element in list, the key of the dict is the name of data items and the \ | |
| value is the corresponding data. Usually, in the ``_forward_learn`` method, data should be stacked in \ | |
| the batch dimension by some utility functions such as ``default_preprocess_learn``. | |
| Returns: | |
| - output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \ | |
| monitoring training such as loss, priority, q value, policy entropy, and some data for next step \ | |
| training such as priority. Note the output data item should be Python native scalar rather than \ | |
| PyTorch tensor, which is convenient for the outside to use. | |
| """ | |
| raise NotImplementedError | |
| # don't need to implement _reset_learn method by force | |
| def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: | |
| """ | |
| Overview: | |
| Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \ | |
| memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ | |
| varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ | |
| different trajectories in ``data_id`` will have different hidden state in RNN. | |
| Arguments: | |
| - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ | |
| specified by ``data_id``. | |
| .. note:: | |
| This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. | |
| """ | |
| pass | |
| def _monitor_vars_learn(self) -> List[str]: | |
| """ | |
| Overview: | |
| Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ | |
| as text logger, tensorboard logger, will use these keys to save the corresponding data. | |
| Returns: | |
| - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. | |
| .. tip:: | |
| The default implementation is ``['cur_lr', 'total_loss']``. Other derived classes can overwrite this \ | |
| method to add their own keys if necessary. | |
| """ | |
| return ['cur_lr', 'total_loss'] | |
| def _state_dict_learn(self) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Return the state_dict of learn mode, usually including model and optimizer. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. | |
| """ | |
| return { | |
| 'model': self._learn_model.state_dict(), | |
| 'optimizer': self._optimizer.state_dict(), | |
| } | |
| def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: | |
| """ | |
| Overview: | |
| Load the state_dict variable into policy learn mode. | |
| Arguments: | |
| - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. | |
| .. tip:: | |
| If you want to only load some parts of model, you can simply set the ``strict`` argument in \ | |
| load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ | |
| complicated operation. | |
| """ | |
| self._learn_model.load_state_dict(state_dict['model']) | |
| self._optimizer.load_state_dict(state_dict['optimizer']) | |
| def _get_batch_size(self) -> Union[int, Dict[str, int]]: | |
| # some specifial algorithms use different batch size for different optimization parts. | |
| if 'batch_size' in self._cfg: | |
| return self._cfg.batch_size | |
| else: # for compatibility | |
| return self._cfg.learn.batch_size | |
| # *************************************** collect function ************************************ | |
| def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: | |
| """ | |
| Overview: | |
| Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ | |
| that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ | |
| data, such as the action to interact with the envs, or the action logits to calculate the loss in learn \ | |
| mode. This method is left to be implemented by the subclass, and more arguments can be added in ``kwargs`` \ | |
| part if necessary. | |
| Arguments: | |
| - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ | |
| key of the dict is environment id and the value is the corresponding data of the env. | |
| Returns: | |
| - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ | |
| other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ | |
| dict is the same as the input data, i.e. environment id. | |
| """ | |
| raise NotImplementedError | |
| def _process_transition( | |
| self, obs: Union[torch.Tensor, Dict[str, torch.Tensor]], policy_output: Dict[str, torch.Tensor], | |
| timestep: namedtuple | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Overview: | |
| Process and pack one timestep transition data into a dict, such as <s, a, r, s', done>. Some policies \ | |
| need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), \ | |
| so this method is left to be implemented by the subclass. | |
| Arguments: | |
| - obs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The observation of the current timestep. | |
| - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ | |
| as input. Usually, it contains the action and the logit of the action. | |
| - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ | |
| except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ | |
| reward, done, info, etc. | |
| Returns: | |
| - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. | |
| """ | |
| raise NotImplementedError | |
| def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Overview: | |
| For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ | |
| can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) \ | |
| or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary \ | |
| RL data preprocessing before training, which can help learner amortize revelant time consumption. \ | |
| In addition, you can also implement this method as an identity function and do the data processing \ | |
| in ``self._forward_learn`` method. | |
| Arguments: | |
| - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ | |
| the same format as the return value of ``self._process_transition`` method. | |
| Returns: | |
| - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ | |
| as input transitions, but may contain more data for training, such as nstep reward, advantage, etc. | |
| .. note:: | |
| We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ | |
| And the user can customize the this data processing procecure by overriding this two methods and collector \ | |
| itself | |
| """ | |
| raise NotImplementedError | |
| # don't need to implement _reset_collect method by force | |
| def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: | |
| """ | |
| Overview: | |
| Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the \ | |
| memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ | |
| varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ | |
| different environments/episodes in collecting in ``data_id`` will have different hidden state in RNN. | |
| Arguments: | |
| - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ | |
| specified by ``data_id``. | |
| .. note:: | |
| This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. | |
| """ | |
| pass | |
| def _state_dict_collect(self) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Return the state_dict of collect mode, only including model in usual, which is necessary for distributed \ | |
| training scenarios to auto-recover collectors. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): The dict of current policy collect state, for saving and restoring. | |
| .. tip:: | |
| Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed \ | |
| collector and renew a new one. | |
| """ | |
| return {'model': self._collect_model.state_dict()} | |
| def _load_state_dict_collect(self, state_dict: Dict[str, Any]) -> None: | |
| """ | |
| Overview: | |
| Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover \ | |
| checkpoint, or model replica from learner in distributed training scenarios. | |
| Arguments: | |
| - state_dict (:obj:`Dict[str, Any]`): The dict of policy collect state saved before. | |
| .. tip:: | |
| If you want to only load some parts of model, you can simply set the ``strict`` argument in \ | |
| load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ | |
| complicated operation. | |
| """ | |
| self._collect_model.load_state_dict(state_dict['model'], strict=True) | |
| def _get_n_sample(self) -> Union[int, None]: | |
| if 'n_sample' in self._cfg: | |
| return self._cfg.n_sample | |
| else: # for compatibility | |
| return self._cfg.collect.get('n_sample', None) # for some adpative collecting data case | |
| def _get_n_episode(self) -> Union[int, None]: | |
| if 'n_episode' in self._cfg: | |
| return self._cfg.n_episode | |
| else: # for compatibility | |
| return self._cfg.collect.get('n_episode', None) # for some adpative collecting data case | |
| # *************************************** eval function ************************************ | |
| def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: | |
| """ | |
| Overview: | |
| Policy forward function of eval mode (evaluation policy performance, such as interacting with envs or \ | |
| computing metrics on validation dataset). Forward means that the policy gets some necessary data (mainly \ | |
| observation) from the envs and then returns the output data, such as the action to interact with the envs. \ | |
| This method is left to be implemented by the subclass. | |
| Arguments: | |
| - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ | |
| key of the dict is environment id and the value is the corresponding data of the env. | |
| Returns: | |
| - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ | |
| key of the dict is the same as the input data, i.e. environment id. | |
| """ | |
| raise NotImplementedError | |
| # don't need to implement _reset_eval method by force | |
| def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: | |
| """ | |
| Overview: | |
| Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ | |
| memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ | |
| varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ | |
| different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. | |
| Arguments: | |
| - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ | |
| specified by ``data_id``. | |
| .. note:: | |
| This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. | |
| """ | |
| pass | |
| def _state_dict_eval(self) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Return the state_dict of eval mode, only including model in usual, which is necessary for distributed \ | |
| training scenarios to auto-recover evaluators. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): The dict of current policy eval state, for saving and restoring. | |
| .. tip:: | |
| Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed \ | |
| evaluator and renew a new one. | |
| """ | |
| return {'model': self._eval_model.state_dict()} | |
| def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: | |
| """ | |
| Overview: | |
| Load the state_dict variable into policy eval mode, such as load auto-recover \ | |
| checkpoint, or model replica from learner in distributed training scenarios. | |
| Arguments: | |
| - state_dict (:obj:`Dict[str, Any]`): The dict of policy eval state saved before. | |
| .. tip:: | |
| If you want to only load some parts of model, you can simply set the ``strict`` argument in \ | |
| load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ | |
| complicated operation. | |
| """ | |
| self._eval_model.load_state_dict(state_dict['model'], strict=True) | |
| class CommandModePolicy(Policy): | |
| """ | |
| Overview: | |
| Policy with command mode, which can be used in old version of DI-engine pipeline: ``serial_pipeline``. \ | |
| ``CommandModePolicy`` uses ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` methods \ | |
| to exchange information between different workers. | |
| Interface: | |
| ``_init_command``, ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` | |
| Property: | |
| ``command_mode`` | |
| """ | |
| command_function = namedtuple('command_function', ['get_setting_learn', 'get_setting_collect', 'get_setting_eval']) | |
| total_field = set(['learn', 'collect', 'eval', 'command']) | |
| def command_mode(self) -> 'Policy.command_function': # noqa | |
| """ | |
| Overview: | |
| Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple \ | |
| to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ | |
| subclass can override the interfaces to customize its own command mode. | |
| Returns: | |
| - interfaces (:obj:`Policy.command_function`): The interfaces of command mode, it is a namedtuple \ | |
| whose values of distinct fields are different internal methods. | |
| Examples: | |
| >>> policy = CommandModePolicy(cfg, model) | |
| >>> policy_command = policy.command_mode | |
| >>> settings = policy_command.get_setting_learn(command_info) | |
| """ | |
| return CommandModePolicy.command_function( | |
| self._get_setting_learn, self._get_setting_collect, self._get_setting_eval | |
| ) | |
| def _init_command(self) -> None: | |
| """ | |
| Overview: | |
| Initialize the command mode of policy, including related attributes and modules. This method will be \ | |
| called in ``__init__`` method if ``command`` field is in ``enable_field``. Almost different policies have \ | |
| its own command mode, so this method must be overrided in subclass. | |
| .. note:: | |
| If you want to set some spacial member variables in ``_init_command`` method, you'd better name them \ | |
| with prefix ``_command_`` to avoid conflict with other modes, such as ``self._command_attr1``. | |
| """ | |
| raise NotImplementedError | |
| # *************************************** command function ************************************ | |
| def _get_setting_learn(self, command_info: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ | |
| step, evaluation results, etc.), return the setting of learn mode, which contains dynamically changed \ | |
| hyperparameters for learn mode, such as ``batch_size``, ``learning_rate``, etc. | |
| Arguments: | |
| - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. | |
| Returns: | |
| - setting (:obj:`Dict[str, Any]`): The latest setting of learn mode, which is usually used as extra \ | |
| arguments of the ``policy._forward_learn`` method. | |
| """ | |
| raise NotImplementedError | |
| def _get_setting_collect(self, command_info: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ | |
| step, evaluation results, etc.), return the setting of collect mode, which contains dynamically changed \ | |
| hyperparameters for collect mode, such as ``eps``, ``temperature``, etc. | |
| Arguments: | |
| - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. | |
| Returns: | |
| - setting (:obj:`Dict[str, Any]`): The latest setting of collect mode, which is usually used as extra \ | |
| arguments of the ``policy._forward_collect`` method. | |
| """ | |
| raise NotImplementedError | |
| def _get_setting_eval(self, command_info: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ | |
| step, evaluation results, etc.), return the setting of eval mode, which contains dynamically changed \ | |
| hyperparameters for eval mode, such as ``temperature``, etc. | |
| Arguments: | |
| - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. | |
| Returns: | |
| - setting (:obj:`Dict[str, Any]`): The latest setting of eval mode, which is usually used as extra \ | |
| arguments of the ``policy._forward_eval`` method. | |
| """ | |
| raise NotImplementedError | |
| def create_policy(cfg: EasyDict, **kwargs) -> Policy: | |
| """ | |
| Overview: | |
| Create a policy instance according to ``cfg`` and other kwargs. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Final merged policy config. | |
| ArgumentsKeys: | |
| - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` . | |
| - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \ | |
| as ``ding.policy.dqn`` . | |
| Returns: | |
| - policy (:obj:`Policy`): The created policy instance. | |
| .. tip:: | |
| ``kwargs`` contains other arguments that need to be passed to the policy constructor. You can refer to \ | |
| the ``__init__`` method of the corresponding policy class for details. | |
| .. note:: | |
| For more details about how to merge config, please refer to the system document of DI-engine \ | |
| (`en link <../03_system/config.html>`_). | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return POLICY_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) | |
| def get_policy_cls(cfg: EasyDict) -> type: | |
| """ | |
| Overview: | |
| Get policy class according to ``cfg``, which is used to access related class variables/methods. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Final merged policy config. | |
| ArgumentsKeys: | |
| - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` . | |
| - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \ | |
| as ``ding.policy.dqn`` . | |
| Returns: | |
| - policy (:obj:`type`): The policy class. | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return POLICY_REGISTRY.get(cfg.type) | |