|
import gym |
|
import numpy as np |
|
from gym import spaces |
|
|
|
from .tsp_data import TSPDataset |
|
|
|
|
|
def assign_env_config(self, kwargs): |
|
""" |
|
Set self.key = value, for each key in kwargs |
|
""" |
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
|
|
def dist(loc1, loc2): |
|
return ((loc1[:, 0] - loc2[:, 0]) ** 2 + (loc1[:, 1] - loc2[:, 1]) ** 2) ** 0.5 |
|
|
|
|
|
class TSPVectorEnv(gym.Env): |
|
def __init__(self, *args, **kwargs): |
|
self.max_nodes = 50 |
|
self.n_traj = 50 |
|
|
|
self.eval_data = False |
|
self.eval_partition = "test" |
|
self.eval_data_idx = 0 |
|
assign_env_config(self, kwargs) |
|
|
|
obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))} |
|
obs_dict["action_mask"] = spaces.MultiBinary( |
|
[self.n_traj, self.max_nodes] |
|
) |
|
obs_dict["first_node_idx"] = spaces.MultiDiscrete([self.max_nodes] * self.n_traj) |
|
obs_dict["last_node_idx"] = spaces.MultiDiscrete([self.max_nodes] * self.n_traj) |
|
obs_dict["is_initial_action"] = spaces.Discrete(1) |
|
|
|
self.observation_space = spaces.Dict(obs_dict) |
|
self.action_space = spaces.MultiDiscrete([self.max_nodes] * self.n_traj) |
|
self.reward_space = None |
|
|
|
self.reset() |
|
|
|
def seed(self, seed): |
|
np.random.seed(seed) |
|
|
|
def reset(self): |
|
self.visited = np.zeros((self.n_traj, self.max_nodes), dtype=bool) |
|
self.num_steps = 0 |
|
self.last = np.zeros(self.n_traj, dtype=int) |
|
self.first = np.zeros(self.n_traj, dtype=int) |
|
|
|
if self.eval_data: |
|
self._load_orders() |
|
else: |
|
self._generate_orders() |
|
self.state = self._update_state() |
|
self.info = {} |
|
self.done = False |
|
return self.state |
|
|
|
def _load_orders(self): |
|
self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx]) |
|
|
|
def _generate_orders(self): |
|
self.nodes = np.random.rand(self.max_nodes, 2) |
|
|
|
def step(self, action): |
|
|
|
self._go_to(action) |
|
self.num_steps += 1 |
|
self.state = self._update_state() |
|
|
|
|
|
self.done = (action == self.first) & self.is_all_visited() |
|
|
|
return self.state, self.reward, self.done, self.info |
|
|
|
|
|
def cost(self, loc1, loc2): |
|
return dist(loc1, loc2) |
|
|
|
def is_all_visited(self): |
|
|
|
return self.visited[:, :].all(axis=1) |
|
|
|
def _go_to(self, destination): |
|
dest_node = self.nodes[destination] |
|
if self.num_steps != 0: |
|
dist = self.cost(dest_node, self.nodes[self.last]) |
|
else: |
|
dist = np.zeros(self.n_traj) |
|
self.first = destination |
|
|
|
self.last = destination |
|
|
|
self.visited[np.arange(self.n_traj), destination] = True |
|
self.reward = -dist |
|
|
|
def _update_state(self): |
|
obs = {"observations": self.nodes} |
|
obs["action_mask"] = self._update_mask() |
|
obs["first_node_idx"] = self.first |
|
obs["last_node_idx"] = self.last |
|
obs["is_initial_action"] = self.num_steps == 0 |
|
return obs |
|
|
|
def _update_mask(self): |
|
|
|
action_mask = ~self.visited |
|
|
|
action_mask[np.arange(self.n_traj), self.first] |= self.is_all_visited() |
|
return action_mask |
|
|