Spaces:
Sleeping
Sleeping
| from typing import Union, Optional, Dict | |
| from easydict import EasyDict | |
| import torch | |
| import torch.nn as nn | |
| from ding.torch_utils import get_lstm | |
| from ding.utils import MODEL_REGISTRY, SequenceType, squeeze | |
| from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, RegressionHead | |
| class PDQN(nn.Module): | |
| """ | |
| Overview: | |
| The neural network and computation graph of PDQN(https://arxiv.org/abs/1810.06394v1) and \ | |
| MPDQN(https://arxiv.org/abs/1905.04388) algorithms for parameterized action space. \ | |
| This model supports parameterized action space with discrete ``action_type`` and continuous ``action_arg``. \ | |
| In principle, PDQN consists of x network (continuous action parameter network) and Q network (discrete \ | |
| action type network). But for simplicity, the code is split into ``encoder`` and ``actor_head``, which \ | |
| contain the encoder and head of the above two networks respectively. | |
| Interface: | |
| ``__init__``, ``forward``, ``compute_discrete``, ``compute_continuous``. | |
| """ | |
| mode = ['compute_discrete', 'compute_continuous'] | |
| def __init__( | |
| self, | |
| obs_shape: Union[int, SequenceType], | |
| action_shape: EasyDict, | |
| encoder_hidden_size_list: SequenceType = [128, 128, 64], | |
| dueling: bool = True, | |
| head_hidden_size: Optional[int] = None, | |
| head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| multi_pass: Optional[bool] = False, | |
| action_mask: Optional[list] = None | |
| ) -> None: | |
| """ | |
| Overview: | |
| Init the PDQN (encoder + head) Model according to input arguments. | |
| Arguments: | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. | |
| - action_shape (:obj:`EasyDict`): Action space shape in dict type, such as \ | |
| EasyDict({'action_type_shape': 3, 'action_args_shape': 5}). | |
| - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ | |
| the last element must match ``head_hidden_size``. | |
| - dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``. | |
| - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. | |
| - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output. | |
| - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ | |
| if ``None`` then default set it to ``nn.ReLU()``. | |
| - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ | |
| ``ding.torch_utils.fc_block`` for more details. | |
| - multi_pass (:obj:`Optional[bool]`): Whether to use multi pass version. | |
| - action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are \ | |
| associated to each discrete action. For example, if there are 3 discrete action, \ | |
| 4 continous action args, and the first discrete action associates with the first \ | |
| continuous action args, the second discrete action associates with the second continuous \ | |
| action args, and the third discrete action associates with the remaining 2 action args, \ | |
| the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4. | |
| """ | |
| super(PDQN, self).__init__() | |
| self.multi_pass = multi_pass | |
| if self.multi_pass: | |
| assert isinstance( | |
| action_mask, list | |
| ), 'Please indicate action mask in list form if you set multi_pass to True' | |
| self.action_mask = torch.LongTensor(action_mask) | |
| nonzero = torch.nonzero(self.action_mask) | |
| index = torch.zeros(action_shape.action_args_shape).long() | |
| index.scatter_(dim=0, index=nonzero[:, 1], src=nonzero[:, 0]) | |
| self.action_scatter_index = index # (self.action_args_shape, ) | |
| # squeeze action shape input like (3,) to 3 | |
| action_shape.action_args_shape = squeeze(action_shape.action_args_shape) | |
| action_shape.action_type_shape = squeeze(action_shape.action_type_shape) | |
| self.action_args_shape = action_shape.action_args_shape | |
| self.action_type_shape = action_shape.action_type_shape | |
| # init head hidden size | |
| if head_hidden_size is None: | |
| head_hidden_size = encoder_hidden_size_list[-1] | |
| # squeeze obs input for compatibility: 1, (1, ), [4, 32, 32] | |
| obs_shape = squeeze(obs_shape) | |
| # Obs Encoder Type | |
| if isinstance(obs_shape, int) or len(obs_shape) == 1: # FC Encoder | |
| self.dis_encoder = FCEncoder( | |
| obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type | |
| ) | |
| self.cont_encoder = FCEncoder( | |
| obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type | |
| ) | |
| elif len(obs_shape) == 3: # Conv Encoder | |
| self.dis_encoder = ConvEncoder( | |
| obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type | |
| ) | |
| self.cont_encoder = ConvEncoder( | |
| obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type | |
| ) | |
| else: | |
| raise RuntimeError( | |
| "Pre-defined encoder not support obs_shape {}, please customize your own PDQN.".format(obs_shape) | |
| ) | |
| # Continuous Action Head Type | |
| self.cont_head = RegressionHead( | |
| head_hidden_size, | |
| action_shape.action_args_shape, | |
| head_layer_num, | |
| final_tanh=True, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| # Discrete Action Head Type | |
| if dueling: | |
| dis_head_cls = DuelingHead | |
| else: | |
| dis_head_cls = DiscreteHead | |
| self.dis_head = dis_head_cls( | |
| head_hidden_size + action_shape.action_args_shape, | |
| action_shape.action_type_shape, | |
| head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| self.actor_head = nn.ModuleList([self.dis_head, self.cont_head]) | |
| # self.encoder = nn.ModuleList([self.dis_encoder, self.cont_encoder]) | |
| # To speed up the training process, the X network and the Q network share the encoder for the state | |
| self.encoder = nn.ModuleList([self.cont_encoder, self.cont_encoder]) | |
| def forward(self, inputs: Union[torch.Tensor, Dict, EasyDict], mode: str) -> Dict: | |
| """ | |
| Overview: | |
| PDQN forward computation graph, input observation tensor to predict q_value for \ | |
| discrete actions and values for continuous action_args. | |
| Arguments: | |
| - inputs (:obj:`Union[torch.Tensor, Dict, EasyDict]`): Inputs including observation and \ | |
| other info according to `mode`. | |
| - mode (:obj:`str`): Name of the forward mode. | |
| Shapes: | |
| - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``. | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_continuous(self, inputs: torch.Tensor) -> Dict: | |
| """ | |
| Overview: | |
| Use observation tensor to predict continuous action args. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): Observation inputs. | |
| Returns: | |
| - outputs (:obj:`Dict`): A dict with key 'action_args'. | |
| - 'action_args' (:obj:`torch.Tensor`): The continuous action args. | |
| Shapes: | |
| - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``. | |
| - action_args (:obj:`torch.Tensor`): :math:`(B, M)`, where M is ``action_args_shape``. | |
| Examples: | |
| >>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )}) | |
| >>> model = PDQN(4, act_shape) | |
| >>> inputs = torch.randn(64, 4) | |
| >>> outputs = model.forward(inputs, mode='compute_continuous') | |
| >>> assert outputs['action_args'].shape == torch.Size([64, 5]) | |
| """ | |
| cont_x = self.encoder[1](inputs) # size (B, encoded_state_shape) | |
| action_args = self.actor_head[1](cont_x)['pred'] # size (B, action_args_shape) | |
| outputs = {'action_args': action_args} | |
| return outputs | |
| def compute_discrete(self, inputs: Union[Dict, EasyDict]) -> Dict: | |
| """ | |
| Overview: | |
| Use observation tensor and continuous action args to predict discrete action types. | |
| Arguments: | |
| - inputs (:obj:`Union[Dict, EasyDict]`): A dict with keys 'state', 'action_args'. | |
| - state (:obj:`torch.Tensor`): Observation inputs. | |
| - action_args (:obj:`torch.Tensor`): Action parameters are used to concatenate with the observation \ | |
| and serve as input to the discrete action type network. | |
| Returns: | |
| - outputs (:obj:`Dict`): A dict with keys 'logit', 'action_args'. | |
| - 'logit': The logit value for each discrete action. | |
| - 'action_args': The continuous action args(same as the inputs['action_args']) for later usage. | |
| Examples: | |
| >>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )}) | |
| >>> model = PDQN(4, act_shape) | |
| >>> inputs = {'state': torch.randn(64, 4), 'action_args': torch.randn(64, 5)} | |
| >>> outputs = model.forward(inputs, mode='compute_discrete') | |
| >>> assert outputs['logit'].shape == torch.Size([64, 3]) | |
| >>> assert outputs['action_args'].shape == torch.Size([64, 5]) | |
| """ | |
| dis_x = self.encoder[0](inputs['state']) # size (B, encoded_state_shape) | |
| action_args = inputs['action_args'] # size (B, action_args_shape) | |
| if self.multi_pass: # mpdqn | |
| # fill_value=-2 is a mask value, which is not in normal acton range | |
| # (B, action_args_shape, K) where K is the action_type_shape | |
| mp_action = torch.full( | |
| (dis_x.shape[0], self.action_args_shape, self.action_type_shape), | |
| fill_value=-2, | |
| device=dis_x.device, | |
| dtype=dis_x.dtype | |
| ) | |
| index = self.action_scatter_index.view(1, -1, 1).repeat(dis_x.shape[0], 1, 1).to(dis_x.device) | |
| # index: (B, action_args_shape, 1) src: (B, action_args_shape, 1) | |
| mp_action.scatter_(dim=-1, index=index, src=action_args.unsqueeze(-1)) | |
| mp_action = mp_action.permute(0, 2, 1) # (B, K, action_args_shape) | |
| mp_state = dis_x.unsqueeze(1).repeat(1, self.action_type_shape, 1) # (B, K, obs_shape) | |
| mp_state_action_cat = torch.cat([mp_state, mp_action], dim=-1) | |
| logit = self.actor_head[0](mp_state_action_cat)['logit'] # (B, K, K) | |
| logit = torch.diagonal(logit, dim1=-2, dim2=-1) # (B, K) | |
| else: # pdqn | |
| # size (B, encoded_state_shape + action_args_shape) | |
| if len(action_args.shape) == 1: # (B, ) -> (B, 1) | |
| action_args = action_args.unsqueeze(1) | |
| state_action_cat = torch.cat((dis_x, action_args), dim=-1) | |
| logit = self.actor_head[0](state_action_cat)['logit'] # size (B, K) where K is action_type_shape | |
| outputs = {'logit': logit, 'action_args': action_args} | |
| return outputs | |