{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3f3052cd", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3f3052cd", "outputId": "78d129fd-0956-4f88-ae39-def9953a982e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into 'RLOR'...\n", "remote: Enumerating objects: 52, done.\u001b[K\n", "remote: Counting objects: 100% (52/52), done.\u001b[K\n", "remote: Compressing objects: 100% (35/35), done.\u001b[K\n", "remote: Total 52 (delta 12), reused 52 (delta 12), pack-reused 0\u001b[K\n", "Unpacking objects: 100% (52/52), 5.19 MiB | 4.39 MiB/s, done.\n", "/content/RLOR\n" ] } ], "source": [ "!git clone https://github.com/cpwan/RLOR\n", "%cd RLOR" ] }, { "cell_type": "code", "execution_count": 2, "id": "f01dfb64", "metadata": { "id": "f01dfb64" }, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import gym" ] }, { "cell_type": "markdown", "id": "985bf6e6", "metadata": { "id": "985bf6e6" }, "source": [ "# Define our agent" ] }, { "cell_type": "code", "execution_count": 3, "id": "953a7fde", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "953a7fde", "outputId": "9b37d746-b9a0-4d53-c12a-b2445ed6bd9d" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 3 } ], "source": [ "from models.attention_model_wrapper import Agent\n", "device = 'cuda'\n", "ckpt_path = './runs/cvrp-v0__ppo_or__1__1678159979/ckpt/12000.pt'\n", "agent = Agent(device=device, name='cvrp').to(device)\n", "agent.load_state_dict(torch.load(ckpt_path))" ] }, { "cell_type": "markdown", "id": "2cbaa255", "metadata": { "id": "2cbaa255" }, "source": [ "# Define our environment\n", "## CVRP\n", "Given a depot, n nodes with their demands, and the capacity of the vehicle, \n", "find the shortest path that fulfills the demand of every node.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "81fd7b68", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "81fd7b68", "outputId": "f4a9d9d8-29f7-413d-b462-d27f04a0153a" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:31: UserWarning: \u001b[33mWARN: A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (50, 2)\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n", "/usr/local/lib/python3.9/dist-packages/gym/wrappers/step_api_compatibility.py:39: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n", "/usr/local/lib/python3.9/dist-packages/gym/vector/vector_env.py:56: DeprecationWarning: \u001b[33mWARN: Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n", " deprecation(\n" ] } ], "source": [ "from wrappers.syncVectorEnvPomo import SyncVectorEnv\n", "from wrappers.recordWrapper import RecordEpisodeStatistics\n", "\n", "env_id = 'cvrp-v0'\n", "env_entry_point = 'envs.cvrp_vector_env:CVRPVectorEnv'\n", "seed = 0\n", "\n", "gym.envs.register(\n", " id=env_id,\n", " entry_point=env_entry_point,\n", ")\n", "\n", "def make_env(env_id, seed, cfg={}):\n", " def thunk():\n", " env = gym.make(env_id, **cfg)\n", " env = RecordEpisodeStatistics(env)\n", " env.seed(seed)\n", " env.action_space.seed(seed)\n", " env.observation_space.seed(seed)\n", " return env\n", " return thunk\n", "\n", "envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=50))])" ] }, { "cell_type": "markdown", "id": "c363d489", "metadata": { "id": "c363d489" }, "source": [ "# Inference\n", "We use the Multi-Greedy search strategy: running greedy sampling with different starting nodes" ] }, { "cell_type": "code", "execution_count": 5, "id": "bbee9e3c", "metadata": { "id": "bbee9e3c", "outputId": "5632253d-9a70-433b-c35a-3d97e478da0d", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:174: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:190: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:195: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method is not within the observation space.\u001b[0m\n", " logger.warn(f\"{pre} is not within the observation space.\")\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n", " logger.deprecation(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:234: UserWarning: \u001b[33mWARN: Expects `done` signal to be a boolean, actual type: \u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n", " logger.warn(\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method is not within the observation space.\u001b[0m\n", " logger.warn(f\"{pre} is not within the observation space.\")\n", "/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:260: UserWarning: \u001b[33mWARN: The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: \u001b[0m\n", " logger.warn(\n" ] } ], "source": [ "trajectories = []\n", "agent.eval()\n", "obs = envs.reset()\n", "done = np.array([False])\n", "while not done.all():\n", " # ALGO LOGIC: action logic\n", " with torch.no_grad():\n", " action, logits = agent(obs)\n", " if trajectories==[]: # Multi-greedy inference\n", " action = torch.arange(1, envs.n_traj + 1).repeat(1, 1)\n", " \n", " obs, reward, done, info = envs.step(action.cpu().numpy())\n", " trajectories.append(action.cpu().numpy())" ] }, { "cell_type": "code", "execution_count": 6, "id": "f0fbf6fd", "metadata": { "id": "f0fbf6fd" }, "outputs": [], "source": [ "nodes_coordinates = np.vstack([obs['depot'],obs['observations'][0]])\n", "final_return = info[0]['episode']['r']\n", "best_traj = np.argmax(final_return)\n", "resulting_traj = np.array(trajectories)[:,0,best_traj]\n", "resulting_traj_with_depot = np.hstack([np.zeros(1,dtype = int),resulting_traj])" ] }, { "cell_type": "markdown", "source": [ "## Results" ], "metadata": { "id": "ViNGfd1PQwlw" }, "id": "ViNGfd1PQwlw" }, { "cell_type": "code", "execution_count": 7, "id": "dff29ef4", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dff29ef4", "outputId": "8a57a330-b340-4d60-dc83-2a1ea548c7d0" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "A route of length -11.283475875854492\n", "The route is:\n", " [ 0 16 32 27 7 30 23 38 40 2 11 0 9 10 19 36 35 4 5 18 0 25 42 43\n", " 48 21 37 31 50 0 13 8 41 17 12 46 0 3 6 44 0 22 26 49 34 33 28 0\n", " 1 14 29 39 47 0 20 24 45 15 0 0]\n" ] } ], "source": [ "print(f'A route of length {final_return[best_traj]}')\n", "print('The route is:\\n', resulting_traj_with_depot)" ] }, { "cell_type": "markdown", "id": "1b78c529", "metadata": { "id": "1b78c529" }, "source": [ "### Display it in a 2d-grid\n", "- Darker color means later steps in the route.\n", "- We abuse the errorbar to show the relative size of demand at each customer." ] }, { "cell_type": "code", "execution_count": 8, "id": "dc681a06", "metadata": { "tags": [ "\"hide-cell\"" ], "cellView": "form", "id": "dc681a06" }, "outputs": [], "source": [ "#@title Helper function for plotting\n", "# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb\n", "import matplotlib.pyplot as plt\n", "from matplotlib.collections import LineCollection\n", "from matplotlib.colors import ListedColormap, BoundaryNorm\n", "\n", "def make_segments(x, y):\n", " '''\n", " Create list of line segments from x and y coordinates, in the correct format for LineCollection:\n", " an array of the form numlines x (points per line) x 2 (x and y) array\n", " '''\n", "\n", " points = np.array([x, y]).T.reshape(-1, 1, 2)\n", " segments = np.concatenate([points[:-1], points[1:]], axis=1)\n", " \n", " return segments\n", "\n", "def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):\n", " '''\n", " Plot a colored line with coordinates x and y\n", " Optionally specify colors in the array z\n", " Optionally specify a colormap, a norm function and a line width\n", " '''\n", " \n", " # Default colors equally spaced on [0,1]:\n", " if z is None:\n", " z = np.linspace(0.3, 1.0, len(x))\n", " \n", " # Special case if a single number:\n", " if not hasattr(z, \"__iter__\"): # to check for numerical input -- this is a hack\n", " z = np.array([z])\n", " \n", " z = np.asarray(z)\n", " \n", " segments = make_segments(x, y)\n", " lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)\n", " \n", " ax = plt.gca()\n", " ax.add_collection(lc)\n", " \n", " return lc\n", "\n", "def plot(coords, demand):\n", " x,y = coords.T\n", " lc = colorline(x,y,cmap='Reds')\n", " plt.axis('square')\n", " x, y =obs['observations'][0].T\n", " h = obs['demand']/4\n", " h = np.vstack([h*0,h])\n", " plt.errorbar(x,y,h,fmt='None',elinewidth=2)\n", " return lc" ] }, { "cell_type": "code", "execution_count": 9, "id": "aa5e32f2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 282 }, "id": "aa5e32f2", "outputId": "8da68d19-138b-4e3e-c481-02e06416c5e7" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 9 }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "plot(nodes_coordinates[resulting_traj_with_depot], obs['demand'])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "colab": { "provenance": [] }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 5 }