Spaces:
Runtime error
Runtime error
Pierre Tassel
commited on
Commit
·
d746b98
1
Parent(s):
e8861ce
wip
Browse files- MyDummyVecEnv.py +123 -0
- MyRemoteVectorEnv.py +130 -0
- MyVecEnv.py +47 -0
- Network.py +114 -0
- actor.pt +0 -0
- app.py +155 -0
- checkpoint.pt +0 -0
- dmu01.txt +21 -0
- la01.txt +11 -0
- requirements.txt +4 -0
- ta01 +16 -0
MyDummyVecEnv.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
|
8 |
+
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
|
9 |
+
|
10 |
+
|
11 |
+
class MyDummyVecEnv(VecEnv):
|
12 |
+
"""
|
13 |
+
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
|
14 |
+
Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
|
15 |
+
as the overhead of multiprocess or multithread outweighs the environment computation time.
|
16 |
+
This can also be used for RL methods that
|
17 |
+
require a vectorized environment, but that you want a single environments to train with.
|
18 |
+
|
19 |
+
:param env_fns: a list of functions
|
20 |
+
that return environments to vectorize
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
|
24 |
+
self.envs = [fn() for fn in env_fns]
|
25 |
+
env = self.envs[0]
|
26 |
+
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
27 |
+
obs_space = env.observation_space
|
28 |
+
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
29 |
+
|
30 |
+
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
|
31 |
+
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
32 |
+
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
33 |
+
self.buf_infos = [{} for _ in range(self.num_envs)]
|
34 |
+
self.actions = None
|
35 |
+
|
36 |
+
def step_async(self, actions: np.ndarray) -> None:
|
37 |
+
self.actions = actions
|
38 |
+
|
39 |
+
def step_wait(self) -> VecEnvStepReturn:
|
40 |
+
for env_idx in range(self.num_envs):
|
41 |
+
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
|
42 |
+
self.actions[env_idx]
|
43 |
+
)
|
44 |
+
if self.buf_dones[env_idx]:
|
45 |
+
# save final observation where user can get it, then reset
|
46 |
+
self.buf_infos[env_idx]["terminal_observation"] = obs
|
47 |
+
obs = self.envs[env_idx].reset()
|
48 |
+
self._save_obs(env_idx, obs)
|
49 |
+
return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
|
50 |
+
|
51 |
+
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
52 |
+
seeds = list()
|
53 |
+
for idx, env in enumerate(self.envs):
|
54 |
+
seeds.append(env.seed(seed + idx))
|
55 |
+
return seeds
|
56 |
+
|
57 |
+
def reset(self) -> VecEnvObs:
|
58 |
+
for env_idx in range(self.num_envs):
|
59 |
+
obs = self.envs[env_idx].reset()
|
60 |
+
self._save_obs(env_idx, obs)
|
61 |
+
return self._obs_from_buf()
|
62 |
+
|
63 |
+
def close(self) -> None:
|
64 |
+
for env in self.envs:
|
65 |
+
env.close()
|
66 |
+
|
67 |
+
def get_images(self) -> Sequence[np.ndarray]:
|
68 |
+
return [env.render(mode="rgb_array") for env in self.envs]
|
69 |
+
|
70 |
+
def render(self, mode: str = "human") -> Optional[np.ndarray]:
|
71 |
+
"""
|
72 |
+
Gym environment rendering. If there are multiple environments then
|
73 |
+
they are tiled together in one image via ``BaseVecEnv.render()``.
|
74 |
+
Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
|
75 |
+
underlying environment.
|
76 |
+
|
77 |
+
Therefore, some arguments such as ``mode`` will have values that are valid
|
78 |
+
only when ``num_envs == 1``.
|
79 |
+
|
80 |
+
:param mode: The rendering type.
|
81 |
+
"""
|
82 |
+
if self.num_envs == 1:
|
83 |
+
return self.envs[0].render(mode=mode)
|
84 |
+
else:
|
85 |
+
return super().render(mode=mode)
|
86 |
+
|
87 |
+
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
|
88 |
+
for key in self.keys:
|
89 |
+
if key is None:
|
90 |
+
self.buf_obs[key][env_idx] = obs
|
91 |
+
else:
|
92 |
+
self.buf_obs[key][env_idx] = obs[key]
|
93 |
+
|
94 |
+
def _obs_from_buf(self) -> VecEnvObs:
|
95 |
+
return dict_to_obs(self.observation_space, self.buf_obs)
|
96 |
+
|
97 |
+
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
98 |
+
"""Return attribute from vectorized environment (see base class)."""
|
99 |
+
target_envs = self._get_target_envs(indices)
|
100 |
+
return [getattr(env_i, attr_name) for env_i in target_envs]
|
101 |
+
|
102 |
+
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
103 |
+
"""Set attribute inside vectorized environments (see base class)."""
|
104 |
+
target_envs = self._get_target_envs(indices)
|
105 |
+
for env_i in target_envs:
|
106 |
+
setattr(env_i, attr_name, value)
|
107 |
+
|
108 |
+
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
109 |
+
"""Call instance methods of vectorized environments."""
|
110 |
+
target_envs = self._get_target_envs(indices)
|
111 |
+
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
112 |
+
|
113 |
+
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
114 |
+
"""Check if worker environments are wrapped with a given wrapper"""
|
115 |
+
target_envs = self._get_target_envs(indices)
|
116 |
+
# Import here to avoid a circular import
|
117 |
+
from stable_baselines3.common import env_util
|
118 |
+
|
119 |
+
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
|
120 |
+
|
121 |
+
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
|
122 |
+
indices = self._get_indices(indices)
|
123 |
+
return [self.envs[i] for i in indices]
|
MyRemoteVectorEnv.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Callable, Optional
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import ray
|
8 |
+
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
9 |
+
from ray.rllib.utils.annotations import PublicAPI
|
10 |
+
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
|
11 |
+
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
|
12 |
+
from stable_baselines3.common.vec_env.util import obs_space_info, dict_to_obs
|
13 |
+
|
14 |
+
from MyDummyVecEnv import MyDummyVecEnv
|
15 |
+
|
16 |
+
|
17 |
+
@PublicAPI
|
18 |
+
class MyRemoteVectorEnv(BaseEnv):
|
19 |
+
"""Vector env that executes envs in remote workers.
|
20 |
+
This provides dynamic batching of inference as observations are returned
|
21 |
+
from the remote simulator actors. Both single and multi-agent child envs
|
22 |
+
are supported, and envs can be stepped synchronously or async.
|
23 |
+
You shouldn't need to instantiate this class directly. It's automatically
|
24 |
+
inserted when you use the `remote_worker_envs` option for Trainers.
|
25 |
+
"""
|
26 |
+
|
27 |
+
@property
|
28 |
+
def observation_space(self):
|
29 |
+
return self._observation_space
|
30 |
+
|
31 |
+
def __init__(self, make_env: Callable[[int], EnvType], num_workers: int, env_per_worker: int, observation_space: Optional[gym.spaces.Space], device: torch.device):
|
32 |
+
self.make_local_env = make_env
|
33 |
+
self.num_workers = num_workers
|
34 |
+
self.env_per_worker = env_per_worker
|
35 |
+
self.num_envs = num_workers * env_per_worker
|
36 |
+
self.poll_timeout = None
|
37 |
+
|
38 |
+
self.actors = None # lazy init
|
39 |
+
self.pending = None # lazy init
|
40 |
+
|
41 |
+
self.observation_space = observation_space
|
42 |
+
self.keys, shapes, dtypes = obs_space_info(self.observation_space)
|
43 |
+
|
44 |
+
self.device = device
|
45 |
+
|
46 |
+
self.buf_obs = OrderedDict(
|
47 |
+
[(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
|
48 |
+
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
49 |
+
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
50 |
+
self.buf_infos = [{} for _ in range(self.num_envs)]
|
51 |
+
|
52 |
+
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
|
53 |
+
for key in self.keys:
|
54 |
+
self.buf_obs[key][env_idx * self.env_per_worker: (env_idx + 1) * self.env_per_worker] = torch.from_numpy(obs[key]).to(self.device,
|
55 |
+
non_blocking=True)
|
56 |
+
|
57 |
+
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
58 |
+
MultiEnvDict, MultiEnvDict]:
|
59 |
+
if self.actors is None:
|
60 |
+
|
61 |
+
def make_remote_env(i):
|
62 |
+
return _RemoteSingleAgentEnv.remote(self.make_local_env, i, self.env_per_worker)
|
63 |
+
|
64 |
+
self.actors = [make_remote_env(i) for i in range(self.num_workers)]
|
65 |
+
|
66 |
+
if self.pending is None:
|
67 |
+
self.pending = {a.reset.remote(): a for a in self.actors}
|
68 |
+
|
69 |
+
# each keyed by env_id in [0, num_remote_envs)
|
70 |
+
ready = []
|
71 |
+
|
72 |
+
# Wait for at least 1 env to be ready here
|
73 |
+
while not ready:
|
74 |
+
ready, _ = ray.wait(
|
75 |
+
list(self.pending),
|
76 |
+
num_returns=len(self.pending),
|
77 |
+
timeout=self.poll_timeout)
|
78 |
+
|
79 |
+
for obj_ref in ready:
|
80 |
+
actor = self.pending.pop(obj_ref)
|
81 |
+
env_id = self.actors.index(actor)
|
82 |
+
ob, rew, done, info = ray.get(obj_ref)
|
83 |
+
|
84 |
+
self._save_obs(env_id, ob)
|
85 |
+
self.buf_rews[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = rew
|
86 |
+
self.buf_dones[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = done
|
87 |
+
self.buf_infos[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = info
|
88 |
+
return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
|
89 |
+
|
90 |
+
def _obs_from_buf(self) -> VecEnvObs:
|
91 |
+
return dict_to_obs(self.observation_space, self.buf_obs)
|
92 |
+
|
93 |
+
@PublicAPI
|
94 |
+
def send_actions(self, action_list) -> None:
|
95 |
+
for worker_id in range(self.num_workers):
|
96 |
+
actions = action_list[worker_id * self.env_per_worker: (worker_id + 1) * self.env_per_worker]
|
97 |
+
actor = self.actors[worker_id]
|
98 |
+
obj_ref = actor.step.remote(actions)
|
99 |
+
self.pending[obj_ref] = actor
|
100 |
+
|
101 |
+
@PublicAPI
|
102 |
+
def try_reset(self,
|
103 |
+
env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
|
104 |
+
actor = self.actors[env_id]
|
105 |
+
obj_ref = actor.reset.remote()
|
106 |
+
self.pending[obj_ref] = actor
|
107 |
+
return ASYNC_RESET_RETURN
|
108 |
+
|
109 |
+
@PublicAPI
|
110 |
+
def stop(self) -> None:
|
111 |
+
if self.actors is not None:
|
112 |
+
for actor in self.actors:
|
113 |
+
actor.__ray_terminate__.remote()
|
114 |
+
|
115 |
+
@observation_space.setter
|
116 |
+
def observation_space(self, value):
|
117 |
+
self._observation_space = value
|
118 |
+
|
119 |
+
@ray.remote(num_cpus=1)
|
120 |
+
class _RemoteSingleAgentEnv:
|
121 |
+
"""Wrapper class for making a gym env a remote actor."""
|
122 |
+
|
123 |
+
def __init__(self, make_env, i, env_per_worker):
|
124 |
+
self.env = MyDummyVecEnv([lambda: make_env((i * env_per_worker) + k) for k in range(env_per_worker)])
|
125 |
+
|
126 |
+
def reset(self):
|
127 |
+
return self.env.reset(), 0, False, {}
|
128 |
+
|
129 |
+
def step(self, actions):
|
130 |
+
return self.env.step(actions)
|
MyVecEnv.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Union, Sequence, Type, Any
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import numpy as np
|
5 |
+
from ray.rllib import BaseEnv
|
6 |
+
from stable_baselines3.common.vec_env import VecEnv
|
7 |
+
from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices, VecEnvStepReturn, VecEnvObs
|
8 |
+
|
9 |
+
from MyRemoteVectorEnv import MyRemoteVectorEnv
|
10 |
+
|
11 |
+
|
12 |
+
class WrapperRay(VecEnv):
|
13 |
+
|
14 |
+
def __init__(self, make_env, num_workers, per_worker_env, device):
|
15 |
+
self.one_env = make_env(0)
|
16 |
+
self.remote: BaseEnv = MyRemoteVectorEnv(make_env, num_workers, per_worker_env, self.one_env.observation_space, device)
|
17 |
+
super(WrapperRay, self).__init__(num_workers * per_worker_env, self.one_env.observation_space, self.one_env.action_space)
|
18 |
+
|
19 |
+
def reset(self) -> VecEnvObs:
|
20 |
+
return self.remote.poll()[0]
|
21 |
+
|
22 |
+
def step_async(self, actions: np.ndarray) -> None:
|
23 |
+
self.remote.send_actions(actions)
|
24 |
+
|
25 |
+
def step_wait(self) -> VecEnvStepReturn:
|
26 |
+
return self.remote.poll()
|
27 |
+
|
28 |
+
def close(self) -> None:
|
29 |
+
self.remote.stop()
|
30 |
+
|
31 |
+
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
32 |
+
pass
|
33 |
+
|
34 |
+
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
35 |
+
pass
|
36 |
+
|
37 |
+
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
38 |
+
pass
|
39 |
+
|
40 |
+
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
41 |
+
pass
|
42 |
+
|
43 |
+
def get_images(self) -> Sequence[np.ndarray]:
|
44 |
+
pass
|
45 |
+
|
46 |
+
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
47 |
+
pass
|
Network.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torch.distributions import Categorical
|
7 |
+
|
8 |
+
|
9 |
+
class PositionalEncoding(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, d_model: int, max_len: int = 100):
|
12 |
+
super().__init__()
|
13 |
+
position = torch.arange(max_len).unsqueeze(1)
|
14 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
15 |
+
pe = torch.zeros(max_len, d_model)
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
self.register_buffer('pe', pe)
|
19 |
+
|
20 |
+
def forward(self, positions: Tensor) -> Tensor:
|
21 |
+
return self.pe[positions]
|
22 |
+
|
23 |
+
|
24 |
+
class Actor(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, pos_encoder):
|
27 |
+
super(Actor, self).__init__()
|
28 |
+
self.activation = nn.Tanh()
|
29 |
+
self.project = nn.Linear(4, 8)
|
30 |
+
nn.init.xavier_uniform_(self.project.weight, gain=1.0)
|
31 |
+
nn.init.constant_(self.project.bias, 0)
|
32 |
+
self.pos_encoder = pos_encoder
|
33 |
+
|
34 |
+
self.embedding_fixed = nn.Embedding(2, 1)
|
35 |
+
self.embedding_legal_op = nn.Embedding(2, 1)
|
36 |
+
|
37 |
+
self.tokens_start_end = nn.Embedding(3, 4)
|
38 |
+
|
39 |
+
# self.conv_transform = nn.Conv1d(5, 1, 1)
|
40 |
+
# nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
|
41 |
+
# nn.init.constant_(self.conv_transform.bias, 0)
|
42 |
+
|
43 |
+
self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
|
44 |
+
norm_first=True)
|
45 |
+
self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
|
46 |
+
norm_first=True)
|
47 |
+
|
48 |
+
self.final_tmp = nn.Sequential(
|
49 |
+
layer_init_tanh(nn.Linear(8, 32)),
|
50 |
+
nn.Tanh(),
|
51 |
+
layer_init_tanh(nn.Linear(32, 1), std=0.01)
|
52 |
+
)
|
53 |
+
self.no_op = nn.Sequential(
|
54 |
+
layer_init_tanh(nn.Linear(8, 32)),
|
55 |
+
nn.Tanh(),
|
56 |
+
layer_init_tanh(nn.Linear(32, 1), std=0.01)
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
|
60 |
+
embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
|
61 |
+
self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
|
62 |
+
non_zero_tokens = tokens_start_end != 0
|
63 |
+
t = tokens_start_end[non_zero_tokens].long()
|
64 |
+
embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
|
65 |
+
pos_encoder = self.pos_encoder(indexes_inter.long())
|
66 |
+
pos_encoder[non_zero_tokens] = 0
|
67 |
+
obs = self.project(embedded_obs) + pos_encoder
|
68 |
+
|
69 |
+
transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
|
70 |
+
attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
|
71 |
+
transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
|
72 |
+
transformed_obs = transformed_obs.view(obs.shape)
|
73 |
+
obs = transformed_obs.mean(dim=2)
|
74 |
+
|
75 |
+
job_resource = job_resource[:, :-1, :-1] == 0
|
76 |
+
|
77 |
+
obs_action = self.enc2(obs, src_mask=job_resource) + obs
|
78 |
+
|
79 |
+
logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
|
80 |
+
return logits.masked_fill(mask == 0, -3.4028234663852886e+38)
|
81 |
+
|
82 |
+
|
83 |
+
class Agent(nn.Module):
|
84 |
+
def __init__(self):
|
85 |
+
super(Agent, self).__init__()
|
86 |
+
self.pos_encoder = PositionalEncoding(8)
|
87 |
+
self.actor = Actor(self.pos_encoder)
|
88 |
+
|
89 |
+
def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
|
90 |
+
action=None):
|
91 |
+
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
|
92 |
+
probs = Categorical(logits=logits)
|
93 |
+
if action is None:
|
94 |
+
probabilities = probs.probs
|
95 |
+
actions = torch.multinomial(probabilities, probabilities.shape[1])
|
96 |
+
return actions, torch.log(probabilities), probs.entropy()
|
97 |
+
else:
|
98 |
+
return logits, probs.log_prob(action), probs.entropy()
|
99 |
+
|
100 |
+
def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
|
101 |
+
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
|
102 |
+
probs = Categorical(logits=logits)
|
103 |
+
return probs.sample()
|
104 |
+
|
105 |
+
def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
|
106 |
+
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
|
107 |
+
return logits
|
108 |
+
|
109 |
+
|
110 |
+
def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
|
111 |
+
torch.nn.init.orthogonal_(layer.weight, std)
|
112 |
+
if layer.bias is not None:
|
113 |
+
torch.nn.init.constant_(layer.bias, bias_const)
|
114 |
+
return layer
|
actor.pt
ADDED
Binary file (80.3 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
|
5 |
+
import multiprocessing as mp
|
6 |
+
import json
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from compiled_jss.CPEnv import CompiledJssEnvCP
|
10 |
+
|
11 |
+
from stable_baselines3.common.vec_env import VecEnvWrapper
|
12 |
+
from torch.distributions import Categorical
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from MyVecEnv import WrapperRay
|
18 |
+
|
19 |
+
import gradio as gr
|
20 |
+
import docplex.cp.utils_visu as visu
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
|
24 |
+
class VecPyTorch(VecEnvWrapper):
|
25 |
+
|
26 |
+
def __init__(self, venv, device):
|
27 |
+
super(VecPyTorch, self).__init__(venv)
|
28 |
+
self.device = device
|
29 |
+
|
30 |
+
def reset(self):
|
31 |
+
return self.venv.reset()
|
32 |
+
|
33 |
+
def step_async(self, actions):
|
34 |
+
self.venv.step_async(actions)
|
35 |
+
|
36 |
+
def step_wait(self):
|
37 |
+
return self.venv.step_wait()
|
38 |
+
|
39 |
+
|
40 |
+
def make_env(seed, instance):
|
41 |
+
def thunk():
|
42 |
+
_env = CompiledJssEnvCP(instance)
|
43 |
+
return _env
|
44 |
+
|
45 |
+
return thunk
|
46 |
+
|
47 |
+
|
48 |
+
def solve(file):
|
49 |
+
random.seed(0)
|
50 |
+
np.random.seed(0)
|
51 |
+
torch.manual_seed(0)
|
52 |
+
num_workers = min(mp.cpu_count(), 32)
|
53 |
+
with torch.inference_mode():
|
54 |
+
device = torch.device('cpu')
|
55 |
+
actor = torch.jit.load('actor.pt', map_location=device)
|
56 |
+
actor.eval()
|
57 |
+
start_time = time.time()
|
58 |
+
fn_env = [make_env(0, file.name)
|
59 |
+
for _ in range(num_workers)]
|
60 |
+
ray_wrapper_env = WrapperRay(lambda n: fn_env[n](),
|
61 |
+
num_workers, 1, device)
|
62 |
+
envs = VecPyTorch(ray_wrapper_env, device)
|
63 |
+
current_solution_cost = float('inf')
|
64 |
+
current_solution = ''
|
65 |
+
obs = envs.reset()
|
66 |
+
total_episode = 0
|
67 |
+
while total_episode < envs.num_envs:
|
68 |
+
logits = actor(obs['interval_rep'], obs['attention_interval_mask'], obs['job_resource_mask'],
|
69 |
+
obs['action_mask'], obs['index_interval'], obs['start_end_tokens'])
|
70 |
+
# temperature vector
|
71 |
+
if num_workers >= 4:
|
72 |
+
temperature = torch.arange(0.5, 2.0, step=(1.5 / num_workers), device=device)
|
73 |
+
else:
|
74 |
+
temperature = torch.ones(num_workers, device=device)
|
75 |
+
logits = logits / temperature[:, None]
|
76 |
+
probs = Categorical(logits=logits).probs
|
77 |
+
# random sample based on logits
|
78 |
+
actions = torch.multinomial(probs, probs.shape[1]).cpu().numpy()
|
79 |
+
obs, reward, done, infos = envs.step(actions)
|
80 |
+
total_episode += done.sum()
|
81 |
+
# total_actions += 1
|
82 |
+
# print(f'Episode {total_episode} / {envs.num_envs} - Actions {total_actions}', end='\r')
|
83 |
+
for env_idx, info in enumerate(infos):
|
84 |
+
if 'makespan' in info and int(info['makespan']) < current_solution_cost:
|
85 |
+
current_solution_cost = int(info['makespan'])
|
86 |
+
current_solution = json.loads(info['solution'])
|
87 |
+
total_time = time.time() - start_time
|
88 |
+
pretty_output = ""
|
89 |
+
for job_id in range(len(current_solution)):
|
90 |
+
pretty_output += f"Job {job_id}: {current_solution[job_id]}\n"
|
91 |
+
|
92 |
+
jobs_data = []
|
93 |
+
file.seek(0)
|
94 |
+
line_str: str = file.readline()
|
95 |
+
line_cnt: int = 1
|
96 |
+
while line_str:
|
97 |
+
data = []
|
98 |
+
split_data = line_str.split()
|
99 |
+
if line_cnt == 1:
|
100 |
+
jobs_count, machines_count = int(split_data[0]), int(
|
101 |
+
split_data[1]
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
i = 0
|
105 |
+
this_job_op_count = 0
|
106 |
+
while i < len(split_data):
|
107 |
+
machine, op_time = int(split_data[i]), int(split_data[i + 1])
|
108 |
+
data.append((machine, op_time))
|
109 |
+
i += 2
|
110 |
+
this_job_op_count += 1
|
111 |
+
jobs_data.append(data)
|
112 |
+
line_str = file.readline()
|
113 |
+
line_cnt += 1
|
114 |
+
visu.timeline(f'Solution for job-shop, solved using ')
|
115 |
+
visu.panel('Jobs')
|
116 |
+
# convert to integer the current_solution
|
117 |
+
current_solution = [[int(x) for x in y] for y in current_solution]
|
118 |
+
for job_id in range(len(current_solution)):
|
119 |
+
visu.sequence(name=f'J{job_id}', intervals=[(current_solution[job_id][task_id],
|
120 |
+
current_solution[job_id][task_id] + jobs_data[job_id][task_id][
|
121 |
+
1], jobs_data[job_id][task_id][0],
|
122 |
+
f'M{jobs_data[job_id][task_id][0]}')
|
123 |
+
for task_id in
|
124 |
+
range(len(current_solution[job_id]))])
|
125 |
+
visu.panel('Machines')
|
126 |
+
machine_solution = collections.defaultdict(list)
|
127 |
+
for job_id in range(len(current_solution)):
|
128 |
+
for task_id in range(len(current_solution[job_id])):
|
129 |
+
machine = jobs_data[job_id][task_id][1]
|
130 |
+
machine_solution[machine].append((current_solution[job_id][task_id],
|
131 |
+
current_solution[job_id][task_id] + jobs_data[job_id][task_id][1],
|
132 |
+
machine, f'J{job_id}'))
|
133 |
+
# sort dictionary keys
|
134 |
+
machine_solution = {k: machine_solution[k] for k in sorted(machine_solution.keys())}
|
135 |
+
for machine_id in machine_solution:
|
136 |
+
visu.sequence(name=f'M{machine_id}',
|
137 |
+
intervals=machine_solution[machine_id])
|
138 |
+
plt.rcParams["font.family"] = "Times New Roman"
|
139 |
+
plt.rcParams["font.size"] = "30"
|
140 |
+
plt.gca().set_aspect('equal')
|
141 |
+
plt.rcParams["figure.figsize"] = (45, 50)
|
142 |
+
from io import BytesIO
|
143 |
+
buffer = BytesIO()
|
144 |
+
|
145 |
+
visu.show(pngfile=buffer)
|
146 |
+
reloadedPILImage = Image.open(buffer)
|
147 |
+
return pretty_output, reloadedPILImage, str(total_time) + " seconds"
|
148 |
+
|
149 |
+
title = "Job-Shop Scheduling CP RL"
|
150 |
+
description = "A Job-Shop Scheduling Reinforcement Learning based solver, using an underlying CP model as an " \
|
151 |
+
"environment. "
|
152 |
+
article = "<p style='text-align: center'>Article Under Review</p>"
|
153 |
+
examples = ['ta01', 'dmu01.txt', 'la01.txt']
|
154 |
+
iface = gr.Interface(fn=solve, inputs=gr.File(label="Instance File"), outputs=[gr.Text(label="Solution"), gr.Image(label="Solution's Gantt Chart"), gr.Text(label="Elapsed Time")], title=title, description=description, article=article, examples=examples)
|
155 |
+
iface.launch(enable_queue=True)
|
checkpoint.pt
ADDED
Binary file (75.6 kB). View file
|
|
dmu01.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
20 15
|
2 |
+
0 160 13 5 6 139 11 99 12 9 5 98 2 28 1 107 3 196 10 165 7 114 4 7 14 34 8 133 9 76
|
3 |
+
14 105 7 160 3 19 2 189 11 25 1 95 12 15 0 122 4 165 9 2 10 66 13 111 8 51 6 83 5 183
|
4 |
+
11 61 5 11 9 130 4 147 13 106 12 1 6 141 7 136 10 33 0 13 2 15 8 10 14 62 3 4 1 142
|
5 |
+
13 117 1 11 4 162 0 192 5 35 8 172 3 4 14 193 2 141 11 139 6 62 9 12 12 1 7 135 10 25
|
6 |
+
5 53 9 89 10 168 12 41 11 121 1 181 3 43 0 118 4 61 14 193 2 124 6 176 13 28 8 125 7 136
|
7 |
+
5 152 0 115 2 122 14 5 12 46 13 144 11 29 7 176 1 115 6 18 4 23 9 26 3 175 8 110 10 75
|
8 |
+
6 50 1 62 3 186 12 57 11 156 10 32 2 134 9 141 4 189 13 118 0 102 7 3 8 177 14 43 5 41
|
9 |
+
13 35 0 171 14 160 9 32 7 5 11 154 8 195 3 113 12 162 5 152 6 140 2 72 4 16 10 104 1 171
|
10 |
+
13 68 8 54 6 116 4 9 14 99 12 155 10 22 5 135 0 67 1 165 9 100 11 47 3 46 7 55 2 12
|
11 |
+
1 135 5 105 9 49 8 4 12 176 3 52 11 128 7 188 6 170 10 170 2 169 4 62 0 120 13 28 14 70
|
12 |
+
2 93 1 172 13 124 6 72 7 189 14 122 5 38 0 120 12 114 11 51 9 77 8 65 4 176 3 171 10 169
|
13 |
+
3 122 6 21 4 6 13 189 14 75 5 5 9 180 0 160 1 14 11 73 12 45 2 61 7 148 10 96 8 194
|
14 |
+
9 94 12 198 8 100 5 194 2 127 10 95 4 43 3 52 6 166 1 31 14 100 13 104 7 166 11 139 0 143
|
15 |
+
5 4 3 78 11 199 8 119 12 167 0 54 9 38 14 114 13 10 4 115 7 101 1 104 2 61 6 75 10 175
|
16 |
+
10 18 11 115 6 166 8 41 14 124 12 101 7 38 13 29 0 91 2 118 9 40 5 55 1 82 4 89 3 100
|
17 |
+
11 2 9 107 14 99 3 152 7 51 4 13 10 112 0 96 1 150 6 97 13 67 5 57 2 45 8 17 12 184
|
18 |
+
1 176 11 15 3 92 9 9 14 77 12 4 7 83 10 195 4 156 6 102 2 91 13 65 8 19 5 163 0 93
|
19 |
+
8 38 0 32 14 80 11 109 9 71 1 100 12 139 7 52 3 163 13 40 4 5 6 28 2 105 5 186 10 186
|
20 |
+
11 1 3 73 0 106 4 80 12 150 13 5 5 71 9 145 1 138 6 148 10 168 7 60 2 107 14 164 8 178
|
21 |
+
1 14 10 5 4 115 2 70 11 112 5 76 9 20 0 104 7 167 13 58 8 193 12 30 6 132 3 6 14 19
|
la01.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
10 5
|
2 |
+
1 21 0 53 4 95 3 55 2 34
|
3 |
+
0 21 3 52 4 16 2 26 1 71
|
4 |
+
3 39 4 98 1 42 2 31 0 12
|
5 |
+
1 77 0 55 4 79 2 66 3 77
|
6 |
+
0 83 3 34 2 64 1 19 4 37
|
7 |
+
1 54 2 43 4 79 0 92 3 62
|
8 |
+
3 69 4 77 1 87 2 87 0 93
|
9 |
+
2 38 0 60 1 41 3 24 4 83
|
10 |
+
3 17 1 49 4 25 0 44 2 98
|
11 |
+
4 77 3 79 2 43 1 75 0 96
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
job-shop-cp-env==1.0.0
|
2 |
+
ray==2.1.0
|
3 |
+
ray[rllib]==2.1.0
|
4 |
+
stable-baselines3==1.6.2
|
ta01
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
15 15
|
2 |
+
6 94 12 66 4 10 7 53 3 26 2 15 10 65 11 82 8 10 14 27 9 93 13 92 5 96 0 70 1 83
|
3 |
+
4 74 5 31 7 88 14 51 13 57 8 78 11 8 9 7 6 91 10 79 0 18 3 51 12 18 1 99 2 33
|
4 |
+
1 4 8 82 9 40 12 86 6 50 11 54 13 21 5 6 0 54 2 68 7 82 10 20 4 39 3 35 14 68
|
5 |
+
5 73 2 23 9 30 6 30 10 53 0 94 13 58 4 93 7 32 14 91 11 30 8 56 12 27 1 92 3 9
|
6 |
+
7 78 8 23 6 21 10 60 4 36 9 29 2 95 14 99 12 79 5 76 1 93 13 42 11 52 0 42 3 96
|
7 |
+
5 29 3 61 12 88 13 70 11 16 4 31 14 65 7 83 2 78 1 26 10 50 0 87 9 62 6 14 8 30
|
8 |
+
12 18 3 75 7 20 8 4 14 91 6 68 1 19 11 54 4 85 5 73 2 43 10 24 0 37 13 87 9 66
|
9 |
+
11 32 5 52 0 9 7 49 12 61 13 35 14 99 1 62 2 6 8 62 4 7 3 80 9 3 6 57 10 7
|
10 |
+
10 85 11 30 6 96 14 91 0 13 1 87 2 82 5 83 12 78 4 56 8 85 7 8 9 66 13 88 3 15
|
11 |
+
6 5 11 59 9 30 2 60 8 41 0 17 13 66 3 89 10 78 7 88 1 69 12 45 14 82 4 6 5 13
|
12 |
+
4 90 7 27 13 1 0 8 5 91 12 80 6 89 8 49 14 32 10 28 3 90 1 93 11 6 9 35 2 73
|
13 |
+
2 47 14 43 0 75 12 8 6 51 10 3 7 84 5 34 8 28 9 60 13 69 1 45 3 67 11 58 4 87
|
14 |
+
5 65 8 62 10 97 2 20 3 31 6 33 9 33 0 77 13 50 4 80 1 48 11 90 12 75 7 96 14 44
|
15 |
+
8 28 14 21 4 51 13 75 5 17 6 89 9 59 1 56 12 63 7 18 11 17 10 30 3 16 2 7 0 35
|
16 |
+
10 57 8 16 12 42 6 34 4 37 1 26 13 68 14 73 11 5 0 8 7 12 3 87 2 83 9 20 5 97
|