{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import pickle\n", "from atari import AtariEnv\n", "from networks import QNetwork\n", "\n", "# ------- START TO MODIFY ------- #\n", "ALGO = \"eaudedqn\" # choose between eaudedqn, polyprunedqn, dqn, eaudecql, polyprunecql, and cql.\n", "GAME = \"BeamRider\" # choose between BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, and VideoPinball.\n", "FEATURE_SIZE = 32 # choose between 32, 512, and 2048.\n", "NETWORK_SEED = 1 # choose between 1, 2, 3, 4, and 5.\n", "EVALUATION_SEED = 0\n", "HORIZON = 27000\n", "EPSILON = 0.01\n", "RECORD_VIDEO = False\n", "# ------- END TO MODIFY ------- #\n", "\n", "params_path = f\"models/{GAME}/{ALGO}/feature_size_{FEATURE_SIZE}_seed_{NETWORK_SEED}\"\n", "\n", "env = AtariEnv(GAME)\n", "\n", "q = QNetwork([32, 64, 64, FEATURE_SIZE], env.n_actions)\n", "\n", "with open(params_path, \"rb\") as handle:\n", " q_params = pickle.load(handle)\n", "\n", "return_, absorbing = env.evaluate_one_simulation(\n", " q, q_params, HORIZON, EPSILON, jax.random.PRNGKey(EVALUATION_SEED), params_path + \"_eval\" if RECORD_VIDEO else None\n", ")\n", "print(\"Undiscounted return:\", return_)\n", "print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)\n", "non_zeros = sum(jax.tree.leaves(jax.tree.map(jnp.count_nonzero, q_params)))\n", "n_params = sum(jax.tree.leaves(jax.tree.map(jnp.size, q_params)))\n", "print(\"Spartity level:\", (1 - jnp.float32(non_zeros) / jnp.float32(n_params)))" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }