Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from ding.torch_utils import is_differentiable | |
| from ding.model.template.vae import VanillaVAE | |
| def test_vae(): | |
| batch_size = 32 | |
| action_shape = 6 | |
| original_action_shape = 2 | |
| obs_shape = 6 | |
| hidden_size_list = [256, 256] | |
| inputs = { | |
| 'action': torch.randn(batch_size, original_action_shape), | |
| 'obs': torch.randn(batch_size, obs_shape), | |
| 'next_obs': torch.randn(batch_size, obs_shape) | |
| } | |
| vae_model = VanillaVAE(original_action_shape, obs_shape, action_shape, hidden_size_list) | |
| outputs = vae_model(inputs) | |
| assert outputs['recons_action'].shape == (batch_size, original_action_shape) | |
| assert outputs['prediction_residual'].shape == (batch_size, obs_shape) | |
| assert isinstance(outputs['input'], dict) | |
| assert outputs['mu'].shape == (batch_size, obs_shape) | |
| assert outputs['log_var'].shape == (batch_size, obs_shape) | |
| assert outputs['z'].shape == (batch_size, action_shape) | |
| outputs_decode = vae_model.decode_with_obs(outputs['z'], inputs['obs']) | |
| assert outputs_decode['reconstruction_action'].shape == (batch_size, original_action_shape) | |
| assert outputs_decode['predition_residual'].shape == (batch_size, obs_shape) | |
| outputs['original_action'] = inputs['action'] | |
| outputs['true_residual'] = inputs['next_obs'] - inputs['obs'] | |
| vae_loss = vae_model.loss_function(outputs, kld_weight=0.01, predict_weight=0.01) | |
| is_differentiable(vae_loss['loss'], vae_model) | |