Pierre Tassel commited on
Commit
9a90bc0
·
1 Parent(s): dfe9f8e

improvements

Browse files
Files changed (2) hide show
  1. MyDummyVecEnv.py +7 -6
  2. app.py +30 -22
MyDummyVecEnv.py CHANGED
@@ -7,6 +7,8 @@ import numpy as np
7
  from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
8
  from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
9
 
 
 
10
 
11
  class MyDummyVecEnv(VecEnv):
12
  """
@@ -20,14 +22,16 @@ class MyDummyVecEnv(VecEnv):
20
  that return environments to vectorize
21
  """
22
 
23
- def __init__(self, env_fns: List[Callable[[], gym.Env]]):
24
  self.envs = [fn() for fn in env_fns]
25
  env = self.envs[0]
26
  VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
27
  obs_space = env.observation_space
28
  self.keys, shapes, dtypes = obs_space_info(obs_space)
 
29
 
30
- self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
 
31
  self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
32
  self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
33
  self.buf_infos = [{} for _ in range(self.num_envs)]
@@ -86,10 +90,7 @@ class MyDummyVecEnv(VecEnv):
86
 
87
  def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
88
  for key in self.keys:
89
- if key is None:
90
- self.buf_obs[key][env_idx] = obs
91
- else:
92
- self.buf_obs[key][env_idx] = obs[key]
93
 
94
  def _obs_from_buf(self) -> VecEnvObs:
95
  return dict_to_obs(self.observation_space, self.buf_obs)
 
7
  from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
8
  from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
9
 
10
+ import torch
11
+
12
 
13
  class MyDummyVecEnv(VecEnv):
14
  """
 
22
  that return environments to vectorize
23
  """
24
 
25
+ def __init__(self, env_fns: List[Callable[[], gym.Env]], device):
26
  self.envs = [fn() for fn in env_fns]
27
  env = self.envs[0]
28
  VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
29
  obs_space = env.observation_space
30
  self.keys, shapes, dtypes = obs_space_info(obs_space)
31
+ self.device = device
32
 
