Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pytest | |
| import torch | |
| from lzero.mcts.buffer.game_segment import GameSegment | |
| from lzero.mcts.utils import prepare_observation | |
| from lzero.policy import select_action | |
| # args = ['EfficientZero', 'MuZero'] | |
| args = ["MuZero"] | |
| def test_game_segment(test_algo): | |
| # import different modules according to ``test_algo`` | |
| if test_algo == 'EfficientZero': | |
| from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree | |
| from lzero.model.efficientzero_model import EfficientZeroModel as Model | |
| from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config | |
| from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv | |
| envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)] | |
| elif test_algo == 'MuZero': | |
| from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree | |
| from lzero.model.muzero_model import MuZeroModel as Model | |
| from lzero.mcts.tests.config.tictactoe_muzero_bot_mode_config_for_test import tictactoe_muzero_config as config | |
| from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv | |
| envs = [TicTacToeEnv(config.env) for _ in range(config.env.evaluator_env_num)] | |
| # create model | |
| model = Model(**config.policy.model) | |
| if config.policy.cuda and torch.cuda.is_available(): | |
| config.policy.device = 'cuda' | |
| else: | |
| config.policy.device = 'cpu' | |
| model.to(config.policy.device) | |
| model.eval() | |
| with torch.no_grad(): | |
| # initializations | |
| init_observations = [env.reset() for env in envs] | |
| dones = np.array([False for _ in range(config.env.evaluator_env_num)]) | |
| game_segments = [ | |
| GameSegment( | |
| envs[i].action_space, game_segment_length=config.policy.game_segment_length, config=config.policy | |
| ) for i in range(config.env.evaluator_env_num) | |
| ] | |
| for i in range(config.env.evaluator_env_num): | |
| game_segments[i].reset( | |
| [init_observations[i]['observation'] for _ in range(config.policy.model.frame_stack_num)] | |
| ) | |
| episode_rewards = np.zeros(config.env.evaluator_env_num) | |
| while not dones.all(): | |
| stack_obs = [game_segment.get_obs() for game_segment in game_segments] | |
| stack_obs = prepare_observation(stack_obs, config.policy.model.model_type) | |
| stack_obs = torch.from_numpy(np.array(stack_obs)).to(config.policy.device) | |
| # ============================================================== | |
| # the core initial_inference. | |
| # ============================================================== | |
| network_output = model.initial_inference(stack_obs) | |
| # process the network output | |
| policy_logits_pool = network_output.policy_logits.detach().cpu().numpy().tolist() | |
| latent_state_roots = network_output.latent_state.detach().cpu().numpy() | |
| if test_algo == 'EfficientZero': | |
| reward_hidden_state_roots = network_output.reward_hidden_state | |
| value_prefix_pool = network_output.value_prefix | |
| reward_hidden_state_roots = ( | |
| reward_hidden_state_roots[0].detach().cpu().numpy(), | |
| reward_hidden_state_roots[1].detach().cpu().numpy() | |
| ) | |
| # for atari env, all actions is legal_action | |
| legal_actions_list = [ | |
| [i for i in range(config.policy.model.action_space_size)] | |
| for _ in range(config.env.evaluator_env_num) | |
| ] | |
| elif test_algo == 'MuZero': | |
| reward_pool = network_output.reward | |
| # for board games, we use the all actions is legal_action | |
| legal_actions_list = [ | |
| [a for a, x in enumerate(init_observations[i]['action_mask']) if x == 1] | |
| for i in range(config.env.evaluator_env_num) | |
| ] | |
| # null padding for the atari games and board_games in vs_bot_mode | |
| to_play = [-1 for _ in range(config.env.evaluator_env_num)] | |
| if test_algo == 'EfficientZero': | |
| roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list) | |
| roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play) | |
| MCTSCtree(config.policy).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) | |
| elif test_algo == 'MuZero': | |
| roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list) | |
| roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) | |
| MCTSCtree(config.policy).search(roots, model, latent_state_roots, to_play) | |
| roots_distributions = roots.get_distributions() | |
| roots_values = roots.get_values() | |
| for i in range(config.env.evaluator_env_num): | |
| distributions, value, env = roots_distributions[i], roots_values[i], envs[i] | |
| # ``deterministic=True`` indicates that we select the argmax action instead of sampling. | |
| action, _ = select_action(distributions, temperature=1, deterministic=True) | |
| # ============================================================== | |
| # the core initial_inference. | |
| # ============================================================== | |
| obs, reward, done, info = env.step(action) | |
| obs = obs['observation'] | |
| game_segments[i].store_search_stats(distributions, value) | |
| game_segments[i].append(action, obs, reward) | |
| dones[i] = done | |
| episode_rewards[i] += reward | |
| if dones[i]: | |
| continue | |
| for env in envs: | |
| env.close() | |