Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action | |
| def test_vtrace_discrete_action(): | |
| T, B, N = 4, 8, 16 | |
| value = torch.randn(T + 1, B).requires_grad_(True) | |
| reward = torch.rand(T, B) | |
| target_output = torch.randn(T, B, N).requires_grad_(True) | |
| behaviour_output = torch.randn(T, B, N) | |
| action = torch.randint(0, N, size=(T, B)) | |
| data = vtrace_data(target_output, behaviour_output, action, value, reward, None) | |
| loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1) | |
| assert all([l.shape == tuple() for l in loss]) | |
| assert target_output.grad is None | |
| assert value.grad is None | |
| loss = sum(loss) | |
| loss.backward() | |
| assert isinstance(target_output, torch.Tensor) | |
| assert isinstance(value, torch.Tensor) | |
| def test_vtrace_continuous_action(): | |
| T, B, N = 4, 8, 16 | |
| value = torch.randn(T + 1, B).requires_grad_(True) | |
| reward = torch.rand(T, B) | |
| target_output = {} | |
| target_output['mu'] = torch.randn(T, B, N).requires_grad_(True) | |
| target_output['sigma'] = torch.exp(torch.randn(T, B, N).requires_grad_(True)) | |
| behaviour_output = {} | |
| behaviour_output['mu'] = torch.randn(T, B, N) | |
| behaviour_output['sigma'] = torch.exp(torch.randn(T, B, N)) | |
| action = torch.randn((T, B, N)) | |
| data = vtrace_data(target_output, behaviour_output, action, value, reward, None) | |
| loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1) | |
| assert all([l.shape == tuple() for l in loss]) | |
| assert target_output['mu'].grad is None | |
| assert target_output['sigma'].grad is None | |
| assert value.grad is None | |
| loss = sum(loss) | |
| loss.backward() | |
| assert isinstance(target_output['mu'], torch.Tensor) | |
| assert isinstance(target_output['sigma'], torch.Tensor) | |
| assert isinstance(value, torch.Tensor) | |