Spaces:
Sleeping
Sleeping
| from typing import Union, Optional, Dict | |
| import torch | |
| import torch.nn as nn | |
| from easydict import EasyDict | |
| from ding.utils import MODEL_REGISTRY, SequenceType, squeeze | |
| from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \ | |
| MultiHead, RegressionHead, ReparameterizationHead | |
| class DiscreteBC(nn.Module): | |
| """ | |
| Overview: | |
| The DiscreteBC network. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType], | |
| encoder_hidden_size_list: SequenceType = [128, 128, 64], | |
| dueling: bool = True, | |
| head_hidden_size: Optional[int] = None, | |
| head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| strides: Optional[list] = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Init the DiscreteBC (encoder + head) Model according to input arguments. | |
| Arguments: | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. | |
| - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. | |
| - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ | |
| the last element must match ``head_hidden_size``. | |
| - dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``. | |
| - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. | |
| - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output | |
| - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ | |
| if ``None`` then default set it to ``nn.ReLU()``. | |
| - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ | |
| ``ding.torch_utils.fc_block`` for more details. | |
| - strides (:obj:`Optional[list]`): The strides for each convolution layers, such as [2, 2, 2]. The length \ | |
| of this argument should be the same as ``encoder_hidden_size_list``. | |
| """ | |
| super(DiscreteBC, self).__init__() | |
| # For compatibility: 1, (1, ), [4, 32, 32] | |
| obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) | |
| if head_hidden_size is None: | |
| head_hidden_size = encoder_hidden_size_list[-1] | |
| # FC Encoder | |
| if isinstance(obs_shape, int) or len(obs_shape) == 1: | |
| self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) | |
| # Conv Encoder | |
| elif len(obs_shape) == 3: | |
| if not strides: | |
| self.encoder = ConvEncoder( | |
| obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type | |
| ) | |
| else: | |
| self.encoder = ConvEncoder( | |
| obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, stride=strides | |
| ) | |
| else: | |
| raise RuntimeError( | |
| "not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape) | |
| ) | |
| # Head Type | |
| if dueling: | |
| head_cls = DuelingHead | |
| else: | |
| head_cls = DiscreteHead | |
| multi_head = not isinstance(action_shape, int) | |
| if multi_head: | |
| self.head = MultiHead( | |
| head_cls, | |
| head_hidden_size, | |
| action_shape, | |
| layer_num=head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| else: | |
| self.head = head_cls( | |
| head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type | |
| ) | |
| def forward(self, x: torch.Tensor) -> Dict: | |
| """ | |
| Overview: | |
| DiscreteBC forward computation graph, input observation tensor to predict q_value. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): Observation inputs | |
| Returns: | |
| - outputs (:obj:`Dict`): DiscreteBC forward outputs, such as q_value. | |
| ReturnsKeys: | |
| - logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension. | |
| Shapes: | |
| - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` | |
| - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` | |
| Examples: | |
| >>> model = DiscreteBC(32, 6) # arguments: 'obs_shape' and 'action_shape' | |
| >>> inputs = torch.randn(4, 32) | |
| >>> outputs = model(inputs) | |
| >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6]) | |
| """ | |
| x = self.encoder(x) | |
| x = self.head(x) | |
| return x | |
| class ContinuousBC(nn.Module): | |
| """ | |
| Overview: | |
| The ContinuousBC network. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType, EasyDict], | |
| action_space: str, | |
| actor_head_hidden_size: int = 64, | |
| actor_head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the ContinuousBC Model according to input arguments. | |
| Arguments: | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). | |
| - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ | |
| EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). | |
| - action_space (:obj:`str`): The type of action space, \ | |
| including [``regression``, ``reparameterization``]. | |
| - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. | |
| - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ | |
| for actor head. | |
| - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ | |
| after each FC layer, if ``None`` then default set to ``nn.ReLU()``. | |
| - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ | |
| see ``ding.torch_utils.network`` for more details. | |
| """ | |
| super(ContinuousBC, self).__init__() | |
| obs_shape: int = squeeze(obs_shape) | |
| action_shape = squeeze(action_shape) | |
| self.action_shape = action_shape | |
| self.action_space = action_space | |
| assert self.action_space in ['regression', 'reparameterization'] | |
| if self.action_space == 'regression': | |
| self.actor = nn.Sequential( | |
| nn.Linear(obs_shape, actor_head_hidden_size), activation, | |
| RegressionHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| final_tanh=True, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| elif self.action_space == 'reparameterization': | |
| self.actor = nn.Sequential( | |
| nn.Linear(obs_shape, actor_head_hidden_size), activation, | |
| ReparameterizationHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| sigma_type='conditioned', | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict: | |
| """ | |
| Overview: | |
| The unique execution (forward) method of ContinuousBC. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. | |
| Returns: | |
| - output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space. | |
| ReturnsKeys: | |
| - action (:obj:`torch.Tensor`): action output of actor network, \ | |
| with shape :math:`(B, action_shape)`. | |
| - logit (:obj:`List[torch.Tensor]`): reparameterized action output of actor network, \ | |
| with shape :math:`(B, action_shape)`. | |
| Shapes: | |
| - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` | |
| - action (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` | |
| - logit (:obj:`List[torch.FloatTensor]`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` | |
| Examples (Regression): | |
| >>> model = ContinuousBC(32, 6, action_space='regression') | |
| >>> inputs = torch.randn(4, 32) | |
| >>> outputs = model(inputs) | |
| >>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6]) | |
| Examples (Reparameterization): | |
| >>> model = ContinuousBC(32, 6, action_space='reparameterization') | |
| >>> inputs = torch.randn(4, 32) | |
| >>> outputs = model(inputs) | |
| >>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6]) | |
| >>> assert outputs['logit'][1].shape == torch.Size([4, 6]) | |
| """ | |
| if self.action_space == 'regression': | |
| x = self.actor(inputs) | |
| return {'action': x['pred']} | |
| elif self.action_space == 'reparameterization': | |
| x = self.actor(inputs) | |
| return {'logit': [x['mu'], x['sigma']]} | |