Spaces:
Sleeping
Sleeping
| import pytest | |
| import numpy as np | |
| import torch | |
| from itertools import product | |
| from ding.model import mavac | |
| from ding.model.template.mavac import MAVAC | |
| from ding.torch_utils import is_differentiable | |
| B = 32 | |
| agent_obs_shape = [216, 265] | |
| global_obs_shape = [264, 324] | |
| agent_num = 8 | |
| action_shape = 14 | |
| args = list(product(*[agent_obs_shape, global_obs_shape])) | |
| class TestVAC: | |
| def output_check(self, model, outputs, action_shape): | |
| if isinstance(action_shape, tuple): | |
| loss = sum([t.sum() for t in outputs]) | |
| elif np.isscalar(action_shape): | |
| loss = outputs.sum() | |
| is_differentiable(loss, model) | |
| def test_vac(self, agent_obs_shape, global_obs_shape): | |
| data = { | |
| 'agent_state': torch.randn(B, agent_num, agent_obs_shape), | |
| 'global_state': torch.randn(B, agent_num, global_obs_shape), | |
| 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) | |
| } | |
| model = MAVAC(agent_obs_shape, global_obs_shape, action_shape, agent_num) | |
| logit = model(data, mode='compute_actor_critic')['logit'] | |
| value = model(data, mode='compute_actor_critic')['value'] | |
| outputs = value.sum() + logit.sum() | |
| self.output_check(model, outputs, action_shape) | |
| for p in model.parameters(): | |
| p.grad = None | |
| logit = model(data, mode='compute_actor')['logit'] | |
| self.output_check(model.actor, logit, model.action_shape) | |
| for p in model.parameters(): | |
| p.grad = None | |
| value = model(data, mode='compute_critic')['value'] | |
| assert value.shape == (B, agent_num) | |
| self.output_check(model.critic, value, action_shape) | |