Spaces:
Sleeping
Sleeping
| from typing import Union, Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| from ding.utils import SequenceType, squeeze, MODEL_REGISTRY | |
| from ..common import RegressionHead, ReparameterizationHead, DistributionHead | |
| class QACDIST(nn.Module): | |
| """ | |
| Overview: | |
| The QAC model with distributional Q-value. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` | |
| """ | |
| mode = ['compute_actor', 'compute_critic'] | |
| def __init__( | |
| self, | |
| obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType], | |
| action_space: str = "regression", | |
| critic_head_type: str = "categorical", | |
| actor_head_hidden_size: int = 64, | |
| actor_head_layer_num: int = 1, | |
| critic_head_hidden_size: int = 64, | |
| critic_head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| v_min: Optional[float] = -10, | |
| v_max: Optional[float] = 10, | |
| n_atom: Optional[int] = 51, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Init the QAC Distributional Model according to arguments. | |
| Arguments: | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. | |
| - action_shape (:obj:`Union[int, SequenceType]`): Action's space. | |
| - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``. | |
| - critic_head_type (:obj:`str`): Only ``categorical``. | |
| - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. | |
| - actor_head_layer_num (:obj:`int`): | |
| The num of layers used in the network to compute Q value output for actor's nn. | |
| - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. | |
| - critic_head_layer_num (:obj:`int`): | |
| The num of layers used in the network to compute Q value output for critic's nn. | |
| - activation (:obj:`Optional[nn.Module]`): | |
| The type of activation function to use in ``MLP`` the after ``layer_fn``, | |
| if ``None`` then default set to ``nn.ReLU()`` | |
| - norm_type (:obj:`Optional[str]`): | |
| The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details. | |
| - v_min (:obj:`int`): Value of the smallest atom | |
| - v_max (:obj:`int`): Value of the largest atom | |
| - n_atom (:obj:`int`): Number of atoms in the support | |
| """ | |
| super(QACDIST, self).__init__() | |
| obs_shape: int = squeeze(obs_shape) | |
| action_shape: int = squeeze(action_shape) | |
| self.action_space = action_space | |
| assert self.action_space in ['regression', 'reparameterization'] | |
| if self.action_space == 'regression': | |
| self.actor = nn.Sequential( | |
| nn.Linear(obs_shape, actor_head_hidden_size), activation, | |
| RegressionHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| final_tanh=True, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| elif self.action_space == 'reparameterization': | |
| self.actor = nn.Sequential( | |
| nn.Linear(obs_shape, actor_head_hidden_size), activation, | |
| ReparameterizationHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| sigma_type='conditioned', | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| self.critic_head_type = critic_head_type | |
| assert self.critic_head_type in ['categorical'], self.critic_head_type | |
| if self.critic_head_type == 'categorical': | |
| self.critic = nn.Sequential( | |
| nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation, | |
| DistributionHead( | |
| critic_head_hidden_size, | |
| 1, | |
| critic_head_layer_num, | |
| n_atom=n_atom, | |
| v_min=v_min, | |
| v_max=v_max, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: | |
| """ | |
| Overview: | |
| Use observation and action tensor to predict output. | |
| Parameter updates with QACDIST's MLPs forward setup. | |
| Arguments: | |
| Forward with ``'compute_actor'``: | |
| - inputs (:obj:`torch.Tensor`): | |
| The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. | |
| Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. | |
| Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys: | |
| - ``obs``, ``action`` encoded tensors. | |
| - mode (:obj:`str`): Name of the forward mode. | |
| Returns: | |
| - outputs (:obj:`Dict`): Outputs of network forward. | |
| Forward with ``'compute_actor'``, Necessary Keys (either): | |
| - action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``. | |
| - logit (:obj:`torch.Tensor`): | |
| Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. | |
| Forward with ``'compute_critic'``, Necessary Keys: | |
| - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. | |
| - distribution (:obj:`torch.Tensor`): Q value distribution tensor. | |
| Actor Shapes: | |
| - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` | |
| - action (:obj:`torch.Tensor`): :math:`(B, N0)` | |
| - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. | |
| Critic Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` | |
| - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape`` | |
| - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` | |
| - distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom`` | |
| Actor Examples: | |
| >>> # Regression mode | |
| >>> model = QACDIST(64, 64, 'regression') | |
| >>> inputs = torch.randn(4, 64) | |
| >>> actor_outputs = model(inputs,'compute_actor') | |
| >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) | |
| >>> # Reparameterization Mode | |
| >>> model = QACDIST(64, 64, 'reparameterization') | |
| >>> inputs = torch.randn(4, 64) | |
| >>> actor_outputs = model(inputs,'compute_actor') | |
| >>> actor_outputs['logit'][0].shape # mu | |
| >>> torch.Size([4, 64]) | |
| >>> actor_outputs['logit'][1].shape # sigma | |
| >>> torch.Size([4, 64]) | |
| Critic Examples: | |
| >>> # Categorical mode | |
| >>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)} | |
| >>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \ | |
| ... critic_head_type='categorical', n_atoms=51) | |
| >>> q_value = model(inputs, mode='compute_critic') # q value | |
| >>> assert q_value['q_value'].shape == torch.Size([4, 1]) | |
| >>> assert q_value['distribution'].shape == torch.Size([4, 1, 51]) | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_actor(self, inputs: torch.Tensor) -> Dict: | |
| """ | |
| Overview: | |
| Use encoded embedding tensor to predict output. | |
| Execute parameter updates with ``'compute_actor'`` mode | |
| Use encoded embedding tensor to predict output. | |
| Arguments: | |
| - inputs (:obj:`torch.Tensor`): | |
| The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. | |
| ``hidden_size = actor_head_hidden_size`` | |
| - mode (:obj:`str`): Name of the forward mode. | |
| Returns: | |
| - outputs (:obj:`Dict`): Outputs of forward pass encoder and head. | |
| ReturnsKeys (either): | |
| - action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``. | |
| - logit (:obj:`torch.Tensor`): | |
| Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. | |
| Shapes: | |
| - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` | |
| - action (:obj:`torch.Tensor`): :math:`(B, N0)` | |
| - logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`. | |
| - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size. | |
| Examples: | |
| >>> # Regression mode | |
| >>> model = QACDIST(64, 64, 'regression') | |
| >>> inputs = torch.randn(4, 64) | |
| >>> actor_outputs = model(inputs,'compute_actor') | |
| >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) | |
| >>> # Reparameterization Mode | |
| >>> model = QACDIST(64, 64, 'reparameterization') | |
| >>> inputs = torch.randn(4, 64) | |
| >>> actor_outputs = model(inputs,'compute_actor') | |
| >>> actor_outputs['logit'][0].shape # mu | |
| >>> torch.Size([4, 64]) | |
| >>> actor_outputs['logit'][1].shape # sigma | |
| >>> torch.Size([4, 64]) | |
| """ | |
| x = self.actor(inputs) | |
| if self.action_space == 'regression': | |
| return {'action': x['pred']} | |
| elif self.action_space == 'reparameterization': | |
| return {'logit': [x['mu'], x['sigma']]} | |
| def compute_critic(self, inputs: Dict) -> Dict: | |
| """ | |
| Overview: | |
| Execute parameter updates with ``'compute_critic'`` mode | |
| Use encoded embedding tensor to predict output. | |
| Arguments: | |
| - ``obs``, ``action`` encoded tensors. | |
| - mode (:obj:`str`): Name of the forward mode. | |
| Returns: | |
| - outputs (:obj:`Dict`): Q-value output and distribution. | |
| ReturnKeys: | |
| - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. | |
| - distribution (:obj:`torch.Tensor`): Q value distribution tensor. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` | |
| - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape`` | |
| - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` | |
| - distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom`` | |
| Examples: | |
| >>> # Categorical mode | |
| >>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)} | |
| >>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \ | |
| ... critic_head_type='categorical', n_atoms=51) | |
| >>> q_value = model(inputs, mode='compute_critic') # q value | |
| >>> assert q_value['q_value'].shape == torch.Size([4, 1]) | |
| >>> assert q_value['distribution'].shape == torch.Size([4, 1, 51]) | |
| """ | |
| obs, action = inputs['obs'], inputs['action'] | |
| assert len(obs.shape) == 2 | |
| if len(action.shape) == 1: # (B, ) -> (B, 1) | |
| action = action.unsqueeze(1) | |
| x = torch.cat([obs, action], dim=1) | |
| x = self.critic(x) | |
| return {'q_value': x['logit'], 'distribution': x['distribution']} | |