Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from ding.rl_utils.upgo import upgo_loss, upgo_returns, tb_cross_entropy | |
| def test_upgo(): | |
| T, B, N, N2 = 4, 8, 5, 7 | |
| # tb_cross_entropy: 3 tests | |
| logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True) | |
| action = logit.argmax(-1).detach() | |
| ce = tb_cross_entropy(logit, action) | |
| assert ce.shape == (T, B) | |
| logit = torch.randn(T, B, N, N2, 2).softmax(-1).requires_grad_(True) | |
| action = logit.argmax(-1).detach() | |
| with pytest.raises(AssertionError): | |
| ce = tb_cross_entropy(logit, action) | |
| logit = torch.randn(T, B, N).softmax(-1).requires_grad_(True) | |
| action = logit.argmax(-1).detach() | |
| ce = tb_cross_entropy(logit, action) | |
| assert ce.shape == (T, B) | |
| # upgo_returns | |
| rewards = torch.randn(T, B) | |
| bootstrap_values = torch.randn(T + 1, B).requires_grad_(True) | |
| returns = upgo_returns(rewards, bootstrap_values) | |
| assert returns.shape == (T, B) | |
| # upgo loss | |
| rhos = torch.randn(T, B) | |
| loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values) | |
| assert logit.requires_grad | |
| assert bootstrap_values.requires_grad | |
| for t in [logit, bootstrap_values]: | |
| assert t.grad is None | |
| loss.backward() | |
| for t in [logit]: | |
| assert isinstance(t.grad, torch.Tensor) | |