Spaces:
Sleeping
Sleeping
| import pytest | |
| import time | |
| from itertools import product | |
| import numpy as np | |
| import torch | |
| from ding.rl_utils import ppg_data, ppg_joint_error | |
| use_value_clip_args = [True, False] | |
| random_weight = torch.rand(4) + 1 | |
| weight_args = [None, random_weight] | |
| args = [item for item in product(*[use_value_clip_args, weight_args])] | |
| # due to numeric stability of this unittest, we rerun it when sporadic error occurs | |
| def test_ppg(use_value_clip, weight): | |
| error_count = 0 | |
| while True: | |
| torch.manual_seed(time.time()) | |
| B, N = 4, 32 | |
| logit_new = torch.randn(B, N).add_(0.1).clamp_(0.1, 0.99) | |
| logit_old = logit_new.add_(torch.rand_like(logit_new) * 0.1).clamp_(0.1, 0.99) | |
| logit_new.requires_grad_(True) | |
| logit_old.requires_grad_(True) | |
| action = torch.randint(0, N, size=(B, )) | |
| value_new = torch.randn(B).requires_grad_(True) | |
| value_old = value_new + torch.rand_like(value_new) * 0.1 | |
| return_ = torch.randn(B) * 2 | |
| data = ppg_data(logit_new, logit_old, action, value_new, value_old, return_, weight) | |
| loss = ppg_joint_error(data, use_value_clip=use_value_clip) | |
| assert all([l.shape == tuple() for l in loss]) | |
| assert logit_new.grad is None | |
| assert value_new.grad is None | |
| total_loss = sum(loss) | |
| try: | |
| total_loss.backward() | |
| except RuntimeError as e: | |
| print("[ERROR]: {}".format(e)) | |
| if error_count == 10: | |
| break | |
| error_count += 1 | |
| continue | |
| assert isinstance(logit_new.grad, torch.Tensor) | |
| assert isinstance(value_new.grad, torch.Tensor) | |
| break | |