Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from ding.utils import MODEL_REGISTRY | |
| from .qmix import QMix | |
| class MADQN(nn.Module): | |
| def __init__( | |
| self, | |
| agent_num: int, | |
| obs_shape: int, | |
| action_shape: int, | |
| hidden_size_list: list, | |
| global_obs_shape: int = None, | |
| mixer: bool = False, | |
| global_cooperation: bool = True, | |
| lstm_type: str = 'gru', | |
| dueling: bool = False | |
| ) -> None: | |
| super(MADQN, self).__init__() | |
| self.current = QMix( | |
| agent_num=agent_num, | |
| obs_shape=obs_shape, | |
| action_shape=action_shape, | |
| hidden_size_list=hidden_size_list, | |
| global_obs_shape=global_obs_shape, | |
| mixer=mixer, | |
| lstm_type=lstm_type, | |
| dueling=dueling | |
| ) | |
| self.global_cooperation = global_cooperation | |
| if self.global_cooperation: | |
| cooperation_obs_shape = global_obs_shape | |
| else: | |
| cooperation_obs_shape = obs_shape | |
| self.cooperation = QMix( | |
| agent_num=agent_num, | |
| obs_shape=cooperation_obs_shape, | |
| action_shape=action_shape, | |
| hidden_size_list=hidden_size_list, | |
| global_obs_shape=global_obs_shape, | |
| mixer=mixer, | |
| lstm_type=lstm_type, | |
| dueling=dueling | |
| ) | |
| def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict: | |
| if cooperation: | |
| if self.global_cooperation: | |
| data['obs']['agent_state'] = data['obs']['global_state'] | |
| return self.cooperation(data, single_step=single_step) | |
| else: | |
| return self.current(data, single_step=single_step) | |