In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import pickle
from atari import AtariEnv
from networks import QNetwork

# ------- START TO MODIFY ------- #
ALGO = "eaudedqn" # choose between eaudedqn, polyprunedqn, dqn, eaudecql, polyprunecql, and cql.
GAME = "BeamRider" # choose between BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, and VideoPinball.
FEATURE_SIZE = 32 # choose between 32, 512, and 2048.
NETWORK_SEED = 1 # choose between 1, 2, 3, 4, and 5.
EVALUATION_SEED = 0
HORIZON = 27000
EPSILON = 0.01
RECORD_VIDEO = False
# ------- END TO MODIFY ------- #

params_path = f"models/{GAME}/{ALGO}/feature_size_{FEATURE_SIZE}_seed_{NETWORK_SEED}"

env = AtariEnv(GAME)

q = QNetwork([32, 64, 64, FEATURE_SIZE], env.n_actions)

with open(params_path, "rb") as handle:
 q_params = pickle.load(handle)

return_, absorbing = env.evaluate_one_simulation(
 q, q_params, HORIZON, EPSILON, jax.random.PRNGKey(EVALUATION_SEED), params_path + "_eval" if RECORD_VIDEO else None
)
print("Undiscounted return:", return_)
print("N steps", env.n_steps, "; Horizon", HORIZON, "; Absorbing", absorbing)
non_zeros = sum(jax.tree.leaves(jax.tree.map(jnp.count_nonzero, q_params)))
n_params = sum(jax.tree.leaves(jax.tree.map(jnp.size, q_params)))
print("Spartity level:", (1 - jnp.float32(non_zeros) / jnp.float32(n_params)))