Pierre Tassel commited on
Commit
d746b98
·
1 Parent(s): e8861ce
Files changed (11) hide show
  1. MyDummyVecEnv.py +123 -0
  2. MyRemoteVectorEnv.py +130 -0
  3. MyVecEnv.py +47 -0
  4. Network.py +114 -0
  5. actor.pt +0 -0
  6. app.py +155 -0
  7. checkpoint.pt +0 -0
  8. dmu01.txt +21 -0
  9. la01.txt +11 -0
  10. requirements.txt +4 -0
  11. ta01 +16 -0
MyDummyVecEnv.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Any, Callable, List, Optional, Sequence, Type, Union
3
+
4
+ import gym
5
+ import numpy as np
6
+
7
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
8
+ from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
9
+
10
+
11
+ class MyDummyVecEnv(VecEnv):
12
+ """
13
+ Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
14
+ Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
15
+ as the overhead of multiprocess or multithread outweighs the environment computation time.
16
+ This can also be used for RL methods that
17
+ require a vectorized environment, but that you want a single environments to train with.
18
+
19
+ :param env_fns: a list of functions
20
+ that return environments to vectorize
21
+ """
22
+
23
+ def __init__(self, env_fns: List[Callable[[], gym.Env]]):
24
+ self.envs = [fn() for fn in env_fns]
25
+ env = self.envs[0]
26
+ VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
27
+ obs_space = env.observation_space
28
+ self.keys, shapes, dtypes = obs_space_info(obs_space)
29
+
30
+ self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
31
+ self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
32
+ self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
33
+ self.buf_infos = [{} for _ in range(self.num_envs)]
34
+ self.actions = None
35
+
36
+ def step_async(self, actions: np.ndarray) -> None:
37
+ self.actions = actions
38
+
39
+ def step_wait(self) -> VecEnvStepReturn:
40
+ for env_idx in range(self.num_envs):
41
+ obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
42
+ self.actions[env_idx]
43
+ )
44
+ if self.buf_dones[env_idx]:
45
+ # save final observation where user can get it, then reset
46
+ self.buf_infos[env_idx]["terminal_observation"] = obs
47
+ obs = self.envs[env_idx].reset()
48
+ self._save_obs(env_idx, obs)
49
+ return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
50
+
51
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
52
+ seeds = list()
53
+ for idx, env in enumerate(self.envs):
54
+ seeds.append(env.seed(seed + idx))
55
+ return seeds
56
+
57
+ def reset(self) -> VecEnvObs:
58
+ for env_idx in range(self.num_envs):
59
+ obs = self.envs[env_idx].reset()
60
+ self._save_obs(env_idx, obs)
61
+ return self._obs_from_buf()
62
+
63
+ def close(self) -> None:
64
+ for env in self.envs:
65
+ env.close()
66
+
67
+ def get_images(self) -> Sequence[np.ndarray]:
68
+ return [env.render(mode="rgb_array") for env in self.envs]
69
+
70
+ def render(self, mode: str = "human") -> Optional[np.ndarray]:
71
+ """
72
+ Gym environment rendering. If there are multiple environments then
73
+ they are tiled together in one image via ``BaseVecEnv.render()``.
74
+ Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
75
+ underlying environment.
76
+
77
+ Therefore, some arguments such as ``mode`` will have values that are valid
78
+ only when ``num_envs == 1``.
79
+
80
+ :param mode: The rendering type.
81
+ """
82
+ if self.num_envs == 1:
83
+ return self.envs[0].render(mode=mode)
84
+ else:
85
+ return super().render(mode=mode)
86
+
87
+ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
88
+ for key in self.keys:
89
+ if key is None:
90
+ self.buf_obs[key][env_idx] = obs
91
+ else:
92
+ self.buf_obs[key][env_idx] = obs[key]
93
+
94
+ def _obs_from_buf(self) -> VecEnvObs:
95
+ return dict_to_obs(self.observation_space, self.buf_obs)
96
+
97
+ def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
98
+ """Return attribute from vectorized environment (see base class)."""
99
+ target_envs = self._get_target_envs(indices)
100
+ return [getattr(env_i, attr_name) for env_i in target_envs]
101
+
102
+ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
103
+ """Set attribute inside vectorized environments (see base class)."""
104
+ target_envs = self._get_target_envs(indices)
105
+ for env_i in target_envs:
106
+ setattr(env_i, attr_name, value)
107
+
108
+ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
109
+ """Call instance methods of vectorized environments."""
110
+ target_envs = self._get_target_envs(indices)
111
+ return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
112
+
113
+ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
114
+ """Check if worker environments are wrapped with a given wrapper"""
115
+ target_envs = self._get_target_envs(indices)
116
+ # Import here to avoid a circular import
117
+ from stable_baselines3.common import env_util
118
+
119
+ return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
120
+
121
+ def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
122
+ indices = self._get_indices(indices)
123
+ return [self.envs[i] for i in indices]
MyRemoteVectorEnv.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Callable, Optional
2
+ from collections import OrderedDict
3
+
4
+ import gym
5
+ import torch
6
+ import numpy as np
7
+ import ray
8
+ from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
9
+ from ray.rllib.utils.annotations import PublicAPI
10
+ from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
11
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
12
+ from stable_baselines3.common.vec_env.util import obs_space_info, dict_to_obs
13
+
14
+ from MyDummyVecEnv import MyDummyVecEnv
15
+
16
+
17
+ @PublicAPI
18
+ class MyRemoteVectorEnv(BaseEnv):
19
+ """Vector env that executes envs in remote workers.
20
+ This provides dynamic batching of inference as observations are returned
21
+ from the remote simulator actors. Both single and multi-agent child envs
22
+ are supported, and envs can be stepped synchronously or async.
23
+ You shouldn't need to instantiate this class directly. It's automatically
24
+ inserted when you use the `remote_worker_envs` option for Trainers.
25
+ """
26
+
27
+ @property
28
+ def observation_space(self):
29
+ return self._observation_space
30
+
31
+ def __init__(self, make_env: Callable[[int], EnvType], num_workers: int, env_per_worker: int, observation_space: Optional[gym.spaces.Space], device: torch.device):
32
+ self.make_local_env = make_env
33
+ self.num_workers = num_workers
34
+ self.env_per_worker = env_per_worker
35
+ self.num_envs = num_workers * env_per_worker
36
+ self.poll_timeout = None
37
+
38
+ self.actors = None # lazy init
39
+ self.pending = None # lazy init
40
+
41
+ self.observation_space = observation_space
42
+ self.keys, shapes, dtypes = obs_space_info(self.observation_space)
43
+
44
+ self.device = device
45
+
46
+ self.buf_obs = OrderedDict(
47
+ [(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
48
+ self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
49
+ self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
50
+ self.buf_infos = [{} for _ in range(self.num_envs)]
51
+
52
+ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
53
+ for key in self.keys:
54
+ self.buf_obs[key][env_idx * self.env_per_worker: (env_idx + 1) * self.env_per_worker] = torch.from_numpy(obs[key]).to(self.device,
55
+ non_blocking=True)
56
+
57
+ def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
58
+ MultiEnvDict, MultiEnvDict]:
59
+ if self.actors is None:
60
+
61
+ def make_remote_env(i):
62
+ return _RemoteSingleAgentEnv.remote(self.make_local_env, i, self.env_per_worker)
63
+
64
+ self.actors = [make_remote_env(i) for i in range(self.num_workers)]
65
+
66
+ if self.pending is None:
67
+ self.pending = {a.reset.remote(): a for a in self.actors}
68
+
69
+ # each keyed by env_id in [0, num_remote_envs)
70
+ ready = []
71
+
72
+ # Wait for at least 1 env to be ready here
73
+ while not ready:
74
+ ready, _ = ray.wait(
75
+ list(self.pending),
76
+ num_returns=len(self.pending),
77
+ timeout=self.poll_timeout)
78
+
79
+ for obj_ref in ready:
80
+ actor = self.pending.pop(obj_ref)
81
+ env_id = self.actors.index(actor)
82
+ ob, rew, done, info = ray.get(obj_ref)
83
+
84
+ self._save_obs(env_id, ob)
85
+ self.buf_rews[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = rew
86
+ self.buf_dones[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = done
87
+ self.buf_infos[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = info
88
+ return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
89
+
90
+ def _obs_from_buf(self) -> VecEnvObs:
91
+ return dict_to_obs(self.observation_space, self.buf_obs)
92
+
93
+ @PublicAPI
94
+ def send_actions(self, action_list) -> None:
95
+ for worker_id in range(self.num_workers):
96
+ actions = action_list[worker_id * self.env_per_worker: (worker_id + 1) * self.env_per_worker]
97
+ actor = self.actors[worker_id]
98
+ obj_ref = actor.step.remote(actions)
99
+ self.pending[obj_ref] = actor
100
+
101
+ @PublicAPI
102
+ def try_reset(self,
103
+ env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
104
+ actor = self.actors[env_id]
105
+ obj_ref = actor.reset.remote()
106
+ self.pending[obj_ref] = actor
107
+ return ASYNC_RESET_RETURN
108
+
109
+ @PublicAPI
110
+ def stop(self) -> None:
111
+ if self.actors is not None:
112
+ for actor in self.actors:
113
+ actor.__ray_terminate__.remote()
114
+
115
+ @observation_space.setter
116
+ def observation_space(self, value):
117
+ self._observation_space = value
118
+
119
+ @ray.remote(num_cpus=1)
120
+ class _RemoteSingleAgentEnv:
121
+ """Wrapper class for making a gym env a remote actor."""
122
+
123
+ def __init__(self, make_env, i, env_per_worker):
124
+ self.env = MyDummyVecEnv([lambda: make_env((i * env_per_worker) + k) for k in range(env_per_worker)])
125
+
126
+ def reset(self):
127
+ return self.env.reset(), 0, False, {}
128
+
129
+ def step(self, actions):
130
+ return self.env.step(actions)
MyVecEnv.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Union, Sequence, Type, Any
2
+
3
+ import gym
4
+ import numpy as np
5
+ from ray.rllib import BaseEnv
6
+ from stable_baselines3.common.vec_env import VecEnv
7
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices, VecEnvStepReturn, VecEnvObs
8
+
9
+ from MyRemoteVectorEnv import MyRemoteVectorEnv
10
+
11
+
12
+ class WrapperRay(VecEnv):
13
+
14
+ def __init__(self, make_env, num_workers, per_worker_env, device):
15
+ self.one_env = make_env(0)
16
+ self.remote: BaseEnv = MyRemoteVectorEnv(make_env, num_workers, per_worker_env, self.one_env.observation_space, device)
17
+ super(WrapperRay, self).__init__(num_workers * per_worker_env, self.one_env.observation_space, self.one_env.action_space)
18
+
19
+ def reset(self) -> VecEnvObs:
20
+ return self.remote.poll()[0]
21
+
22
+ def step_async(self, actions: np.ndarray) -> None:
23
+ self.remote.send_actions(actions)
24
+
25
+ def step_wait(self) -> VecEnvStepReturn:
26
+ return self.remote.poll()
27
+
28
+ def close(self) -> None:
29
+ self.remote.stop()
30
+
31
+ def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
32
+ pass
33
+
34
+ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
35
+ pass
36
+
37
+ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
38
+ pass
39
+
40
+ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
41
+ pass
42
+
43
+ def get_images(self) -> Sequence[np.ndarray]:
44
+ pass
45
+
46
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
47
+ pass
Network.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.distributions import Categorical
7
+
8
+
9
+ class PositionalEncoding(nn.Module):
10
+
11
+ def __init__(self, d_model: int, max_len: int = 100):
12
+ super().__init__()
13
+ position = torch.arange(max_len).unsqueeze(1)
14
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
15
+ pe = torch.zeros(max_len, d_model)
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ self.register_buffer('pe', pe)
19
+
20
+ def forward(self, positions: Tensor) -> Tensor:
21
+ return self.pe[positions]
22
+
23
+
24
+ class Actor(nn.Module):
25
+
26
+ def __init__(self, pos_encoder):
27
+ super(Actor, self).__init__()
28
+ self.activation = nn.Tanh()
29
+ self.project = nn.Linear(4, 8)
30
+ nn.init.xavier_uniform_(self.project.weight, gain=1.0)
31
+ nn.init.constant_(self.project.bias, 0)
32
+ self.pos_encoder = pos_encoder
33
+
34
+ self.embedding_fixed = nn.Embedding(2, 1)
35
+ self.embedding_legal_op = nn.Embedding(2, 1)
36
+
37
+ self.tokens_start_end = nn.Embedding(3, 4)
38
+
39
+ # self.conv_transform = nn.Conv1d(5, 1, 1)
40
+ # nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
41
+ # nn.init.constant_(self.conv_transform.bias, 0)
42
+
43
+ self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
44
+ norm_first=True)
45
+ self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
46
+ norm_first=True)
47
+
48
+ self.final_tmp = nn.Sequential(
49
+ layer_init_tanh(nn.Linear(8, 32)),
50
+ nn.Tanh(),
51
+ layer_init_tanh(nn.Linear(32, 1), std=0.01)
52
+ )
53
+ self.no_op = nn.Sequential(
54
+ layer_init_tanh(nn.Linear(8, 32)),
55
+ nn.Tanh(),
56
+ layer_init_tanh(nn.Linear(32, 1), std=0.01)
57
+ )
58
+
59
+ def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
60
+ embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
61
+ self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
62
+ non_zero_tokens = tokens_start_end != 0
63
+ t = tokens_start_end[non_zero_tokens].long()
64
+ embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
65
+ pos_encoder = self.pos_encoder(indexes_inter.long())
66
+ pos_encoder[non_zero_tokens] = 0
67
+ obs = self.project(embedded_obs) + pos_encoder
68
+
69
+ transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
70
+ attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
71
+ transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
72
+ transformed_obs = transformed_obs.view(obs.shape)
73
+ obs = transformed_obs.mean(dim=2)
74
+
75
+ job_resource = job_resource[:, :-1, :-1] == 0
76
+
77
+ obs_action = self.enc2(obs, src_mask=job_resource) + obs
78
+
79
+ logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
80
+ return logits.masked_fill(mask == 0, -3.4028234663852886e+38)
81
+
82
+
83
+ class Agent(nn.Module):
84
+ def __init__(self):
85
+ super(Agent, self).__init__()
86
+ self.pos_encoder = PositionalEncoding(8)
87
+ self.actor = Actor(self.pos_encoder)
88
+
89
+ def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
90
+ action=None):
91
+ logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
92
+ probs = Categorical(logits=logits)
93
+ if action is None:
94
+ probabilities = probs.probs
95
+ actions = torch.multinomial(probabilities, probabilities.shape[1])
96
+ return actions, torch.log(probabilities), probs.entropy()
97
+ else:
98
+ return logits, probs.log_prob(action), probs.entropy()
99
+
100
+ def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
101
+ logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
102
+ probs = Categorical(logits=logits)
103
+ return probs.sample()
104
+
105
+ def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
106
+ logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
107
+ return logits
108
+
109
+
110
+ def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
111
+ torch.nn.init.orthogonal_(layer.weight, std)
112
+ if layer.bias is not None:
113
+ torch.nn.init.constant_(layer.bias, bias_const)
114
+ return layer
actor.pt ADDED
Binary file (80.3 kB). View file
 
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import random
3
+ import time
4
+
5
+ import multiprocessing as mp
6
+ import json
7
+
8
+ from PIL import Image
9
+ from compiled_jss.CPEnv import CompiledJssEnvCP
10
+
11
+ from stable_baselines3.common.vec_env import VecEnvWrapper
12
+ from torch.distributions import Categorical
13
+
14
+ import torch
15
+ import numpy as np
16
+
17
+ from MyVecEnv import WrapperRay
18
+
19
+ import gradio as gr
20
+ import docplex.cp.utils_visu as visu
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ class VecPyTorch(VecEnvWrapper):
25
+
26
+ def __init__(self, venv, device):
27
+ super(VecPyTorch, self).__init__(venv)
28
+ self.device = device
29
+
30
+ def reset(self):
31
+ return self.venv.reset()
32
+
33
+ def step_async(self, actions):
34
+ self.venv.step_async(actions)
35
+
36
+ def step_wait(self):
37
+ return self.venv.step_wait()
38
+
39
+
40
+ def make_env(seed, instance):
41
+ def thunk():
42
+ _env = CompiledJssEnvCP(instance)
43
+ return _env
44
+
45
+ return thunk
46
+
47
+
48
+ def solve(file):
49
+ random.seed(0)
50
+ np.random.seed(0)
51
+ torch.manual_seed(0)
52
+ num_workers = min(mp.cpu_count(), 32)
53
+ with torch.inference_mode():
54
+ device = torch.device('cpu')
55
+ actor = torch.jit.load('actor.pt', map_location=device)
56
+ actor.eval()
57
+ start_time = time.time()
58
+ fn_env = [make_env(0, file.name)
59
+ for _ in range(num_workers)]
60
+ ray_wrapper_env = WrapperRay(lambda n: fn_env[n](),
61
+ num_workers, 1, device)
62
+ envs = VecPyTorch(ray_wrapper_env, device)
63
+ current_solution_cost = float('inf')
64
+ current_solution = ''
65
+ obs = envs.reset()
66
+ total_episode = 0
67
+ while total_episode < envs.num_envs:
68
+ logits = actor(obs['interval_rep'], obs['attention_interval_mask'], obs['job_resource_mask'],
69
+ obs['action_mask'], obs['index_interval'], obs['start_end_tokens'])
70
+ # temperature vector
71
+ if num_workers >= 4:
72
+ temperature = torch.arange(0.5, 2.0, step=(1.5 / num_workers), device=device)
73
+ else:
74
+ temperature = torch.ones(num_workers, device=device)
75
+ logits = logits / temperature[:, None]
76
+ probs = Categorical(logits=logits).probs
77
+ # random sample based on logits
78
+ actions = torch.multinomial(probs, probs.shape[1]).cpu().numpy()
79
+ obs, reward, done, infos = envs.step(actions)
80
+ total_episode += done.sum()
81
+ # total_actions += 1
82
+ # print(f'Episode {total_episode} / {envs.num_envs} - Actions {total_actions}', end='\r')
83
+ for env_idx, info in enumerate(infos):
84
+ if 'makespan' in info and int(info['makespan']) < current_solution_cost:
85
+ current_solution_cost = int(info['makespan'])
86
+ current_solution = json.loads(info['solution'])
87
+ total_time = time.time() - start_time
88
+ pretty_output = ""
89
+ for job_id in range(len(current_solution)):
90
+ pretty_output += f"Job {job_id}: {current_solution[job_id]}\n"
91
+
92
+ jobs_data = []
93
+ file.seek(0)
94
+ line_str: str = file.readline()
95
+ line_cnt: int = 1
96
+ while line_str:
97
+ data = []
98
+ split_data = line_str.split()
99
+ if line_cnt == 1:
100
+ jobs_count, machines_count = int(split_data[0]), int(
101
+ split_data[1]
102
+ )
103
+ else:
104
+ i = 0
105
+ this_job_op_count = 0
106
+ while i < len(split_data):
107
+ machine, op_time = int(split_data[i]), int(split_data[i + 1])
108
+ data.append((machine, op_time))
109
+ i += 2
110
+ this_job_op_count += 1
111
+ jobs_data.append(data)
112
+ line_str = file.readline()
113
+ line_cnt += 1
114
+ visu.timeline(f'Solution for job-shop, solved using ')
115
+ visu.panel('Jobs')
116
+ # convert to integer the current_solution
117
+ current_solution = [[int(x) for x in y] for y in current_solution]
118
+ for job_id in range(len(current_solution)):
119
+ visu.sequence(name=f'J{job_id}', intervals=[(current_solution[job_id][task_id],
120
+ current_solution[job_id][task_id] + jobs_data[job_id][task_id][
121
+ 1], jobs_data[job_id][task_id][0],
122
+ f'M{jobs_data[job_id][task_id][0]}')
123
+ for task_id in
124
+ range(len(current_solution[job_id]))])
125
+ visu.panel('Machines')
126
+ machine_solution = collections.defaultdict(list)
127
+ for job_id in range(len(current_solution)):
128
+ for task_id in range(len(current_solution[job_id])):
129
+ machine = jobs_data[job_id][task_id][1]
130
+ machine_solution[machine].append((current_solution[job_id][task_id],
131
+ current_solution[job_id][task_id] + jobs_data[job_id][task_id][1],
132
+ machine, f'J{job_id}'))
133
+ # sort dictionary keys
134
+ machine_solution = {k: machine_solution[k] for k in sorted(machine_solution.keys())}
135
+ for machine_id in machine_solution:
136
+ visu.sequence(name=f'M{machine_id}',
137
+ intervals=machine_solution[machine_id])
138
+ plt.rcParams["font.family"] = "Times New Roman"
139
+ plt.rcParams["font.size"] = "30"
140
+ plt.gca().set_aspect('equal')
141
+ plt.rcParams["figure.figsize"] = (45, 50)
142
+ from io import BytesIO
143
+ buffer = BytesIO()
144
+
145
+ visu.show(pngfile=buffer)
146
+ reloadedPILImage = Image.open(buffer)
147
+ return pretty_output, reloadedPILImage, str(total_time) + " seconds"
148
+
149
+ title = "Job-Shop Scheduling CP RL"
150
+ description = "A Job-Shop Scheduling Reinforcement Learning based solver, using an underlying CP model as an " \
151
+ "environment. "
152
+ article = "<p style='text-align: center'>Article Under Review</p>"
153
+ examples = ['ta01', 'dmu01.txt', 'la01.txt']
154
+ iface = gr.Interface(fn=solve, inputs=gr.File(label="Instance File"), outputs=[gr.Text(label="Solution"), gr.Image(label="Solution's Gantt Chart"), gr.Text(label="Elapsed Time")], title=title, description=description, article=article, examples=examples)
155
+ iface.launch(enable_queue=True)
checkpoint.pt ADDED
Binary file (75.6 kB). View file
 
dmu01.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 20 15
2
+ 0 160 13 5 6 139 11 99 12 9 5 98 2 28 1 107 3 196 10 165 7 114 4 7 14 34 8 133 9 76
3
+ 14 105 7 160 3 19 2 189 11 25 1 95 12 15 0 122 4 165 9 2 10 66 13 111 8 51 6 83 5 183
4
+ 11 61 5 11 9 130 4 147 13 106 12 1 6 141 7 136 10 33 0 13 2 15 8 10 14 62 3 4 1 142
5
+ 13 117 1 11 4 162 0 192 5 35 8 172 3 4 14 193 2 141 11 139 6 62 9 12 12 1 7 135 10 25
6
+ 5 53 9 89 10 168 12 41 11 121 1 181 3 43 0 118 4 61 14 193 2 124 6 176 13 28 8 125 7 136
7
+ 5 152 0 115 2 122 14 5 12 46 13 144 11 29 7 176 1 115 6 18 4 23 9 26 3 175 8 110 10 75
8
+ 6 50 1 62 3 186 12 57 11 156 10 32 2 134 9 141 4 189 13 118 0 102 7 3 8 177 14 43 5 41
9
+ 13 35 0 171 14 160 9 32 7 5 11 154 8 195 3 113 12 162 5 152 6 140 2 72 4 16 10 104 1 171
10
+ 13 68 8 54 6 116 4 9 14 99 12 155 10 22 5 135 0 67 1 165 9 100 11 47 3 46 7 55 2 12
11
+ 1 135 5 105 9 49 8 4 12 176 3 52 11 128 7 188 6 170 10 170 2 169 4 62 0 120 13 28 14 70
12
+ 2 93 1 172 13 124 6 72 7 189 14 122 5 38 0 120 12 114 11 51 9 77 8 65 4 176 3 171 10 169
13
+ 3 122 6 21 4 6 13 189 14 75 5 5 9 180 0 160 1 14 11 73 12 45 2 61 7 148 10 96 8 194
14
+ 9 94 12 198 8 100 5 194 2 127 10 95 4 43 3 52 6 166 1 31 14 100 13 104 7 166 11 139 0 143
15
+ 5 4 3 78 11 199 8 119 12 167 0 54 9 38 14 114 13 10 4 115 7 101 1 104 2 61 6 75 10 175
16
+ 10 18 11 115 6 166 8 41 14 124 12 101 7 38 13 29 0 91 2 118 9 40 5 55 1 82 4 89 3 100
17
+ 11 2 9 107 14 99 3 152 7 51 4 13 10 112 0 96 1 150 6 97 13 67 5 57 2 45 8 17 12 184
18
+ 1 176 11 15 3 92 9 9 14 77 12 4 7 83 10 195 4 156 6 102 2 91 13 65 8 19 5 163 0 93
19
+ 8 38 0 32 14 80 11 109 9 71 1 100 12 139 7 52 3 163 13 40 4 5 6 28 2 105 5 186 10 186
20
+ 11 1 3 73 0 106 4 80 12 150 13 5 5 71 9 145 1 138 6 148 10 168 7 60 2 107 14 164 8 178
21
+ 1 14 10 5 4 115 2 70 11 112 5 76 9 20 0 104 7 167 13 58 8 193 12 30 6 132 3 6 14 19
la01.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 10 5
2
+ 1 21 0 53 4 95 3 55 2 34
3
+ 0 21 3 52 4 16 2 26 1 71
4
+ 3 39 4 98 1 42 2 31 0 12
5
+ 1 77 0 55 4 79 2 66 3 77
6
+ 0 83 3 34 2 64 1 19 4 37
7
+ 1 54 2 43 4 79 0 92 3 62
8
+ 3 69 4 77 1 87 2 87 0 93
9
+ 2 38 0 60 1 41 3 24 4 83
10
+ 3 17 1 49 4 25 0 44 2 98
11
+ 4 77 3 79 2 43 1 75 0 96
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ job-shop-cp-env==1.0.0
2
+ ray==2.1.0
3
+ ray[rllib]==2.1.0
4
+ stable-baselines3==1.6.2
ta01 ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 15 15
2
+ 6 94 12 66 4 10 7 53 3 26 2 15 10 65 11 82 8 10 14 27 9 93 13 92 5 96 0 70 1 83
3
+ 4 74 5 31 7 88 14 51 13 57 8 78 11 8 9 7 6 91 10 79 0 18 3 51 12 18 1 99 2 33
4
+ 1 4 8 82 9 40 12 86 6 50 11 54 13 21 5 6 0 54 2 68 7 82 10 20 4 39 3 35 14 68
5
+ 5 73 2 23 9 30 6 30 10 53 0 94 13 58 4 93 7 32 14 91 11 30 8 56 12 27 1 92 3 9
6
+ 7 78 8 23 6 21 10 60 4 36 9 29 2 95 14 99 12 79 5 76 1 93 13 42 11 52 0 42 3 96
7
+ 5 29 3 61 12 88 13 70 11 16 4 31 14 65 7 83 2 78 1 26 10 50 0 87 9 62 6 14 8 30
8
+ 12 18 3 75 7 20 8 4 14 91 6 68 1 19 11 54 4 85 5 73 2 43 10 24 0 37 13 87 9 66
9
+ 11 32 5 52 0 9 7 49 12 61 13 35 14 99 1 62 2 6 8 62 4 7 3 80 9 3 6 57 10 7
10
+ 10 85 11 30 6 96 14 91 0 13 1 87 2 82 5 83 12 78 4 56 8 85 7 8 9 66 13 88 3 15
11
+ 6 5 11 59 9 30 2 60 8 41 0 17 13 66 3 89 10 78 7 88 1 69 12 45 14 82 4 6 5 13
12
+ 4 90 7 27 13 1 0 8 5 91 12 80 6 89 8 49 14 32 10 28 3 90 1 93 11 6 9 35 2 73
13
+ 2 47 14 43 0 75 12 8 6 51 10 3 7 84 5 34 8 28 9 60 13 69 1 45 3 67 11 58 4 87
14
+ 5 65 8 62 10 97 2 20 3 31 6 33 9 33 0 77 13 50 4 80 1 48 11 90 12 75 7 96 14 44
15
+ 8 28 14 21 4 51 13 75 5 17 6 89 9 59 1 56 12 63 7 18 11 17 10 30 3 16 2 7 0 35
16
+ 10 57 8 16 12 42 6 34 4 37 1 26 13 68 14 73 11 5 0 8 7 12 3 87 2 83 9 20 5 97