Spaces:
Sleeping
Sleeping
| from typing import Union, Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| from ding.utils import SequenceType, squeeze, MODEL_REGISTRY | |
| from ..common import ReparameterizationHead, RegressionHead, DiscreteHead | |
| class MAVAC(nn.Module): | |
| """ | |
| Overview: | |
| The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for \ | |
| multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and \ | |
| continuous action space. The MAVAC is composed of four parts: ``actor_encoder``, ``critic_encoder``, \ | |
| ``actor_head`` and ``critic_head``. Encoders are used to extract the feature from various observation. \ | |
| Heads are used to predict corresponding value or action logit. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. | |
| """ | |
| mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] | |
| def __init__( | |
| self, | |
| agent_obs_shape: Union[int, SequenceType], | |
| global_obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType], | |
| agent_num: int, | |
| actor_head_hidden_size: int = 256, | |
| actor_head_layer_num: int = 2, | |
| critic_head_hidden_size: int = 512, | |
| critic_head_layer_num: int = 1, | |
| action_space: str = 'discrete', | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| sigma_type: Optional[str] = 'independent', | |
| bound_type: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Overview: | |
| Init the MAVAC Model according to arguments. | |
| Arguments: | |
| - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent, \ | |
| such as 8 or [4, 84, 84]. | |
| - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space, such as 8 or [4, 84, 84]. | |
| - action_shape (:obj:`Union[int, SequenceType]`): Action space shape for single agent, such as 6 \ | |
| or [2, 3, 3]. | |
| - agent_num (:obj:`int`): This parameter is temporarily reserved. This parameter may be required for \ | |
| subsequent changes to the model | |
| - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ | |
| to 256, it must match the last element of ``agent_obs_shape``. | |
| - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. | |
| - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ | |
| to 512, it must match the last element of ``global_obs_shape``. | |
| - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for \ | |
| critic's nn. | |
| - action_space (:obj:`Union[int, SequenceType]`): The type of different action spaces, including \ | |
| ['discrete', 'continuous'], then will instantiate corresponding head, including ``DiscreteHead`` \ | |
| and ``ReparameterizationHead``. | |
| - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ | |
| ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. | |
| - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ | |
| ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']. | |
| - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \ | |
| ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in MAPPO, it defaults \ | |
| to ``independent``, which means state-independent sigma parameters. | |
| - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \ | |
| to ``None``, which means no bound. | |
| """ | |
| super(MAVAC, self).__init__() | |
| agent_obs_shape: int = squeeze(agent_obs_shape) | |
| global_obs_shape: int = squeeze(global_obs_shape) | |
| action_shape: int = squeeze(action_shape) | |
| self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape | |
| self.action_space = action_space | |
| # Encoder Type | |
| # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder. | |
| # In SMAC task it can obviously improve the performance. | |
| # Users can change the model according to their own needs. | |
| self.actor_encoder = nn.Identity() | |
| self.critic_encoder = nn.Identity() | |
| # Head Type | |
| self.critic_head = nn.Sequential( | |
| nn.Linear(global_obs_shape, critic_head_hidden_size), activation, | |
| RegressionHead( | |
| critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type | |
| ) | |
| ) | |
| assert self.action_space in ['discrete', 'continuous'], self.action_space | |
| if self.action_space == 'discrete': | |
| self.actor_head = nn.Sequential( | |
| nn.Linear(agent_obs_shape, actor_head_hidden_size), activation, | |
| DiscreteHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| ) | |
| elif self.action_space == 'continuous': | |
| self.actor_head = nn.Sequential( | |
| nn.Linear(agent_obs_shape, actor_head_hidden_size), activation, | |
| ReparameterizationHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| sigma_type=sigma_type, | |
| activation=activation, | |
| norm_type=norm_type, | |
| bound_type=bound_type | |
| ) | |
| ) | |
| # must use list, not nn.ModuleList | |
| self.actor = [self.actor_encoder, self.actor_head] | |
| self.critic = [self.critic_encoder, self.critic_head] | |
| # for convenience of call some apis(such as: self.critic.parameters()), but may cause | |
| # misunderstanding when print(self) | |
| self.actor = nn.ModuleList(self.actor) | |
| self.critic = nn.ModuleList(self.critic) | |
| def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: | |
| """ | |
| Overview: | |
| MAVAC forward computation graph, input observation tensor to predict state value or action logit. \ | |
| ``mode`` includes ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. | |
| Different ``mode`` will forward with different network modules to get different outputs and save \ | |
| computation. | |
| Arguments: | |
| - inputs (:obj:`Dict`): The input dict including observation and related info, \ | |
| whose key-values vary from different ``mode``. | |
| - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. | |
| Returns: | |
| - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph, whose key-values vary from \ | |
| different ``mode``. | |
| Examples (Actor): | |
| >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
| >>> inputs = { | |
| 'agent_state': torch.randn(10, 8, 64), | |
| 'global_state': torch.randn(10, 8, 128), | |
| 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
| } | |
| >>> actor_outputs = model(inputs,'compute_actor') | |
| >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) | |
| Examples (Critic): | |
| >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
| >>> inputs = { | |
| 'agent_state': torch.randn(10, 8, 64), | |
| 'global_state': torch.randn(10, 8, 128), | |
| 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
| } | |
| >>> critic_outputs = model(inputs,'compute_critic') | |
| >>> assert actor_outputs['value'].shape == torch.Size([10, 8]) | |
| Examples (Actor-Critic): | |
| >>> model = MAVAC(64, 64) | |
| >>> inputs = { | |
| 'agent_state': torch.randn(10, 8, 64), | |
| 'global_state': torch.randn(10, 8, 128), | |
| 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
| } | |
| >>> outputs = model(inputs,'compute_actor_critic') | |
| >>> assert outputs['value'].shape == torch.Size([10, 8, 14]) | |
| >>> assert outputs['logit'].shape == torch.Size([10, 8]) | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_actor(self, x: Dict) -> Dict: | |
| """ | |
| Overview: | |
| MAVAC forward computation graph for actor part, \ | |
| predicting action logit with agent observation tensor in ``x``. | |
| Arguments: | |
| - x (:obj:`Dict`): Input data dict with keys ['agent_state', 'action_mask'(optional)]. | |
| - agent_state: (:obj:`torch.Tensor`): Each agent local state(obs). | |
| - action_mask(optional): (:obj:`torch.Tensor`): When ``action_space`` is discrete, action_mask needs \ | |
| to be provided to mask illegal actions. | |
| Returns: | |
| - outputs (:obj:`Dict`): The output dict of the forward computation graph for actor, including ``logit``. | |
| ReturnsKeys: | |
| - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ | |
| the same dimension real-value ranged tensor of possible action choices, and for continuous action \ | |
| space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ | |
| same as the number of continuous actions. | |
| Shapes: | |
| - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ | |
| and M is ``agent_num``. | |
| Examples: | |
| >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
| >>> inputs = { | |
| 'agent_state': torch.randn(10, 8, 64), | |
| 'global_state': torch.randn(10, 8, 128), | |
| 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
| } | |
| >>> actor_outputs = model(inputs,'compute_actor') | |
| >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) | |
| """ | |
| if self.action_space == 'discrete': | |
| action_mask = x['action_mask'] | |
| x = x['agent_state'] | |
| x = self.actor_encoder(x) | |
| x = self.actor_head(x) | |
| logit = x['logit'] | |
| logit[action_mask == 0.0] = -99999999 | |
| elif self.action_space == 'continuous': | |
| x = x['agent_state'] | |
| x = self.actor_encoder(x) | |
| x = self.actor_head(x) | |
| logit = x | |
| return {'logit': logit} | |
| def compute_critic(self, x: Dict) -> Dict: | |
| """ | |
| Overview: | |
| MAVAC forward computation graph for critic part. \ | |
| Predict state value with global observation tensor in ``x``. | |
| Arguments: | |
| - x (:obj:`Dict`): Input data dict with keys ['global_state']. | |
| - global_state: (:obj:`torch.Tensor`): Global state(obs). | |
| Returns: | |
| - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for critic, \ | |
| including ``value``. | |
| ReturnsKeys: | |
| - value (:obj:`torch.Tensor`): The predicted state value tensor. | |
| Shapes: | |
| - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``agent_num``. | |
| Examples: | |
| >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) | |
| >>> inputs = { | |
| 'agent_state': torch.randn(10, 8, 64), | |
| 'global_state': torch.randn(10, 8, 128), | |
| 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
| } | |
| >>> critic_outputs = model(inputs,'compute_critic') | |
| >>> assert critic_outputs['value'].shape == torch.Size([10, 8]) | |
| """ | |
| x = self.critic_encoder(x['global_state']) | |
| x = self.critic_head(x) | |
| return {'value': x['pred']} | |
| def compute_actor_critic(self, x: Dict) -> Dict: | |
| """ | |
| Overview: | |
| MAVAC forward computation graph for both actor and critic part, input observation to predict action \ | |
| logit and state value. | |
| Arguments: | |
| - x (:obj:`Dict`): The input dict contains ``agent_state``, ``global_state`` and other related info. | |
| Returns: | |
| - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for both actor and critic, \ | |
| including ``logit`` and ``value``. | |
| ReturnsKeys: | |
| - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. | |
| - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. | |
| Shapes: | |
| - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ | |
| and M is ``agent_num``. | |
| - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``. | |
| Examples: | |
| >>> model = MAVAC(64, 64) | |
| >>> inputs = { | |
| 'agent_state': torch.randn(10, 8, 64), | |
| 'global_state': torch.randn(10, 8, 128), | |
| 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) | |
| } | |
| >>> outputs = model(inputs,'compute_actor_critic') | |
| >>> assert outputs['value'].shape == torch.Size([10, 8]) | |
| >>> assert outputs['logit'].shape == torch.Size([10, 8, 14]) | |
| """ | |
| logit = self.compute_actor(x)['logit'] | |
| value = self.compute_critic(x)['value'] | |
| return {'logit': logit, 'value': value} | |