File size: 2,394 Bytes
83e5230 f87f11a 83e5230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import pickle\n",
"from atari import AtariEnv\n",
"from networks import QNetwork\n",
"\n",
"# ------- START TO MODIFY ------- #\n",
"ALGO = \"eaudedqn\" # choose between eaudedqn, polyprunedqn, dqn, eaudecql, polyprunecql, and cql.\n",
"GAME = \"BeamRider\" # choose between BeamRider, MsPacman, Qbert, Pong, Enduro, SpaceInvaders, Assault, CrazyClimber, Boxing, and VideoPinball.\n",
"FEATURE_SIZE = 32 # choose between 32, 512, and 2048.\n",
"NETWORK_SEED = 1 # choose between 1, 2, 3, 4, and 5.\n",
"EVALUATION_SEED = 0\n",
"HORIZON = 27000\n",
"EPSILON = 0.01\n",
"RECORD_VIDEO = False\n",
"# ------- END TO MODIFY ------- #\n",
"\n",
"params_path = f\"models/{GAME}/{ALGO}/feature_size_{FEATURE_SIZE}_seed_{NETWORK_SEED}\"\n",
"\n",
"env = AtariEnv(GAME)\n",
"\n",
"q = QNetwork([32, 64, 64, FEATURE_SIZE], env.n_actions)\n",
"\n",
"with open(params_path, \"rb\") as handle:\n",
" q_params = pickle.load(handle)\n",
"\n",
"return_, absorbing = env.evaluate_one_simulation(\n",
" q, q_params, HORIZON, EPSILON, jax.random.PRNGKey(EVALUATION_SEED), params_path + \"_eval\" if RECORD_VIDEO else None\n",
")\n",
"print(\"Undiscounted return:\", return_)\n",
"print(\"N steps\", env.n_steps, \"; Horizon\", HORIZON, \"; Absorbing\", absorbing)\n",
"non_zeros = sum(jax.tree.leaves(jax.tree.map(jnp.count_nonzero, q_params)))\n",
"n_params = sum(jax.tree.leaves(jax.tree.map(jnp.size, q_params)))\n",
"print(\"Spartity level:\", (1 - jnp.float32(non_zeros) / jnp.float32(n_params)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|