Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pytest | |
| import torch | |
| from easydict import EasyDict | |
| from lzero.policy import inverse_scalar_transform, select_action | |
| policy = 'GumbelMuZero' | |
| if policy == 'EfficientZero': | |
| from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree | |
| elif policy == 'GumbelMuZero': | |
| from lzero.mcts.tree_search.mcts_ctree import GumbelMuZeroMCTSCtree as MCTSCtree | |
| else: | |
| raise KeyError('Only support test for EfficientZero and GumbelMuZero.') | |
| class MuZeroModelFake(torch.nn.Module): | |
| """ | |
| Overview: | |
| Fake MuZero model just for test EfficientZeroMCTSCtree and GumbelMuZeroMCTSCtree. | |
| Interfaces: | |
| __init__, initial_inference, recurrent_inference | |
| """ | |
| def __init__(self, action_num): | |
| super().__init__() | |
| self.action_num = action_num | |
| def initial_inference(self, observation): | |
| encoded_state = observation | |
| batch_size = encoded_state.shape[0] | |
| value = torch.zeros(size=(batch_size, 601)) | |
| value_prefix = [0. for _ in range(batch_size)] | |
| policy_logits = torch.zeros(size=(batch_size, self.action_num)) | |
| latent_state = torch.zeros(size=(batch_size, 12, 3, 3)) | |
| reward_hidden_state_roots = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) | |
| output = { | |
| 'value': value, | |
| 'value_prefix': value_prefix, | |
| 'policy_logits': policy_logits, | |
| 'latent_state': latent_state, | |
| 'reward_hidden_state': reward_hidden_state_roots | |
| } | |
| if policy == 'GumbelMuZero': | |
| output['reward'] = value_prefix | |
| return EasyDict(output) | |
| def recurrent_inference(self, latent_states, reward_hidden_states, actions=None): | |
| if policy == 'GumbelMuZero': | |
| assert actions == None | |
| actions = reward_hidden_states | |
| batch_size = latent_states.shape[0] | |
| latent_state = torch.zeros(size=(batch_size, 12, 3, 3)) | |
| reward_hidden_state_roots = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) | |
| value = torch.zeros(size=(batch_size, 601)) | |
| value_prefix = torch.zeros(size=(batch_size, 601)) | |
| policy_logits = torch.zeros(size=(batch_size, self.action_num)) | |
| output = { | |
| 'value': value, | |
| 'value_prefix': value_prefix, | |
| 'policy_logits': policy_logits, | |
| 'latent_state': latent_state, | |
| 'reward_hidden_state': reward_hidden_state_roots | |
| } | |
| if policy == 'GumbelMuZero': | |
| output['reward'] = value_prefix | |
| return EasyDict(output) | |
| policy_config = EasyDict( | |
| lstm_horizon_len=5, | |
| num_simulations=8, | |
| batch_size=16, | |
| pb_c_base=1, | |
| pb_c_init=1, | |
| max_num_considered_actions=6, | |
| discount_factor=0.9, | |
| root_dirichlet_alpha=0.3, | |
| root_noise_weight=0.2, | |
| dirichlet_alpha=0.3, | |
| exploration_fraction=1, | |
| device='cpu', | |
| value_delta_max=0.01, | |
| model=dict( | |
| action_space_size=9, | |
| support_scale=300, | |
| categorical_distribution=True, | |
| ), | |
| ) | |
| batch_size = env_nums = policy_config.batch_size | |
| action_space_size = policy_config.model.action_space_size | |
| model = MuZeroModelFake(action_num=9) | |
| stack_obs = torch.zeros(size=(batch_size, 8), dtype=torch.float) | |
| network_output = model.initial_inference(stack_obs.float()) | |
| latent_state_roots = network_output['latent_state'] | |
| reward_hidden_state_roots = network_output['reward_hidden_state'] | |
| pred_values_pool = network_output['value'] | |
| value_prefix_pool = network_output['value_prefix'] | |
| policy_logits_pool = network_output['policy_logits'] | |
| # network output process | |
| pred_values_pool = inverse_scalar_transform(pred_values_pool, policy_config.model.support_scale).detach().cpu().numpy() | |
| latent_state_roots = latent_state_roots.detach().cpu().numpy() | |
| reward_hidden_state_roots = ( | |
| reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() | |
| ) | |
| policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist() | |
| action_mask = [ | |
| [0, 0, 0, 1, 0, 1, 1, 0, 0], | |
| [1, 0, 0, 1, 0, 0, 1, 0, 0], | |
| [1, 1, 0, 0, 1, 0, 1, 0, 1], | |
| [1, 0, 0, 1, 1, 1, 0, 0, 0], | |
| [0, 0, 1, 0, 0, 1, 0, 0, 1], | |
| [0, 1, 1, 0, 1, 0, 0, 0, 0], | |
| [1, 0, 1, 1, 1, 0, 0, 1, 1], | |
| [1, 1, 1, 1, 1, 0, 0, 0, 1], | |
| [0, 0, 0, 1, 0, 1, 1, 0, 0], | |
| [0, 1, 1, 0, 1, 1, 1, 1, 0], | |
| [1, 1, 1, 0, 0, 0, 1, 1, 1], | |
| [1, 1, 0, 1, 0, 1, 1, 0, 0], | |
| [0, 0, 1, 0, 0, 1, 0, 0, 0], | |
| [1, 0, 1, 1, 0, 0, 1, 1, 0], | |
| [0, 1, 0, 0, 0, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 1, 1, 0, 0, 1], | |
| ] | |
| assert len(action_mask) == batch_size | |
| assert len(action_mask[0]) == action_space_size | |
| action_num = [ | |
| int(np.array(action_mask[i]).sum()) for i in range(env_nums) | |
| ] # [3, 3, 5, 4, 3, 3, 6, 6, 3, 6, 6, 5, 2, 5, 1, 4] | |
| legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)] | |
| # legal_actions_list = | |
| # [[3, 5, 6], [0, 3, 6], [0, 1, 4, 6, 8], [0, 3, 4, 5], | |
| # [2, 5, 8], [1, 2, 4], [0, 2, 3, 4, 7, 8], [0, 1, 2, 3, 4, 8], | |
| # [3, 5, 6], [1, 2, 4, 5, 6, 7], [0, 1, 2, 6, 7, 8], [0, 1, 3, 5, 6], | |
| # [2, 5], [0, 2, 3, 6, 7], [1], [0, 4, 5, 8]] | |
| to_play = [2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1] | |
| assert len(to_play) == batch_size | |
| def test_mcts_vs_bot_to_play(): | |
| legal_actions_list = [[i for i in range(action_space_size)] for _ in range(env_nums)] # all action | |
| roots = MCTSCtree.roots(env_nums, legal_actions_list) | |
| noises = [ | |
| np.random.dirichlet([policy_config.root_dirichlet_alpha] * policy_config.model.action_space_size | |
| ).astype(np.float32).tolist() for _ in range(env_nums) | |
| ] | |
| # In ctree, to_play must be list, not None | |
| if policy == 'EfficientZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) | |
| elif policy == 'GumbelMuZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, | |
| [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) | |
| roots_distributions = roots.get_distributions() | |
| roots_values = roots.get_values() | |
| assert np.array(roots_distributions).shape == (batch_size, action_space_size) | |
| assert np.array(roots_values).shape == (batch_size,) | |
| def test_mcts_vs_bot_to_play_large(): | |
| policy_config.obs_space_size = 100 | |
| policy_config.model.action_space_size = 20 | |
| policy_config.num_simulations = 500 | |
| policy_config.batch_size = 256 | |
| env_nums = policy_config.batch_size | |
| model = MuZeroModelFake(action_num=policy_config.model.action_space_size) | |
| # stack_obs = torch.zeros(size=(policy_config.batch_size, policy_config.obs_space_size), dtype=torch.float) | |
| stack_obs = torch.randn(size=(policy_config.batch_size, policy_config.obs_space_size), dtype=torch.float) | |
| network_output = model.initial_inference(stack_obs.float()) | |
| latent_state_roots = network_output['latent_state'] | |
| reward_hidden_state_roots = network_output['reward_hidden_state'] | |
| pred_values_pool = network_output['value'] | |
| value_prefix_pool = network_output['value_prefix'] | |
| policy_logits_pool = network_output['policy_logits'] | |
| # network output process | |
| pred_values_pool = inverse_scalar_transform(pred_values_pool, | |
| policy_config.model.support_scale).detach().cpu().numpy() | |
| latent_state_roots = latent_state_roots.detach().cpu().numpy() | |
| reward_hidden_state_roots = ( | |
| reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() | |
| ) | |
| policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist() | |
| # all actions are legal | |
| legal_actions_list = [[i for i in range(policy_config.model.action_space_size)] for _ in range(env_nums)] | |
| roots = MCTSCtree.roots(env_nums, legal_actions_list) | |
| noises = [ | |
| np.random.dirichlet([policy_config.root_dirichlet_alpha] * policy_config.model.action_space_size | |
| ).astype(np.float32).tolist() for _ in range(env_nums) | |
| ] | |
| # In ctree, to_play must be list, not None | |
| if policy == 'EfficientZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) | |
| elif policy == 'GumbelMuZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, | |
| [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) | |
| roots_distributions = roots.get_distributions() | |
| roots_values = roots.get_values() | |
| assert np.array(roots_distributions).shape == (policy_config.batch_size, policy_config.model.action_space_size) | |
| assert np.array(roots_values).shape == (policy_config.batch_size,) | |
| def test_mcts_vs_bot_to_play_legal_action(): | |
| for i in range(env_nums): | |
| assert action_num[i] == len(legal_actions_list[i]) | |
| roots = MCTSCtree.roots(env_nums, legal_actions_list) | |
| noises = [ | |
| np.random.dirichlet([policy_config.root_dirichlet_alpha] * int(sum(action_mask[j]))).astype(np.float32).tolist() | |
| for j in range(env_nums) | |
| ] | |
| # In ctree, to_play must be list, not None | |
| if policy == 'EfficientZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) | |
| elif policy == 'GumbelMuZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, | |
| [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) | |
| roots_distributions = roots.get_distributions() | |
| roots_values = roots.get_values() | |
| assert len(roots_values) == env_nums | |
| assert len(roots_values) == env_nums | |
| for i in range(env_nums): | |
| assert len(roots_distributions[i]) == action_num[i] | |
| temperature = [1 for _ in range(env_nums)] | |
| for i in range(env_nums): | |
| distributions = roots_distributions[i] | |
| action_index, visit_count_distribution_entropy = select_action( | |
| distributions, temperature=temperature[i], deterministic=False | |
| ) | |
| action = np.where(np.array(action_mask[i]) == 1.0)[0][action_index] | |
| assert action_index < action_num[i] | |
| assert action == legal_actions_list[i][action_index] | |
| print('\n action_index={}, legal_action={}, action={}'.format(action_index, legal_actions_list[i], action)) | |
| def test_mcts_self_play(): | |
| legal_actions_list = [[i for i in range(action_space_size)] for _ in range(env_nums)] # all action | |
| roots = MCTSCtree.roots(env_nums, legal_actions_list) | |
| noises = [ | |
| np.random.dirichlet([policy_config.root_dirichlet_alpha] * policy_config.model.action_space_size | |
| ).astype(np.float32).tolist() for _ in range(env_nums) | |
| ] | |
| # In ctree, to_play must be list, not None | |
| if policy == 'EfficientZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) | |
| elif policy == 'GumbelMuZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, | |
| [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) | |
| roots_distributions = roots.get_distributions() | |
| roots_values = roots.get_values() | |
| assert np.array(roots_distributions).shape == (batch_size, action_space_size) | |
| assert np.array(roots_values).shape == (batch_size,) | |
| def test_mcts_self_play_legal_action(): | |
| for i in range(env_nums): | |
| assert action_num[i] == len(legal_actions_list[i]) | |
| roots = MCTSCtree.roots(env_nums, legal_actions_list) | |
| noises = [ | |
| np.random.dirichlet([policy_config.root_dirichlet_alpha] * int(sum(action_mask[j]))).astype(np.float32).tolist() | |
| for j in range(env_nums) | |
| ] | |
| # In ctree, to_play must be list, not None | |
| if policy == 'EfficientZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) | |
| elif policy == 'GumbelMuZero': | |
| roots.prepare( | |
| policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, | |
| [0 for _ in range(env_nums)] | |
| ) | |
| MCTSCtree(policy_config | |
| ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) | |
| roots_distributions = roots.get_distributions() | |
| roots_values = roots.get_values() | |
| assert len(roots_values) == env_nums | |
| assert len(roots_values) == env_nums | |
| for i in range(env_nums): | |
| assert len(roots_distributions[i]) == action_num[i] | |
| temperature = [1 for _ in range(env_nums)] | |
| for i in range(env_nums): | |
| distributions = roots_distributions[i] | |
| action_index, visit_count_distribution_entropy = select_action( | |
| distributions, temperature=temperature[i], deterministic=False | |
| ) | |
| action = np.where(np.array(action_mask[i]) == 1.0)[0][action_index] | |
| assert action_index < action_num[i] | |
| assert action == legal_actions_list[i][action_index] | |
| print('\n action_index={}, legal_action={}, action={}'.format(action_index, legal_actions_list[i], action)) | |