Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from ding.rl_utils import gae_data, gae | |
| def test_gae(): | |
| # batch trajectory case | |
| T, B = 32, 4 | |
| value = torch.randn(T, B) | |
| next_value = torch.randn(T, B) | |
| reward = torch.randn(T, B) | |
| done = torch.zeros((T, B)) | |
| data = gae_data(value, next_value, reward, done, None) | |
| adv = gae(data) | |
| assert adv.shape == (T, B) | |
| # single trajectory case/concat trajectory case | |
| T = 24 | |
| value = torch.randn(T) | |
| next_value = torch.randn(T) | |
| reward = torch.randn(T) | |
| done = torch.zeros((T)) | |
| data = gae_data(value, next_value, reward, done, None) | |
| adv = gae(data) | |
| assert adv.shape == (T, ) | |
| def test_gae_multi_agent(): | |
| T, B, A = 32, 4, 8 | |
| value = torch.randn(T, B, A) | |
| next_value = torch.randn(T, B, A) | |
| reward = torch.randn(T, B) | |
| done = torch.zeros(T, B) | |
| data = gae_data(value, next_value, reward, done, None) | |
| adv = gae(data) | |
| assert adv.shape == (T, B, A) | |