Spaces:
Sleeping
Sleeping
| from typing import Union, Dict, Optional | |
| from easydict import EasyDict | |
| import torch | |
| import torch.nn as nn | |
| from ding.utils import SequenceType, squeeze, MODEL_REGISTRY | |
| from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \ | |
| FCEncoder, ConvEncoder | |
| class DiscreteMAQAC(nn.Module): | |
| """ | |
| Overview: | |
| The neural network and computation graph of algorithms related to discrete action Multi-Agent Q-value \ | |
| Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \ | |
| critic is a MLP network. The actor network is used to predict the action probability distribution, and the \ | |
| critic network is used to predict the Q value of the state-action pair. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` | |
| """ | |
| mode = ['compute_actor', 'compute_critic'] | |
| def __init__( | |
| self, | |
| agent_obs_shape: Union[int, SequenceType], | |
| global_obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType], | |
| twin_critic: bool = False, | |
| actor_head_hidden_size: int = 64, | |
| actor_head_layer_num: int = 1, | |
| critic_head_hidden_size: int = 64, | |
| critic_head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the DiscreteMAQAC Model according to arguments. | |
| Arguments: | |
| - agent_obs_shape (:obj:`Union[int, SequenceType]`): Agent's observation's space. | |
| - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space. | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. | |
| - action_shape (:obj:`Union[int, SequenceType]`): Action's space. | |
| - twin_critic (:obj:`bool`): Whether include twin critic. | |
| - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. | |
| - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ | |
| for actor's nn. | |
| - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. | |
| - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ | |
| for critic's nn. | |
| - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ | |
| ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` | |
| - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \ | |
| for more details. | |
| """ | |
| super(DiscreteMAQAC, self).__init__() | |
| agent_obs_shape: int = squeeze(agent_obs_shape) | |
| action_shape: int = squeeze(action_shape) | |
| self.actor = nn.Sequential( | |
| nn.Linear(agent_obs_shape, actor_head_hidden_size), activation, | |
| DiscreteHead( | |
| actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type | |
| ) | |
| ) | |
| self.twin_critic = twin_critic | |
| if self.twin_critic: | |
| self.critic = nn.ModuleList() | |
| for _ in range(2): | |
| self.critic.append( | |
| nn.Sequential( | |
| nn.Linear(global_obs_shape, critic_head_hidden_size), activation, | |
| DiscreteHead( | |
| critic_head_hidden_size, | |
| action_shape, | |
| critic_head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| ) | |
| else: | |
| self.critic = nn.Sequential( | |
| nn.Linear(global_obs_shape, critic_head_hidden_size), activation, | |
| DiscreteHead( | |
| critic_head_hidden_size, | |
| action_shape, | |
| critic_head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: | |
| """ | |
| Overview: | |
| Use observation tensor to predict output, with ``compute_actor`` or ``compute_critic`` mode. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ | |
| with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ | |
| N0 corresponds to ``agent_obs_shape``. | |
| - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ | |
| with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ | |
| N1 corresponds to ``global_obs_shape``. | |
| - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ | |
| with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ | |
| N2 corresponds to ``action_shape``. | |
| - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. | |
| Returns: | |
| - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \ | |
| whose key-values vary in different forward modes. | |
| Examples: | |
| >>> B = 32 | |
| >>> agent_obs_shape = 216 | |
| >>> global_obs_shape = 264 | |
| >>> agent_num = 8 | |
| >>> action_shape = 14 | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), | |
| >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
| >>> } | |
| >>> } | |
| >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) | |
| >>> logit = model(data, mode='compute_actor')['logit'] | |
| >>> value = model(data, mode='compute_critic')['q_value'] | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_actor(self, inputs: Dict) -> Dict: | |
| """ | |
| Overview: | |
| Use observation tensor to predict action logits. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ | |
| with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ | |
| N0 corresponds to ``agent_obs_shape``. | |
| - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ | |
| with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ | |
| N1 corresponds to ``global_obs_shape``. | |
| - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ | |
| with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ | |
| N2 corresponds to ``action_shape``. | |
| Returns: | |
| - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \ | |
| whose key-values vary in different forward modes. | |
| - logit (:obj:`torch.Tensor`): Action's output logit (real value range), whose shape is \ | |
| :math:`(B, A, N2)`, where N2 corresponds to ``action_shape``. | |
| - action_mask (:obj:`torch.Tensor`): Action mask tensor with same size as ``action_shape``. | |
| Examples: | |
| >>> B = 32 | |
| >>> agent_obs_shape = 216 | |
| >>> global_obs_shape = 264 | |
| >>> agent_num = 8 | |
| >>> action_shape = 14 | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), | |
| >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
| >>> } | |
| >>> } | |
| >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) | |
| >>> logit = model.compute_actor(data)['logit'] | |
| """ | |
| action_mask = inputs['obs']['action_mask'] | |
| x = self.actor(inputs['obs']['agent_state']) | |
| return {'logit': x['logit'], 'action_mask': action_mask} | |
| def compute_critic(self, inputs: Dict) -> Dict: | |
| """ | |
| Overview: | |
| use observation tensor to predict Q value. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ | |
| with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ | |
| N0 corresponds to ``agent_obs_shape``. | |
| - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ | |
| with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ | |
| N1 corresponds to ``global_obs_shape``. | |
| - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ | |
| with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ | |
| N2 corresponds to ``action_shape``. | |
| Returns: | |
| - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \ | |
| whose key-values vary in different values of ``twin_critic``. | |
| - q_value (:obj:`list`): If ``twin_critic=True``, q_value should be 2 elements, each is the shape of \ | |
| :math:`(B, A, N2)`, where B is batch size and A is agent num. N2 corresponds to ``action_shape``. \ | |
| Otherwise, q_value should be ``torch.Tensor``. | |
| Examples: | |
| >>> B = 32 | |
| >>> agent_obs_shape = 216 | |
| >>> global_obs_shape = 264 | |
| >>> agent_num = 8 | |
| >>> action_shape = 14 | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), | |
| >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
| >>> } | |
| >>> } | |
| >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) | |
| >>> value = model.compute_critic(data)['q_value'] | |
| """ | |
| if self.twin_critic: | |
| x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic] | |
| else: | |
| x = self.critic(inputs['obs']['global_state'])['logit'] | |
| return {'q_value': x} | |
| class ContinuousMAQAC(nn.Module): | |
| """ | |
| Overview: | |
| The neural network and computation graph of algorithms related to continuous action Multi-Agent Q-value \ | |
| Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \ | |
| critic is a MLP network. The actor network is used to predict the action probability distribution, and the \ | |
| critic network is used to predict the Q value of the state-action pair. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` | |
| """ | |
| mode = ['compute_actor', 'compute_critic'] | |
| def __init__( | |
| self, | |
| agent_obs_shape: Union[int, SequenceType], | |
| global_obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType, EasyDict], | |
| action_space: str, | |
| twin_critic: bool = False, | |
| actor_head_hidden_size: int = 64, | |
| actor_head_layer_num: int = 1, | |
| critic_head_hidden_size: int = 64, | |
| critic_head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the QAC Model according to arguments. | |
| Arguments: | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. | |
| - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ) | |
| - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``. | |
| - twin_critic (:obj:`bool`): Whether include twin critic. | |
| - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. | |
| - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ | |
| for actor's nn. | |
| - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. | |
| - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ | |
| for critic's nn. | |
| - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ | |
| ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` | |
| - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \ | |
| for more details. | |
| """ | |
| super(ContinuousMAQAC, self).__init__() | |
| obs_shape: int = squeeze(agent_obs_shape) | |
| global_obs_shape: int = squeeze(global_obs_shape) | |
| action_shape = squeeze(action_shape) | |
| self.action_shape = action_shape | |
| self.action_space = action_space | |
| assert self.action_space in ['regression', 'reparameterization'], self.action_space | |
| if self.action_space == 'regression': # DDPG, TD3 | |
| self.actor = nn.Sequential( | |
| nn.Linear(obs_shape, actor_head_hidden_size), activation, | |
| RegressionHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| final_tanh=True, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| else: # SAC | |
| self.actor = nn.Sequential( | |
| nn.Linear(obs_shape, actor_head_hidden_size), activation, | |
| ReparameterizationHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| sigma_type='conditioned', | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| self.twin_critic = twin_critic | |
| critic_input_size = global_obs_shape + action_shape | |
| if self.twin_critic: | |
| self.critic = nn.ModuleList() | |
| for _ in range(2): | |
| self.critic.append( | |
| nn.Sequential( | |
| nn.Linear(critic_input_size, critic_head_hidden_size), activation, | |
| RegressionHead( | |
| critic_head_hidden_size, | |
| 1, | |
| critic_head_layer_num, | |
| final_tanh=False, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| ) | |
| else: | |
| self.critic = nn.Sequential( | |
| nn.Linear(critic_input_size, critic_head_hidden_size), activation, | |
| RegressionHead( | |
| critic_head_hidden_size, | |
| 1, | |
| critic_head_layer_num, | |
| final_tanh=False, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: | |
| """ | |
| Overview: | |
| Use observation and action tensor to predict output in ``compute_actor`` or ``compute_critic`` mode. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ | |
| with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ | |
| N0 corresponds to ``agent_obs_shape``. | |
| - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ | |
| with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ | |
| N1 corresponds to ``global_obs_shape``. | |
| - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ | |
| with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ | |
| N2 corresponds to ``action_shape``. | |
| - ``action`` (:obj:`torch.Tensor`): The action tensor data, \ | |
| with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \ | |
| N3 corresponds to ``action_shape``. | |
| - mode (:obj:`str`): Name of the forward mode. | |
| Returns: | |
| - outputs (:obj:`Dict`): Outputs of network forward, whose key-values will be different for different \ | |
| ``mode``, ``twin_critic``, ``action_space``. | |
| Examples: | |
| >>> B = 32 | |
| >>> agent_obs_shape = 216 | |
| >>> global_obs_shape = 264 | |
| >>> agent_num = 8 | |
| >>> action_shape = 14 | |
| >>> act_space = 'reparameterization' # regression | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), | |
| >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
| >>> }, | |
| >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) | |
| >>> } | |
| >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) | |
| >>> if action_space == 'regression': | |
| >>> action = model(data['obs'], mode='compute_actor')['action'] | |
| >>> elif action_space == 'reparameterization': | |
| >>> (mu, sigma) = model(data['obs'], mode='compute_actor')['logit'] | |
| >>> value = model(data, mode='compute_critic')['q_value'] | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_actor(self, inputs: Dict) -> Dict: | |
| """ | |
| Overview: | |
| Use observation tensor to predict action logits. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ | |
| with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ | |
| N0 corresponds to ``agent_obs_shape``. | |
| Returns: | |
| - outputs (:obj:`Dict`): Outputs of network forward. | |
| ReturnKeys (``action_space == 'regression'``): | |
| - action (:obj:`torch.Tensor`): Action tensor with same size as ``action_shape``. | |
| ReturnKeys (``action_space == 'reparameterization'``): | |
| - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \ | |
| A is agent num. N3 corresponds to ``action_shape``. | |
| Examples: | |
| >>> B = 32 | |
| >>> agent_obs_shape = 216 | |
| >>> global_obs_shape = 264 | |
| >>> agent_num = 8 | |
| >>> action_shape = 14 | |
| >>> act_space = 'reparameterization' # 'regression' | |
| >>> data = { | |
| >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| >>> } | |
| >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) | |
| >>> if action_space == 'regression': | |
| >>> action = model.compute_actor(data)['action'] | |
| >>> elif action_space == 'reparameterization': | |
| >>> (mu, sigma) = model.compute_actor(data)['logit'] | |
| """ | |
| inputs = inputs['agent_state'] | |
| if self.action_space == 'regression': | |
| x = self.actor(inputs) | |
| return {'action': x['pred']} | |
| else: | |
| x = self.actor(inputs) | |
| return {'logit': [x['mu'], x['sigma']]} | |
| def compute_critic(self, inputs: Dict) -> Dict: | |
| """ | |
| Overview: | |
| Use observation tensor and action tensor to predict Q value. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: | |
| - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ | |
| with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ | |
| N0 corresponds to ``agent_obs_shape``. | |
| - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ | |
| with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ | |
| N1 corresponds to ``global_obs_shape``. | |
| - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ | |
| with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ | |
| N2 corresponds to ``action_shape``. | |
| - ``action`` (:obj:`torch.Tensor`): The action tensor data, \ | |
| with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \ | |
| N3 corresponds to ``action_shape``. | |
| Returns: | |
| - outputs (:obj:`Dict`): Outputs of network forward. | |
| ReturnKeys (``twin_critic=True``): | |
| - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \ | |
| A is agent num. | |
| ReturnKeys (``twin_critic=False``): | |
| - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num. | |
| Examples: | |
| >>> B = 32 | |
| >>> agent_obs_shape = 216 | |
| >>> global_obs_shape = 264 | |
| >>> agent_num = 8 | |
| >>> action_shape = 14 | |
| >>> act_space = 'reparameterization' # 'regression' | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), | |
| >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
| >>> }, | |
| >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) | |
| >>> } | |
| >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) | |
| >>> value = model.compute_critic(data)['q_value'] | |
| """ | |
| obs, action = inputs['obs']['global_state'], inputs['action'] | |
| if len(action.shape) == 1: # (B, ) -> (B, 1) | |
| action = action.unsqueeze(1) | |
| x = torch.cat([obs, action], dim=-1) | |
| if self.twin_critic: | |
| x = [m(x)['pred'] for m in self.critic] | |
| else: | |
| x = self.critic(x)['pred'] | |
| return {'q_value': x} | |