{ "cells": [ { "cell_type": "code", "source": [ "!git clone https://github.com/cpwan/RLOR\n", "%cd RLOR" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5a69iB04JzY2", "outputId": "13a3a63e-34bf-4d8b-a853-9d2597cd03d5" }, "id": "5a69iB04JzY2", "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into 'RLOR'...\n", "remote: Enumerating objects: 52, done.\u001b[K\n", "remote: Counting objects: 100% (52/52), done.\u001b[K\n", "remote: Compressing objects: 100% (35/35), done.\u001b[K\n", "remote: Total 52 (delta 12), reused 52 (delta 12), pack-reused 0\u001b[K\n", "Unpacking objects: 100% (52/52), 5.19 MiB | 7.89 MiB/s, done.\n", "/content/RLOR\n" ] } ] }, { "cell_type": "code", "execution_count": 2, "id": "dbe3c5ed", "metadata": { "id": "dbe3c5ed" }, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import gym\n", "from models.attention_model_wrapper import Agent" ] }, { "cell_type": "markdown", "id": "985bf6e6", "metadata": { "id": "985bf6e6" }, "source": [ "# Define our agent" ] }, { "cell_type": "code", "execution_count": 3, "id": "953a7fde", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "953a7fde", "outputId": "06f10aaf-57ca-4870-d22b-f71a14ea4ec4" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 3 } ], "source": [ "device = 'cuda'\n", "ckpt_path = './runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt'\n", "agent = Agent(device=device, name='tsp').to(device)\n", "agent.load_state_dict(torch.load(ckpt_path))" ] }, { "cell_type": "markdown", "id": "2cbaa255", "metadata": { "id": "2cbaa255" }, "source": [ "# Define our environment" ] }, { "cell_type": "code", "execution_count": 4, "id": "c2bd466f", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c2bd466f", "outputId": "92450c8d-f5db-444d-f465-da4a98667799" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:31: UserWarning: \u001b[33mWARN: A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (50, 2)\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n", "/usr/local/lib/python3.9/dist-packages/gym/wrappers/step_api_compatibility.py:39: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n", "/usr/local/lib/python3.9/dist-packages/gym/vector/vector_env.py:56: DeprecationWarning: \u001b[33mWARN: Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n" ] } ], "source": [ "from wrappers.syncVectorEnvPomo import SyncVectorEnv\n", "from wrappers.recordWrapper import RecordEpisodeStatistics\n", "\n", "env_id = 'tsp-v0'\n", "env_entry_point = 'envs.tsp_vector_env:TSPVectorEnv'\n", "seed = 0\n", "\n", "gym.envs.register(\n", " id=env_id,\n", " entry_point=env_entry_point,\n", ")\n", "\n", "def make_env(env_id, seed, cfg={}):\n", " def thunk():\n", " env = gym.make(env_id, **cfg)\n", " env = RecordEpisodeStatistics(env)\n", " env.seed(seed)\n", " env.action_space.seed(seed)\n", " env.observation_space.seed(seed)\n", " return env\n", " return thunk\n", "\n", "envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=1))])" ] }, { "cell_type": "markdown", "id": "c363d489", "metadata": { "id": "c363d489" }, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": 5, "id": "bbee9e3c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bbee9e3c", "outputId": "11750d60-eb7c-4d9c-8b40-3a8a92021ffe" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:174: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:190: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:195: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method is not within the observation space.\u001b[0m\n", " logger.warn(f\"{pre} is not within the observation space.\")\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n", " logger.deprecation(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:234: UserWarning: \u001b[33mWARN: Expects `done` signal to be a boolean, actual type: \u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method is not within the observation space.\u001b[0m\n", " logger.warn(f\"{pre} is not within the observation space.\")\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:260: UserWarning: \u001b[33mWARN: The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: \u001b[0m\n", " logger.warn(\n" ] } ], "source": [ "num_steps = 51\n", "trajectories = []\n", "agent.eval()\n", "obs = envs.reset()\n", "for step in range(0, num_steps):\n", " # ALGO LOGIC: action logic\n", " with torch.no_grad():\n", " action, logits = agent(obs)\n", " obs, reward, done, info = envs.step(action.cpu().numpy())\n", " trajectories.append(action.cpu().numpy())" ] }, { "cell_type": "code", "execution_count": 6, "id": "f0fbf6fd", "metadata": { "id": "f0fbf6fd" }, "outputs": [], "source": [ "nodes_coordinates = obs['observations'][0]\n", "final_return = info[0]['episode']['r']\n", "resulting_traj = np.array(trajectories)[:,0,0]" ] }, { "cell_type": "markdown", "source": [ "## Results" ], "metadata": { "id": "5n9rBoH5Q8gn" }, "id": "5n9rBoH5Q8gn" }, { "cell_type": "code", "execution_count": 7, "id": "dff29ef4", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dff29ef4", "outputId": "dcffcda5-5728-464c-ee3e-0fcc702443ff" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "A route of length [-5.908508]\n", "The route is:\n", " [26 34 33 49 37 21 48 43 31 28 42 29 47 39 38 23 27 30 7 32 24 40 20 14\n", " 25 1 18 22 0 11 2 16 45 15 46 12 17 41 8 13 3 6 44 9 10 19 36 5\n", " 35 4 26]\n" ] } ], "source": [ "print(f'A route of length {final_return}')\n", "print('The route is:\\n', resulting_traj)" ] }, { "cell_type": "markdown", "id": "b009802e", "metadata": { "id": "b009802e" }, "source": [ "## Display it in a 2d-grid\n", "- Darker color means later steps in the route." ] }, { "cell_type": "code", "execution_count": 8, "id": "dc681a06", "metadata": { "tags": [ "\"hide-cell\"" ], "cellView": "form", "id": "dc681a06" }, "outputs": [], "source": [ "#@title Helper function for plotting\n", "# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb\n", "import matplotlib.pyplot as plt\n", "from matplotlib.collections import LineCollection\n", "from matplotlib.colors import ListedColormap, BoundaryNorm\n", "\n", "def make_segments(x, y):\n", " '''\n", " Create list of line segments from x and y coordinates, in the correct format for LineCollection:\n", " an array of the form numlines x (points per line) x 2 (x and y) array\n", " '''\n", "\n", " points = np.array([x, y]).T.reshape(-1, 1, 2)\n", " segments = np.concatenate([points[:-1], points[1:]], axis=1)\n", " \n", " return segments\n", "\n", "def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):\n", " '''\n", " Plot a colored line with coordinates x and y\n", " Optionally specify colors in the array z\n", " Optionally specify a colormap, a norm function and a line width\n", " '''\n", " \n", " # Default colors equally spaced on [0,1]:\n", " if z is None:\n", " z = np.linspace(0.3, 1.0, len(x))\n", " \n", " # Special case if a single number:\n", " if not hasattr(z, \"__iter__\"): # to check for numerical input -- this is a hack\n", " z = np.array([z])\n", " \n", " z = np.asarray(z)\n", " \n", " segments = make_segments(x, y)\n", " lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)\n", " \n", " ax = plt.gca()\n", " ax.add_collection(lc)\n", " \n", " return lc\n", "\n", "def plot(coords):\n", " x,y = coords.T\n", " lc = colorline(x,y,cmap='Reds')\n", " plt.axis('square')\n", " return lc" ] }, { "cell_type": "code", "execution_count": 9, "id": "bb0548fb", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 282 }, "id": "bb0548fb", "outputId": "3e967b86-32e9-4be3-e5b7-c8a457aa12a7" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 9 }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "plot(nodes_coordinates[resulting_traj])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 5 }