Spaces:
Sleeping
Sleeping
| from typing import Any, Tuple, Callable, Optional, List, Dict, Union | |
| from abc import ABC | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Categorical, Independent, Normal | |
| from ding.torch_utils import get_tensor_data, zeros_like | |
| from ding.rl_utils import create_noise_generator | |
| from ding.utils.data import default_collate | |
| class IModelWrapper(ABC): | |
| """ | |
| Overview: | |
| The basic interface class of model wrappers. Model wrapper is a wrapper class of torch.nn.Module model, which \ | |
| is used to add some extra operations for the wrapped model, such as hidden state maintain for RNN-base model, \ | |
| argmax action selection for discrete action space, etc. | |
| Interfaces: | |
| ``__init__``, ``__getattr__``, ``info``, ``reset``, ``forward``. | |
| """ | |
| def __init__(self, model: nn.Module) -> None: | |
| """ | |
| Overview: | |
| Initialize model and other necessary member variabls in the model wrapper. | |
| """ | |
| self._model = model | |
| def __getattr__(self, key: str) -> Any: | |
| """ | |
| Overview: | |
| Get original attrbutes of torch.nn.Module model, such as variables and methods defined in model. | |
| Arguments: | |
| - key (:obj:`str`): The string key to query. | |
| Returns: | |
| - ret (:obj:`Any`): The queried attribute. | |
| """ | |
| return getattr(self._model, key) | |
| def info(self, attr_name: str) -> str: | |
| """ | |
| Overview: | |
| Get some string information of the indicated ``attr_name``, which is used for debug wrappers. | |
| This method will recursively search for the indicated ``attr_name``. | |
| Arguments: | |
| - attr_name (:obj:`str`): The string key to query information. | |
| Returns: | |
| - info_string (:obj:`str`): The information string of the indicated ``attr_name``. | |
| """ | |
| if attr_name in dir(self): | |
| if isinstance(self._model, IModelWrapper): | |
| return '{} {}'.format(self.__class__.__name__, self._model.info(attr_name)) | |
| else: | |
| if attr_name in dir(self._model): | |
| return '{} {}'.format(self.__class__.__name__, self._model.__class__.__name__) | |
| else: | |
| return '{}'.format(self.__class__.__name__) | |
| else: | |
| if isinstance(self._model, IModelWrapper): | |
| return '{}'.format(self._model.info(attr_name)) | |
| else: | |
| return '{}'.format(self._model.__class__.__name__) | |
| def reset(self, data_id: List[int] = None, **kwargs) -> None: | |
| """ | |
| Overview | |
| Basic interface, reset some stateful varaibles in the model wrapper, such as hidden state of RNN. | |
| Here we do nothing and just implement this interface method. | |
| Other derived model wrappers can override this method to add some extra operations. | |
| Arguments: | |
| - data_id (:obj:`List[int]`): The data id list to reset. If None, reset all data. In practice, \ | |
| model wrappers often needs to maintain some stateful variables for each data trajectory, \ | |
| so we leave this ``data_id`` argument to reset the stateful variables of the indicated data. | |
| """ | |
| pass | |
| def forward(self, *args, **kwargs) -> Any: | |
| """ | |
| Overview: | |
| Basic interface, call the wrapped model's forward method. Other derived model wrappers can override this \ | |
| method to add some extra operations. | |
| """ | |
| return self._model.forward(*args, **kwargs) | |
| class BaseModelWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Placeholder class for the model wrapper. This class is used to wrap the model without any extra operations, \ | |
| including a empty ``reset`` method and a ``forward`` method which directly call the wrapped model's forward. | |
| To keep the consistency of the model wrapper interface, we use this class to wrap the model without specific \ | |
| operations in the implementation of DI-engine's policy. | |
| """ | |
| pass | |
| class HiddenStateWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Maintain the hidden state for RNN-base model. Each sample in a batch has its own state. | |
| Interfaces: | |
| ``__init__``, ``reset``, ``forward``. | |
| """ | |
| def __init__( | |
| self, | |
| model: Any, | |
| state_num: int, | |
| save_prev_state: bool = False, | |
| init_fn: Callable = lambda: None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Maintain the hidden state for RNN-base model. Each sample in a batch has its own state. \ | |
| Init the maintain state and state function; Then wrap the ``model.forward`` method with auto \ | |
| saved data ['prev_state'] input, and create the ``model.reset`` method. | |
| Arguments: | |
| - model(:obj:`Any`): Wrapped model class, should contain forward method. | |
| - state_num (:obj:`int`): Number of states to process. | |
| - save_prev_state (:obj:`bool`): Whether to output the prev state in output. | |
| - init_fn (:obj:`Callable`): The function which is used to init every hidden state when init and reset, \ | |
| default return None for hidden states. | |
| .. note:: | |
| 1. This helper must deal with an actual batch with some parts of samples, e.g: 6 samples of state_num 8. | |
| 2. This helper must deal with the single sample state reset. | |
| """ | |
| super().__init__(model) | |
| self._state_num = state_num | |
| # This is to maintain hidden states (when it comes to this wrapper, \ | |
| # map self._state into data['prev_value] and update next_state, store in self._state) | |
| self._state = {i: init_fn() for i in range(state_num)} | |
| self._save_prev_state = save_prev_state | |
| self._init_fn = init_fn | |
| def forward(self, data, **kwargs): | |
| state_id = kwargs.pop('data_id', None) | |
| valid_id = kwargs.pop('valid_id', None) # None, not used in any code in DI-engine | |
| data, state_info = self.before_forward(data, state_id) # update data['prev_state'] with self._state | |
| output = self._model.forward(data, **kwargs) | |
| h = output.pop('next_state', None) | |
| if h is not None: | |
| self.after_forward(h, state_info, valid_id) # this is to store the 'next hidden state' for each time step | |
| if self._save_prev_state: | |
| prev_state = get_tensor_data(data['prev_state']) | |
| # for compatibility, because of the incompatibility between None and torch.Tensor | |
| for i in range(len(prev_state)): | |
| if prev_state[i] is None: | |
| prev_state[i] = zeros_like(h[0]) | |
| output['prev_state'] = prev_state | |
| return output | |
| def reset(self, *args, **kwargs): | |
| state = kwargs.pop('state', None) | |
| state_id = kwargs.get('data_id', None) | |
| self.reset_state(state, state_id) | |
| if hasattr(self._model, 'reset'): | |
| return self._model.reset(*args, **kwargs) | |
| def reset_state(self, state: Optional[list] = None, state_id: Optional[list] = None) -> None: | |
| if state_id is None: # train: init all states | |
| state_id = [i for i in range(self._state_num)] | |
| if state is None: # collect: init state that are done | |
| state = [self._init_fn() for i in range(len(state_id))] | |
| assert len(state) == len(state_id), '{}/{}'.format(len(state), len(state_id)) | |
| for idx, s in zip(state_id, state): | |
| self._state[idx] = s | |
| def before_forward(self, data: dict, state_id: Optional[list]) -> Tuple[dict, dict]: | |
| if state_id is None: | |
| state_id = [i for i in range(self._state_num)] | |
| state_info = {idx: self._state[idx] for idx in state_id} | |
| data['prev_state'] = list(state_info.values()) | |
| return data, state_info | |
| def after_forward(self, h: Any, state_info: dict, valid_id: Optional[list] = None) -> None: | |
| assert len(h) == len(state_info), '{}/{}'.format(len(h), len(state_info)) | |
| for i, idx in enumerate(state_info.keys()): | |
| if valid_id is None: | |
| self._state[idx] = h[i] | |
| else: | |
| if idx in valid_id: | |
| self._state[idx] = h[i] | |
| class TransformerInputWrapper(IModelWrapper): | |
| def __init__(self, model: Any, seq_len: int, init_fn: Callable = lambda: None) -> None: | |
| """ | |
| Overview: | |
| Given N the length of the sequences received by a Transformer model, maintain the last N-1 input | |
| observations. In this way we can provide at each step all the observations needed by Transformer to | |
| compute its output. We need this because some methods such as 'collect' and 'evaluate' only provide the | |
| model 1 observation per step and don't have memory of past observations, but Transformer needs a sequence | |
| of N observations. The wrapper method ``forward`` will save the input observation in a FIFO memory of | |
| length N and the method ``reset`` will reset the memory. The empty memory spaces will be initialized | |
| with 'init_fn' or zero by calling the method ``reset_input``. Since different env can terminate at | |
| different steps, the method ``reset_memory_entry`` only initializes the memory of specific environments in | |
| the batch size. | |
| Arguments: | |
| - model (:obj:`Any`): Wrapped model class, should contain forward method. | |
| - seq_len (:obj:`int`): Number of past observations to remember. | |
| - init_fn (:obj:`Callable`): The function which is used to init every memory locations when init and reset. | |
| """ | |
| super().__init__(model) | |
| self.seq_len = seq_len | |
| self._init_fn = init_fn | |
| self.obs_memory = None # shape (N, bs, *obs_shape) | |
| self.init_obs = None # sample of observation used to initialize the memory | |
| self.bs = None | |
| self.memory_idx = [] # len bs, index of where to put the next element in the sequence for each batch | |
| def forward(self, | |
| input_obs: torch.Tensor, | |
| only_last_logit: bool = True, | |
| data_id: List = None, | |
| **kwargs) -> Dict[str, torch.Tensor]: | |
| """ | |
| Arguments: | |
| - input_obs (:obj:`torch.Tensor`): Input observation without sequence shape: ``(bs, *obs_shape)``. | |
| - only_last_logit (:obj:`bool`): if True 'logit' only contains the output corresponding to the current \ | |
| observation (shape: bs, embedding_dim), otherwise logit has shape (seq_len, bs, embedding_dim). | |
| - data_id (:obj:`List`): id of the envs that are currently running. Memory update and logits return has \ | |
| only effect for those environments. If `None` it is considered that all envs are running. | |
| Returns: | |
| - Dictionary containing the input_sequence 'input_seq' stored in memory and the transformer output 'logit'. | |
| """ | |
| if self.obs_memory is None: | |
| self.reset_input(torch.zeros_like(input_obs)) # init the memory with the size of the input observation | |
| if data_id is None: | |
| data_id = list(range(self.bs)) | |
| assert self.obs_memory.shape[0] == self.seq_len | |
| # implements a fifo queue, self.memory_idx is index where to put the last element | |
| for i, b in enumerate(data_id): | |
| if self.memory_idx[b] == self.seq_len: | |
| # roll back of 1 position along dim 1 (sequence dim) | |
| self.obs_memory[:, b] = torch.roll(self.obs_memory[:, b], -1, 0) | |
| self.obs_memory[self.memory_idx[b] - 1, b] = input_obs[i] | |
| if self.memory_idx[b] < self.seq_len: | |
| self.obs_memory[self.memory_idx[b], b] = input_obs[i] | |
| if self.memory_idx != self.seq_len: | |
| self.memory_idx[b] += 1 | |
| out = self._model.forward(self.obs_memory, **kwargs) | |
| out['input_seq'] = self.obs_memory | |
| if only_last_logit: | |
| # return only the logits for running environments | |
| out['logit'] = [out['logit'][self.memory_idx[b] - 1][b] for b in range(self.bs) if b in data_id] | |
| out['logit'] = default_collate(out['logit']) | |
| return out | |
| def reset_input(self, input_obs: torch.Tensor): | |
| """ | |
| Overview: | |
| Initialize the whole memory | |
| """ | |
| init_obs = torch.zeros_like(input_obs) | |
| self.init_obs = init_obs | |
| self.obs_memory = [] # List(bs, *obs_shape) | |
| for i in range(self.seq_len): | |
| self.obs_memory.append(init_obs.clone() if init_obs is not None else self._init_fn()) | |
| self.obs_memory = default_collate(self.obs_memory) # shape (N, bs, *obs_shape) | |
| self.bs = self.init_obs.shape[0] | |
| self.memory_idx = [0 for _ in range(self.bs)] | |
| # called before evaluation | |
| # called after each evaluation iteration for each done env | |
| # called after each collect iteration for each done env | |
| def reset(self, *args, **kwargs): | |
| state_id = kwargs.get('data_id', None) | |
| input_obs = kwargs.get('input_obs', None) | |
| if input_obs is not None: | |
| self.reset_input(input_obs) | |
| if state_id is not None: | |
| self.reset_memory_entry(state_id) | |
| if input_obs is None and state_id is None: | |
| self.obs_memory = None | |
| if hasattr(self._model, 'reset'): | |
| return self._model.reset(*args, **kwargs) | |
| def reset_memory_entry(self, state_id: Optional[list] = None) -> None: | |
| """ | |
| Overview: | |
| Reset specific batch of the memory, batch ids are specified in 'state_id' | |
| """ | |
| assert self.init_obs is not None, 'Call method "reset_memory" first' | |
| for _id in state_id: | |
| self.memory_idx[_id] = 0 | |
| self.obs_memory[:, _id] = self.init_obs[_id] # init the corresponding sequence with broadcasting | |
| class TransformerSegmentWrapper(IModelWrapper): | |
| def __init__(self, model: Any, seq_len: int) -> None: | |
| """ | |
| Overview: | |
| Given T the length of a trajectory and N the length of the sequences received by a Transformer model, | |
| split T in sequences of N elements and forward each sequence one by one. If T % N != 0, the last sequence | |
| will be zero-padded. Usually used during Transformer training phase. | |
| Arguments: | |
| - model (:obj:`Any`): Wrapped model class, should contain forward method. | |
| - seq_len (:obj:`int`): N, length of a sequence. | |
| """ | |
| super().__init__(model) | |
| self.seq_len = seq_len | |
| def forward(self, obs: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: | |
| """ | |
| Arguments: | |
| - data (:obj:`dict`): Dict type data, including at least \ | |
| ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight'] | |
| Returns: | |
| - List containing a dict of the model output for each sequence. | |
| """ | |
| sequences = list(torch.split(obs, self.seq_len, dim=0)) | |
| if sequences[-1].shape[0] < self.seq_len: | |
| last = sequences[-1].clone() | |
| diff = self.seq_len - last.shape[0] | |
| sequences[-1] = F.pad(input=last, pad=(0, 0, 0, 0, 0, diff), mode='constant', value=0) | |
| outputs = [] | |
| for i, seq in enumerate(sequences): | |
| out = self._model.forward(seq, **kwargs) | |
| outputs.append(out) | |
| out = {} | |
| for k in outputs[0].keys(): | |
| out_k = [o[k] for o in outputs] | |
| out_k = torch.cat(out_k, dim=0) | |
| out[k] = out_k | |
| return out | |
| class TransformerMemoryWrapper(IModelWrapper): | |
| def __init__( | |
| self, | |
| model: Any, | |
| batch_size: int, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Stores a copy of the Transformer memory in order to be reused across different phases. To make it more | |
| clear, suppose the training pipeline is divided into 3 phases: evaluate, collect, learn. The goal of the | |
| wrapper is to maintain the content of the memory at the end of each phase and reuse it when the same phase | |
| is executed again. In this way, it prevents different phases to interferer each other memory. | |
| Arguments: | |
| - model (:obj:`Any`): Wrapped model class, should contain forward method. | |
| - batch_size (:obj:`int`): Memory batch size. | |
| """ | |
| super().__init__(model) | |
| # shape (layer_num, memory_len, bs, embedding_dim) | |
| self._model.reset_memory(batch_size=batch_size) | |
| self.memory = self._model.get_memory() | |
| self.mem_shape = self.memory.shape | |
| def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: | |
| """ | |
| Arguments: | |
| - data (:obj:`dict`): Dict type data, including at least \ | |
| ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight'] | |
| Returns: | |
| - Output of the forward method. | |
| """ | |
| self._model.reset_memory(state=self.memory) | |
| out = self._model.forward(*args, **kwargs) | |
| self.memory = self._model.get_memory() | |
| return out | |
| def reset(self, *args, **kwargs): | |
| state_id = kwargs.get('data_id', None) | |
| if state_id is None: | |
| self.memory = torch.zeros(self.mem_shape) | |
| else: | |
| self.reset_memory_entry(state_id) | |
| if hasattr(self._model, 'reset'): | |
| return self._model.reset(*args, **kwargs) | |
| def reset_memory_entry(self, state_id: Optional[list] = None) -> None: | |
| """ | |
| Overview: | |
| Reset specific batch of the memory, batch ids are specified in 'state_id' | |
| """ | |
| for _id in state_id: | |
| self.memory[:, :, _id] = torch.zeros((self.mem_shape[-1])) | |
| def show_memory_occupancy(self, layer=0) -> None: | |
| memory = self.memory | |
| memory_shape = memory.shape | |
| print('Layer {}-------------------------------------------'.format(layer)) | |
| for b in range(memory_shape[-2]): | |
| print('b{}: '.format(b), end='') | |
| for m in range(memory_shape[1]): | |
| if sum(abs(memory[layer][m][b].flatten())) != 0: | |
| print(1, end='') | |
| else: | |
| print(0, end='') | |
| print() | |
| def sample_action(logit=None, prob=None): | |
| if prob is None: | |
| prob = torch.softmax(logit, dim=-1) | |
| shape = prob.shape | |
| prob += 1e-8 | |
| prob = prob.view(-1, shape[-1]) | |
| # prob can also be treated as weight in multinomial sample | |
| action = torch.multinomial(prob, 1).squeeze(-1) | |
| action = action.view(*shape[:-1]) | |
| return action | |
| class ArgmaxSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Used to help the model to sample argmax action. | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| """ | |
| Overview: | |
| Employ model forward computation graph, and use the output logit to greedily select max action (argmax). | |
| """ | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| action = [l.argmax(dim=-1) for l in logit] | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output['action'] = action | |
| return output | |
| class CombinationArgmaxSampleWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Used to help the model to sample combination argmax action. | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, shot_number, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| # Generate actions. | |
| act = [] | |
| mask = torch.zeros_like(output['logit']) | |
| for ii in range(shot_number): | |
| masked_logit = output['logit'] + mask | |
| actions = masked_logit.argmax(dim=-1) | |
| act.append(actions) | |
| for jj in range(actions.shape[0]): | |
| mask[jj][actions[jj]] = -1e8 | |
| # `act` is shaped: (B, shot_number) | |
| act = torch.stack(act, dim=1) | |
| output['action'] = act | |
| return output | |
| class CombinationMultinomialSampleWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Used to help the model to sample combination multinomial action. | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, shot_number, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| # Generate actions. | |
| act = [] | |
| mask = torch.zeros_like(output['logit']) | |
| for ii in range(shot_number): | |
| dist = torch.distributions.Categorical(logits=output['logit'] + mask) | |
| actions = dist.sample() | |
| act.append(actions) | |
| for jj in range(actions.shape[0]): | |
| mask[jj][actions[jj]] = -1e8 | |
| # `act` is shaped: (B, shot_number) | |
| act = torch.stack(act, dim=1) | |
| output['action'] = act | |
| return output | |
| class HybridArgmaxSampleWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Used to help the model to sample argmax action in hybrid action space, | |
| i.e.{'action_type': discrete, 'action_args', continuous} | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| if 'logit' not in output: | |
| return output | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| action = [l.argmax(dim=-1) for l in logit] | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} | |
| return output | |
| class MultinomialSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Used to help the model get the corresponding action from the output['logits']self. | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| if 'alpha' in kwargs.keys(): | |
| alpha = kwargs.pop('alpha') | |
| else: | |
| alpha = None | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| if alpha is None: | |
| action = [sample_action(logit=l) for l in logit] | |
| else: | |
| # Note that if alpha is passed in here, we will divide logit by alpha. | |
| action = [sample_action(logit=l / alpha) for l in logit] | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output['action'] = action | |
| return output | |
| class EpsGreedySampleWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Epsilon greedy sampler used in collector_model to help balance exploratin and exploitation. | |
| The type of eps can vary from different algorithms, such as: | |
| - float (i.e. python native scalar): for almost normal case | |
| - Dict[str, float]: for algorithm NGU | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| eps = kwargs.pop('eps') | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| else: | |
| mask = None | |
| action = [] | |
| if isinstance(eps, dict): | |
| # for NGU policy, eps is a dict, each collect env has a different eps | |
| for i, l in enumerate(logit[0]): | |
| eps_tmp = eps[i] | |
| if np.random.random() > eps_tmp: | |
| action.append(l.argmax(dim=-1)) | |
| else: | |
| if mask is not None: | |
| action.append( | |
| sample_action(prob=mask[0][i].float().unsqueeze(0)).to(logit[0].device).squeeze(0) | |
| ) | |
| else: | |
| action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]).to(logit[0].device)) | |
| action = torch.stack(action, dim=-1) # shape torch.size([env_num]) | |
| else: | |
| for i, l in enumerate(logit): | |
| if np.random.random() > eps: | |
| action.append(l.argmax(dim=-1)) | |
| else: | |
| if mask is not None: | |
| action.append(sample_action(prob=mask[i].float())) | |
| else: | |
| action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output['action'] = action | |
| return output | |
| class EpsGreedyMultinomialSampleWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Epsilon greedy sampler coupled with multinomial sample used in collector_model | |
| to help balance exploration and exploitation. | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| eps = kwargs.pop('eps') | |
| if 'alpha' in kwargs.keys(): | |
| alpha = kwargs.pop('alpha') | |
| else: | |
| alpha = None | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| else: | |
| mask = None | |
| action = [] | |
| for i, l in enumerate(logit): | |
| if np.random.random() > eps: | |
| if alpha is None: | |
| action = [sample_action(logit=l) for l in logit] | |
| else: | |
| # Note that if alpha is passed in here, we will divide logit by alpha. | |
| action = [sample_action(logit=l / alpha) for l in logit] | |
| else: | |
| if mask: | |
| action.append(sample_action(prob=mask[i].float())) | |
| else: | |
| action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output['action'] = action | |
| return output | |
| class HybridEpsGreedySampleWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Epsilon greedy sampler used in collector_model to help balance exploration and exploitation. | |
| In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| eps = kwargs.pop('eps') | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| else: | |
| mask = None | |
| action = [] | |
| for i, l in enumerate(logit): | |
| if np.random.random() > eps: | |
| action.append(l.argmax(dim=-1)) | |
| else: | |
| if mask: | |
| action.append(sample_action(prob=mask[i].float())) | |
| else: | |
| action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} | |
| return output | |
| class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Epsilon greedy sampler coupled with multinomial sample used in collector_model | |
| to help balance exploration and exploitation. | |
| In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} | |
| Interfaces: | |
| ``forward``. | |
| """ | |
| def forward(self, *args, **kwargs): | |
| eps = kwargs.pop('eps') | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| if 'logit' not in output: | |
| return output | |
| logit = output['logit'] | |
| assert isinstance(logit, torch.Tensor) or isinstance(logit, list) | |
| if isinstance(logit, torch.Tensor): | |
| logit = [logit] | |
| if 'action_mask' in output: | |
| mask = output['action_mask'] | |
| if isinstance(mask, torch.Tensor): | |
| mask = [mask] | |
| logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] | |
| else: | |
| mask = None | |
| action = [] | |
| for i, l in enumerate(logit): | |
| if np.random.random() > eps: | |
| action = [sample_action(logit=l) for l in logit] | |
| else: | |
| if mask: | |
| action.append(sample_action(prob=mask[i].float())) | |
| else: | |
| action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) | |
| if len(action) == 1: | |
| action, logit = action[0], logit[0] | |
| output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} | |
| return output | |
| class HybridReparamMultinomialSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Reparameterization sampler coupled with multinomial sample used in collector_model | |
| to help balance exploration and exploitation. | |
| In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} | |
| Interfaces: | |
| forward | |
| """ | |
| def forward(self, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] # logit: {'action_type': action_type_logit, 'action_args': action_args_logit} | |
| # discrete part | |
| action_type_logit = logit['action_type'] | |
| prob = torch.softmax(action_type_logit, dim=-1) | |
| pi_action = Categorical(prob) | |
| action_type = pi_action.sample() | |
| # continuous part | |
| mu, sigma = logit['action_args']['mu'], logit['action_args']['sigma'] | |
| dist = Independent(Normal(mu, sigma), 1) | |
| action_args = dist.sample() | |
| action = {'action_type': action_type, 'action_args': action_args} | |
| output['action'] = action | |
| return output | |
| class HybridDeterministicArgmaxSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Deterministic sampler coupled with argmax sample used in eval_model. | |
| In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} | |
| Interfaces: | |
| forward | |
| """ | |
| def forward(self, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| logit = output['logit'] # logit: {'action_type': action_type_logit, 'action_args': action_args_logit} | |
| # discrete part | |
| action_type_logit = logit['action_type'] | |
| action_type = action_type_logit.argmax(dim=-1) | |
| # continuous part | |
| mu = logit['action_args']['mu'] | |
| action_args = mu | |
| action = {'action_type': action_type, 'action_args': action_args} | |
| output['action'] = action | |
| return output | |
| class DeterministicSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Deterministic sampler (just use mu directly) used in eval_model. | |
| Interfaces: | |
| forward | |
| """ | |
| def forward(self, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| output['action'] = output['logit']['mu'] | |
| return output | |
| class ReparamSampleWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Reparameterization gaussian sampler used in collector_model. | |
| Interfaces: | |
| forward | |
| """ | |
| def forward(self, *args, **kwargs): | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| mu, sigma = output['logit']['mu'], output['logit']['sigma'] | |
| dist = Independent(Normal(mu, sigma), 1) | |
| output['action'] = dist.sample() | |
| return output | |
| class ActionNoiseWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Add noise to collector's action output; Do clips on both generated noise and action after adding noise. | |
| Interfaces: | |
| ``__init__``, ``forward``. | |
| Arguments: | |
| - model (:obj:`Any`): Wrapped model class. Should contain ``forward`` method. | |
| - noise_type (:obj:`str`): The type of noise that should be generated, support ['gauss', 'ou']. | |
| - noise_kwargs (:obj:`dict`): Keyword args that should be used in noise init. Depends on ``noise_type``. | |
| - noise_range (:obj:`Optional[dict]`): Range of noise, used for clipping. | |
| - action_range (:obj:`Optional[dict]`): Range of action + noise, used for clip, default clip to [-1, 1]. | |
| """ | |
| def __init__( | |
| self, | |
| model: Any, | |
| noise_type: str = 'gauss', | |
| noise_kwargs: dict = {}, | |
| noise_range: Optional[dict] = None, | |
| action_range: Optional[dict] = { | |
| 'min': -1, | |
| 'max': 1 | |
| } | |
| ) -> None: | |
| super().__init__(model) | |
| self.noise_generator = create_noise_generator(noise_type, noise_kwargs) | |
| self.noise_range = noise_range | |
| self.action_range = action_range | |
| def forward(self, *args, **kwargs): | |
| # if noise sigma need decay, update noise kwargs. | |
| if 'sigma' in kwargs: | |
| sigma = kwargs.pop('sigma') | |
| if sigma is not None: | |
| self.noise_generator.sigma = sigma | |
| output = self._model.forward(*args, **kwargs) | |
| assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) | |
| if 'action' in output or 'action_args' in output: | |
| key = 'action' if 'action' in output else 'action_args' | |
| action = output[key] | |
| assert isinstance(action, torch.Tensor) | |
| action = self.add_noise(action) | |
| output[key] = action | |
| return output | |
| def add_noise(self, action: torch.Tensor) -> torch.Tensor: | |
| r""" | |
| Overview: | |
| Generate noise and clip noise if needed. Add noise to action and clip action if needed. | |
| Arguments: | |
| - action (:obj:`torch.Tensor`): Model's action output. | |
| Returns: | |
| - noised_action (:obj:`torch.Tensor`): Action processed after adding noise and clipping. | |
| """ | |
| noise = self.noise_generator(action.shape, action.device) | |
| if self.noise_range is not None: | |
| noise = noise.clamp(self.noise_range['min'], self.noise_range['max']) | |
| action += noise | |
| if self.action_range is not None: | |
| action = action.clamp(self.action_range['min'], self.action_range['max']) | |
| return action | |
| class TargetNetworkWrapper(IModelWrapper): | |
| r""" | |
| Overview: | |
| Maintain and update the target network | |
| Interfaces: | |
| update, reset | |
| """ | |
| def __init__(self, model: Any, update_type: str, update_kwargs: dict): | |
| super().__init__(model) | |
| assert update_type in ['momentum', 'assign'] | |
| self._update_type = update_type | |
| self._update_kwargs = update_kwargs | |
| self._update_count = 0 | |
| def reset(self, *args, **kwargs): | |
| target_update_count = kwargs.pop('target_update_count', None) | |
| self.reset_state(target_update_count) | |
| if hasattr(self._model, 'reset'): | |
| return self._model.reset(*args, **kwargs) | |
| def update(self, state_dict: dict, direct: bool = False) -> None: | |
| r""" | |
| Overview: | |
| Update the target network state dict | |
| Arguments: | |
| - state_dict (:obj:`dict`): the state_dict from learner model | |
| - direct (:obj:`bool`): whether to update the target network directly, \ | |
| if true then will simply call the load_state_dict method of the model | |
| """ | |
| if direct: | |
| self._model.load_state_dict(state_dict, strict=True) | |
| self._update_count = 0 | |
| else: | |
| if self._update_type == 'assign': | |
| if (self._update_count + 1) % self._update_kwargs['freq'] == 0: | |
| self._model.load_state_dict(state_dict, strict=True) | |
| self._update_count += 1 | |
| elif self._update_type == 'momentum': | |
| theta = self._update_kwargs['theta'] | |
| for name, p in self._model.named_parameters(): | |
| # default theta = 0.001 | |
| p.data = (1 - theta) * p.data + theta * state_dict[name] | |
| def reset_state(self, target_update_count: int = None) -> None: | |
| r""" | |
| Overview: | |
| Reset the update_count | |
| Arguments: | |
| target_update_count (:obj:`int`): reset target update count value. | |
| """ | |
| if target_update_count is not None: | |
| self._update_count = target_update_count | |
| class TeacherNetworkWrapper(IModelWrapper): | |
| """ | |
| Overview: | |
| Set the teacher Network. Set the model's model.teacher_cfg to the input teacher_cfg | |
| """ | |
| def __init__(self, model, teacher_cfg): | |
| super().__init__(model) | |
| self._model._teacher_cfg = teacher_cfg | |
| raise NotImplementedError | |
| wrapper_name_map = { | |
| 'base': BaseModelWrapper, | |
| 'hidden_state': HiddenStateWrapper, | |
| 'argmax_sample': ArgmaxSampleWrapper, | |
| 'hybrid_argmax_sample': HybridArgmaxSampleWrapper, | |
| 'eps_greedy_sample': EpsGreedySampleWrapper, | |
| 'eps_greedy_multinomial_sample': EpsGreedyMultinomialSampleWrapper, | |
| 'deterministic_sample': DeterministicSampleWrapper, | |
| 'reparam_sample': ReparamSampleWrapper, | |
| 'hybrid_eps_greedy_sample': HybridEpsGreedySampleWrapper, | |
| 'hybrid_eps_greedy_multinomial_sample': HybridEpsGreedyMultinomialSampleWrapper, | |
| 'hybrid_reparam_multinomial_sample': HybridReparamMultinomialSampleWrapper, | |
| 'hybrid_deterministic_argmax_sample': HybridDeterministicArgmaxSampleWrapper, | |
| 'multinomial_sample': MultinomialSampleWrapper, | |
| 'action_noise': ActionNoiseWrapper, | |
| 'transformer_input': TransformerInputWrapper, | |
| 'transformer_segment': TransformerSegmentWrapper, | |
| 'transformer_memory': TransformerMemoryWrapper, | |
| # model wrapper | |
| 'target': TargetNetworkWrapper, | |
| 'teacher': TeacherNetworkWrapper, | |
| 'combination_argmax_sample': CombinationArgmaxSampleWrapper, | |
| 'combination_multinomial_sample': CombinationMultinomialSampleWrapper, | |
| } | |
| def model_wrap(model: Union[nn.Module, IModelWrapper], wrapper_name: str = None, **kwargs): | |
| """ | |
| Overview: | |
| Wrap the model with the specified wrapper and return the wrappered model. | |
| Arguments: | |
| - model (:obj:`Any`): The model to be wrapped. | |
| - wrapper_name (:obj:`str`): The name of the wrapper to be used. | |
| .. note:: | |
| The arguments of the wrapper should be passed in as kwargs. | |
| """ | |
| if wrapper_name in wrapper_name_map: | |
| # TODO test whether to remove this if branch | |
| if not isinstance(model, IModelWrapper): | |
| model = wrapper_name_map['base'](model) | |
| model = wrapper_name_map[wrapper_name](model, **kwargs) | |
| else: | |
| raise TypeError("not support model_wrapper type: {}".format(wrapper_name)) | |
| return model | |
| def register_wrapper(name: str, wrapper_type: type) -> None: | |
| """ | |
| Overview: | |
| Register new wrapper to ``wrapper_name_map``. When user implements a new wrapper, they must call this function \ | |
| to complete the registration. Then the wrapper can be called by ``model_wrap``. | |
| Arguments: | |
| - name (:obj:`str`): The name of the new wrapper to be registered. | |
| - wrapper_type (:obj:`type`): The wrapper class needs to be added in ``wrapper_name_map``. This argument \ | |
| should be the subclass of ``IModelWrapper``. | |
| """ | |
| assert isinstance(name, str) | |
| assert issubclass(wrapper_type, IModelWrapper) | |
| wrapper_name_map[name] = wrapper_type | |