Spaces:
Sleeping
Sleeping
| import torch | |
| import pytest | |
| from torch import nn | |
| from lzero.model.stochastic_muzero_model import ChanceEncoder | |
| # Initialize a ChanceEncoder instance for testing | |
| def encoder(): | |
| return ChanceEncoder((3, 32, 32), 4) | |
| def test_ChanceEncoder(encoder): | |
| # Create a dummy tensor for testing | |
| x_and_last_x = torch.randn(1, 6, 32, 32) | |
| # Forward pass through the encoder | |
| chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) | |
| # Check the output shapes | |
| assert chance_encoding_t.shape == (1, 4) | |
| assert chance_onehot_t.shape == (1, 4) | |
| # Check that chance_onehot_t is indeed one-hot | |
| assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1)) | |
| assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1) | |
| def test_ChanceEncoder_gradients_chance_encoding(encoder): | |
| # Create a dummy tensor for testing | |
| x_and_last_x = torch.randn(1, 6, 32, 32) | |
| # Forward pass through the encoder | |
| chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) | |
| # Create a dummy target tensor for a simple loss function | |
| target = torch.randn(1, 4) | |
| # Use mean squared error as a simple loss function | |
| loss = nn.MSELoss()(chance_encoding_t, target) | |
| # Backward pass | |
| loss.backward() | |
| # Check if gradients are computed | |
| for param in encoder.parameters(): | |
| assert param.grad is not None | |
| # Check if gradients have the correct shape | |
| for param in encoder.parameters(): | |
| assert param.grad.shape == param.shape | |
| def test_ChanceEncoder_gradients_chance_onehot_t(encoder): | |
| # Create a dummy tensor for testing | |
| x_and_last_x = torch.randn(1, 6, 32, 32) | |
| # Forward pass through the encoder | |
| chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) | |
| # Create a dummy target tensor for a simple loss function | |
| target = torch.randn(1, 4) | |
| # Use mean squared error as a simple loss function | |
| loss = nn.MSELoss()(chance_onehot_t, target) | |
| # Backward pass | |
| loss.backward() | |
| # Check if gradients are computed | |
| for param in encoder.parameters(): | |
| assert param.grad is not None | |
| # Check if gradients have the correct shape | |
| for param in encoder.parameters(): | |
| assert param.grad.shape == param.shape | |