Spaces:
Runtime error
Runtime error
Pierre Tassel
commited on
Commit
·
9a90bc0
1
Parent(s):
dfe9f8e
improvements
Browse files- MyDummyVecEnv.py +7 -6
- app.py +30 -22
MyDummyVecEnv.py
CHANGED
@@ -7,6 +7,8 @@ import numpy as np
|
|
7 |
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
8 |
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
|
9 |
|
|
|
|
|
10 |
|
11 |
class MyDummyVecEnv(VecEnv):
|
12 |
"""
|
@@ -20,14 +22,16 @@ class MyDummyVecEnv(VecEnv):
|
|
20 |
that return environments to vectorize
|
21 |
"""
|
22 |
|
23 |
-
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
|
24 |
self.envs = [fn() for fn in env_fns]
|
25 |
env = self.envs[0]
|
26 |
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
27 |
obs_space = env.observation_space
|
28 |
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
|
|
29 |
|
30 |
-
self.buf_obs = OrderedDict(
|
|
|
31 |
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
32 |
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
33 |
self.buf_infos = [{} for _ in range(self.num_envs)]
|
@@ -86,10 +90,7 @@ class MyDummyVecEnv(VecEnv):
|
|
86 |
|
87 |
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
|
88 |
for key in self.keys:
|
89 |
-
|
90 |
-
self.buf_obs[key][env_idx] = obs
|
91 |
-
else:
|
92 |
-
self.buf_obs[key][env_idx] = obs[key]
|
93 |
|
94 |
def _obs_from_buf(self) -> VecEnvObs:
|
95 |
return dict_to_obs(self.observation_space, self.buf_obs)
|
|
|
7 |
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
8 |
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
|
9 |
|
10 |
+
import torch
|
11 |
+
|
12 |
|
13 |
class MyDummyVecEnv(VecEnv):
|
14 |
"""
|
|
|
22 |
that return environments to vectorize
|
23 |
"""
|
24 |
|
25 |
+
def __init__(self, env_fns: List[Callable[[], gym.Env]], device):
|
26 |
self.envs = [fn() for fn in env_fns]
|
27 |
env = self.envs[0]
|
28 |
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
29 |
obs_space = env.observation_space
|
30 |
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
31 |
+
self.device = device
|
32 |
|
33 |
+
self.buf_obs = OrderedDict(
|
34 |
+
[(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
|
35 |
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
36 |
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
37 |
self.buf_infos = [{} for _ in range(self.num_envs)]
|
|
|
90 |
|
91 |
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
|
92 |
for key in self.keys:
|
93 |
+
self.buf_obs[key][env_idx] = torch.from_numpy(obs[key]).to(self.device, non_blocking=True)
|
|
|
|
|
|
|
94 |
|
95 |
def _obs_from_buf(self) -> VecEnvObs:
|
96 |
return dict_to_obs(self.observation_space, self.buf_obs)
|
app.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import random
|
4 |
import time
|
5 |
|
|
|
6 |
import plotly.figure_factory as ff
|
7 |
import json
|
8 |
|
@@ -17,6 +18,7 @@ from torch.distributions import Categorical
|
|
17 |
import torch
|
18 |
import numpy as np
|
19 |
|
|
|
20 |
from MyVecEnv import WrapperRay
|
21 |
|
22 |
import gradio as gr
|
@@ -46,11 +48,11 @@ def make_env(seed, instance):
|
|
46 |
return thunk
|
47 |
|
48 |
|
49 |
-
def solve(file):
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
with torch.inference_mode():
|
55 |
device = torch.device('cpu')
|
56 |
actor = torch.jit.load('actor.pt', map_location=device)
|
@@ -58,9 +60,8 @@ def solve(file):
|
|
58 |
start_time = time.time()
|
59 |
fn_env = [make_env(0, file.name)
|
60 |
for _ in range(num_workers)]
|
61 |
-
|
62 |
-
|
63 |
-
envs = VecPyTorch(ray_wrapper_env, device)
|
64 |
current_solution_cost = float('inf')
|
65 |
current_solution = ''
|
66 |
obs = envs.reset()
|
@@ -146,24 +147,31 @@ def solve(file):
|
|
146 |
fig.update_yaxes(
|
147 |
autorange=True
|
148 |
)
|
149 |
-
return current_solution_cost,
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
160 |
|
161 |
article = "<p style='text-align: center'>Article Under Review</p>"
|
162 |
# list all non-hidden files in the 'instances' directory
|
163 |
-
examples = ['instances/' + f for f in os.listdir('instances') if not f.startswith('.')]
|
164 |
iface = gr.Interface(fn=solve,
|
165 |
-
inputs=gr.File(label="Instance File"),
|
166 |
-
|
|
|
|
|
|
|
167 |
title=title,
|
168 |
description=description,
|
169 |
article=article,
|
|
|
3 |
import random
|
4 |
import time
|
5 |
|
6 |
+
import gym
|
7 |
import plotly.figure_factory as ff
|
8 |
import json
|
9 |
|
|
|
18 |
import torch
|
19 |
import numpy as np
|
20 |
|
21 |
+
from MyDummyVecEnv import MyDummyVecEnv
|
22 |
from MyVecEnv import WrapperRay
|
23 |
|
24 |
import gradio as gr
|
|
|
48 |
return thunk
|
49 |
|
50 |
|
51 |
+
def solve(file, num_workers, seed):
|
52 |
+
seed = int(abs(seed))
|
53 |
+
random.seed(seed)
|
54 |
+
np.random.seed(seed)
|
55 |
+
torch.manual_seed(seed)
|
56 |
with torch.inference_mode():
|
57 |
device = torch.device('cpu')
|
58 |
actor = torch.jit.load('actor.pt', map_location=device)
|
|
|
60 |
start_time = time.time()
|
61 |
fn_env = [make_env(0, file.name)
|
62 |
for _ in range(num_workers)]
|
63 |
+
async_envs = MyDummyVecEnv(fn_env, device)
|
64 |
+
envs = VecPyTorch(async_envs, device)
|
|
|
65 |
current_solution_cost = float('inf')
|
66 |
current_solution = ''
|
67 |
obs = envs.reset()
|
|
|
147 |
fig.update_yaxes(
|
148 |
autorange=True
|
149 |
)
|
150 |
+
return current_solution_cost, str(total_time) + " seconds", pretty_output, fig
|
151 |
+
|
152 |
+
|
153 |
+
title = "Job-Shop Scheduling CP environment with RL dispatching"
|
154 |
+
description = """A Job-Shop Scheduling Reinforcement Learning based solver using an underlying CP model as an
|
155 |
+
environment. <br>
|
156 |
+
For fast inference,
|
157 |
+
check out the cached examples below.<br> Any Job-Shop Scheduling instance following the standard specification is
|
158 |
+
compatible. <a href='http://jobshop.jjvh.nl/index.php'>Check out this website for more instances</a>.<br>
|
159 |
+
Increasing the number of workers will provide better solutions, but will slow down the solving time.
|
160 |
+
This behavior is different than the one from the paper repository as here agents are run sequentially,
|
161 |
+
whereas we run agents in parallel (technical limitation due to the platform here). <br>
|
162 |
+
<br>
|
163 |
+
For large instance, we recommend running the approach locally outside the interface, as it causes a lot
|
164 |
+
of overhead and the resource available on this platform are low (1 vCPU and no GPU).<br> """
|
165 |
|
166 |
article = "<p style='text-align: center'>Article Under Review</p>"
|
167 |
# list all non-hidden files in the 'instances' directory
|
168 |
+
examples = [['instances/' + f, 16, 0] for f in os.listdir('instances') if not f.startswith('.')]
|
169 |
iface = gr.Interface(fn=solve,
|
170 |
+
inputs=[gr.File(label="Instance File"),
|
171 |
+
gr.Slider(8, 32, value=16, label="Number of Workers", step=1),
|
172 |
+
gr.Number(0, label="Seed", precision=0)],
|
173 |
+
outputs=[gr.Text(label="Makespan"), gr.Text(label="Elapsed Time"), gr.Text(label="Solution"),
|
174 |
+
gr.Plot(label="Solution's Gantt Chart")],
|
175 |
title=title,
|
176 |
description=description,
|
177 |
article=article,
|