Spaces:
Sleeping
Sleeping
| import pytest | |
| from itertools import product | |
| import torch | |
| from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS | |
| B = 4 | |
| T = 15 | |
| obs_shape = [(64, 64, 3)] | |
| action_dim = [9] | |
| obs_embeddings = 256 | |
| args = list(product(*[obs_shape, action_dim])) | |
| class TestProcedureCloning: | |
| def test_procedure_cloning_mcts(self, obs_shape, action_dim): | |
| inputs = { | |
| 'states': torch.randn(B, *obs_shape), | |
| 'goals': torch.randn(B, *obs_shape), | |
| 'actions': torch.randn(B, T, action_dim) | |
| } | |
| model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim) | |
| goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) | |
| assert goal_preds.shape == (B, obs_embeddings) | |
| assert action_preds.shape == (B, T + 1, action_dim) | |
| def test_procedure_cloning_bfs(self, obs_shape, action_dim): | |
| o_shape = (obs_shape[2], obs_shape[0], obs_shape[1]) | |
| model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim) | |
| inputs = torch.randn(B, *obs_shape) | |
| map_preds = model(inputs) | |
| assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1) | |