33
+ self.buf_obs = OrderedDict(
34
+ [(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
35
  self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
36
  self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
37
  self.buf_infos = [{} for _ in range(self.num_envs)]
 
90
 
91
  def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
92
  for key in self.keys:
93
+ self.buf_obs[key][env_idx] = torch.from_numpy(obs[key]).to(self.device, non_blocking=True)
 
 
 
94
 
95
  def _obs_from_buf(self) -> VecEnvObs:
96
  return dict_to_obs(self.observation_space, self.buf_obs)
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import random
4
  import time
5
 
 
6
  import plotly.figure_factory as ff
7
  import json
8
 
@@ -17,6 +18,7 @@ from torch.distributions import Categorical
17
  import torch
18
  import numpy as np
19
 
 
20
  from MyVecEnv import WrapperRay
21
 
22
  import gradio as gr
@@ -46,11 +48,11 @@ def make_env(seed, instance):
46
  return thunk
47
 
48
 
49
- def solve(file):
50
- random.seed(0)
51
- np.random.seed(0)
52
- torch.manual_seed(0)
53
- num_workers = 1 # only one CPU available
54
  with torch.inference_mode():
55
  device = torch.device('cpu')
56
  actor = torch.jit.load('actor.pt', map_location=device)
@@ -58,9 +60,8 @@ def solve(file):
58
  start_time = time.time()
59
  fn_env = [make_env(0, file.name)
60
  for _ in range(num_workers)]
61
- ray_wrapper_env = WrapperRay(lambda n: fn_env[n](),
62
- num_workers, 1, device)
63
- envs = VecPyTorch(ray_wrapper_env, device)
64
  current_solution_cost = float('inf')
65
  current_solution = ''
66
  obs = envs.reset()
@@ -146,24 +147,31 @@ def solve(file):
146
  fig.update_yaxes(
147
  autorange=True
148
  )
149
- return current_solution_cost, str(total_time) + " seconds", pretty_output, fig
150
-
151
- ray.init(log_to_driver=False,
152
- ignore_reinit_error=True,
153
- include_dashboard=False)
154
- title = "Job-Shop Scheduling CP RL"
155
- description = """A Job-Shop Scheduling Reinforcement Learning based solver using an underlying CP model as an environment. <br>
156
- However, the results you obtain here don't represent the full potential of the approach due to resource limitations on the HuggingFace platform (a single vCPU available, no GPU).<br>
157
- We recommend running this locally outside the interface for large instances, as it causes a lot of overhead.<br>
158
- For fast inference, check out the cached examples below.<br>
159
- Any Job-Shop Scheduling instance following the standard specification is compatible. <a href='http://jobshop.jjvh.nl/index.php'>Check out this website for more instances</a>."""
 
 
 
 
160
 
161
  article = "<p style='text-align: center'>Article Under Review</p>"
162
  # list all non-hidden files in the 'instances' directory
163
- examples = ['instances/' + f for f in os.listdir('instances') if not f.startswith('.')]
164
  iface = gr.Interface(fn=solve,
165
- inputs=gr.File(label="Instance File"),
166
- outputs=[gr.Text(label="Makespan"), gr.Text(label="Elapsed Time"), gr.Text(label="Solution"), gr.Plot(label="Solution's Gantt Chart")],
 
 
 
167
  title=title,
168
  description=description,
169
  article=article,
 
3
  import random
4
  import time
5
 
6
+ import gym
7
  import plotly.figure_factory as ff
8
  import json
9
 
 
18
  import torch
19
  import numpy as np
20
 
21
+ from MyDummyVecEnv import MyDummyVecEnv
22
  from MyVecEnv import WrapperRay
23
 
24
  import gradio as gr
 
48
  return thunk
49
 
50
 
51
+ def solve(file, num_workers, seed):
52
+ seed = int(abs(seed))
53
+ random.seed(seed)
54
+ np.random.seed(seed)
55
+ torch.manual_seed(seed)
56
  with torch.inference_mode():
57
  device = torch.device('cpu')
58
  actor = torch.jit.load('actor.pt', map_location=device)
 
60
  start_time = time.time()
61
  fn_env = [make_env(0, file.name)
62
  for _ in range(num_workers)]
63
+ async_envs = MyDummyVecEnv(fn_env, device)
64
+ envs = VecPyTorch(async_envs, device)
 
65
  current_solution_cost = float('inf')
66
  current_solution = ''
67
  obs = envs.reset()
 
147
  fig.update_yaxes(
148
  autorange=True
149
  )
150
+ return current_solution_cost, str(total_time) + " seconds", pretty_output, fig
151
+
152
+
153
+ title = "Job-Shop Scheduling CP environment with RL dispatching"
154
+ description = """A Job-Shop Scheduling Reinforcement Learning based solver using an underlying CP model as an
155
+ environment. <br>
156
+ For fast inference,
157
+ check out the cached examples below.<br> Any Job-Shop Scheduling instance following the standard specification is
158
+ compatible. <a href='http://jobshop.jjvh.nl/index.php'>Check out this website for more instances</a>.<br>
159
+ Increasing the number of workers will provide better solutions, but will slow down the solving time.
160
+ This behavior is different than the one from the paper repository as here agents are run sequentially,
161
+ whereas we run agents in parallel (technical limitation due to the platform here). <br>
162
+ <br>
163
+ For large instance, we recommend running the approach locally outside the interface, as it causes a lot
164
+ of overhead and the resource available on this platform are low (1 vCPU and no GPU).<br> """
165
 
166
  article = "<p style='text-align: center'>Article Under Review</p>"
167
  # list all non-hidden files in the 'instances' directory
168
+ examples = [['instances/' + f, 16, 0] for f in os.listdir('instances') if not f.startswith('.')]
169
  iface = gr.Interface(fn=solve,
170
+ inputs=[gr.File(label="Instance File"),
171
+ gr.Slider(8, 32, value=16, label="Number of Workers", step=1),
172
+ gr.Number(0, label="Seed", precision=0)],
173
+ outputs=[gr.Text(label="Makespan"), gr.Text(label="Elapsed Time"), gr.Text(label="Solution"),
174
+ gr.Plot(label="Solution's Gantt Chart")],
175
  title=title,
176
  description=description,
177
  article=article,