Spaces:
Runtime error
Runtime error
Pierre Tassel
commited on
Commit
·
21e8280
1
Parent(s):
bb95fff
cleanup
Browse files- MyRemoteVectorEnv.py +0 -130
- MyVecEnv.py +0 -47
- Network.py +0 -114
- app.py +0 -3
- checkpoint.pt +0 -0
MyRemoteVectorEnv.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
from typing import Tuple, Callable, Optional
|
2 |
-
from collections import OrderedDict
|
3 |
-
|
4 |
-
import gym
|
5 |
-
import torch
|
6 |
-
import numpy as np
|
7 |
-
import ray
|
8 |
-
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
9 |
-
from ray.rllib.utils.annotations import PublicAPI
|
10 |
-
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
|
11 |
-
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
|
12 |
-
from stable_baselines3.common.vec_env.util import obs_space_info, dict_to_obs
|
13 |
-
|
14 |
-
from MyDummyVecEnv import MyDummyVecEnv
|
15 |
-
|
16 |
-
|
17 |
-
@PublicAPI
|
18 |
-
class MyRemoteVectorEnv(BaseEnv):
|
19 |
-
"""Vector env that executes envs in remote workers.
|
20 |
-
This provides dynamic batching of inference as observations are returned
|
21 |
-
from the remote simulator actors. Both single and multi-agent child envs
|
22 |
-
are supported, and envs can be stepped synchronously or async.
|
23 |
-
You shouldn't need to instantiate this class directly. It's automatically
|
24 |
-
inserted when you use the `remote_worker_envs` option for Trainers.
|
25 |
-
"""
|
26 |
-
|
27 |
-
@property
|
28 |
-
def observation_space(self):
|
29 |
-
return self._observation_space
|
30 |
-
|
31 |
-
def __init__(self, make_env: Callable[[int], EnvType], num_workers: int, env_per_worker: int, observation_space: Optional[gym.spaces.Space], device: torch.device):
|
32 |
-
self.make_local_env = make_env
|
33 |
-
self.num_workers = num_workers
|
34 |
-
self.env_per_worker = env_per_worker
|
35 |
-
self.num_envs = num_workers * env_per_worker
|
36 |
-
self.poll_timeout = None
|
37 |
-
|
38 |
-
self.actors = None # lazy init
|
39 |
-
self.pending = None # lazy init
|
40 |
-
|
41 |
-
self.observation_space = observation_space
|
42 |
-
self.keys, shapes, dtypes = obs_space_info(self.observation_space)
|
43 |
-
|
44 |
-
self.device = device
|
45 |
-
|
46 |
-
self.buf_obs = OrderedDict(
|
47 |
-
[(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
|
48 |
-
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
49 |
-
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
50 |
-
self.buf_infos = [{} for _ in range(self.num_envs)]
|
51 |
-
|
52 |
-
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
|
53 |
-
for key in self.keys:
|
54 |
-
self.buf_obs[key][env_idx * self.env_per_worker: (env_idx + 1) * self.env_per_worker] = torch.from_numpy(obs[key]).to(self.device,
|
55 |
-
non_blocking=True)
|
56 |
-
|
57 |
-
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
|
58 |
-
MultiEnvDict, MultiEnvDict]:
|
59 |
-
if self.actors is None:
|
60 |
-
|
61 |
-
def make_remote_env(i):
|
62 |
-
return _RemoteSingleAgentEnv.remote(self.make_local_env, i, self.env_per_worker)
|
63 |
-
|
64 |
-
self.actors = [make_remote_env(i) for i in range(self.num_workers)]
|
65 |
-
|
66 |
-
if self.pending is None:
|
67 |
-
self.pending = {a.reset.remote(): a for a in self.actors}
|
68 |
-
|
69 |
-
# each keyed by env_id in [0, num_remote_envs)
|
70 |
-
ready = []
|
71 |
-
|
72 |
-
# Wait for at least 1 env to be ready here
|
73 |
-
while not ready:
|
74 |
-
ready, _ = ray.wait(
|
75 |
-
list(self.pending),
|
76 |
-
num_returns=len(self.pending),
|
77 |
-
timeout=self.poll_timeout)
|
78 |
-
|
79 |
-
for obj_ref in ready:
|
80 |
-
actor = self.pending.pop(obj_ref)
|
81 |
-
env_id = self.actors.index(actor)
|
82 |
-
ob, rew, done, info = ray.get(obj_ref)
|
83 |
-
|
84 |
-
self._save_obs(env_id, ob)
|
85 |
-
self.buf_rews[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = rew
|
86 |
-
self.buf_dones[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = done
|
87 |
-
self.buf_infos[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = info
|
88 |
-
return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
|
89 |
-
|
90 |
-
def _obs_from_buf(self) -> VecEnvObs:
|
91 |
-
return dict_to_obs(self.observation_space, self.buf_obs)
|
92 |
-
|
93 |
-
@PublicAPI
|
94 |
-
def send_actions(self, action_list) -> None:
|
95 |
-
for worker_id in range(self.num_workers):
|
96 |
-
actions = action_list[worker_id * self.env_per_worker: (worker_id + 1) * self.env_per_worker]
|
97 |
-
actor = self.actors[worker_id]
|
98 |
-
obj_ref = actor.step.remote(actions)
|
99 |
-
self.pending[obj_ref] = actor
|
100 |
-
|
101 |
-
@PublicAPI
|
102 |
-
def try_reset(self,
|
103 |
-
env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
|
104 |
-
actor = self.actors[env_id]
|
105 |
-
obj_ref = actor.reset.remote()
|
106 |
-
self.pending[obj_ref] = actor
|
107 |
-
return ASYNC_RESET_RETURN
|
108 |
-
|
109 |
-
@PublicAPI
|
110 |
-
def stop(self) -> None:
|
111 |
-
if self.actors is not None:
|
112 |
-
for actor in self.actors:
|
113 |
-
actor.__ray_terminate__.remote()
|
114 |
-
|
115 |
-
@observation_space.setter
|
116 |
-
def observation_space(self, value):
|
117 |
-
self._observation_space = value
|
118 |
-
|
119 |
-
@ray.remote(num_cpus=1)
|
120 |
-
class _RemoteSingleAgentEnv:
|
121 |
-
"""Wrapper class for making a gym env a remote actor."""
|
122 |
-
|
123 |
-
def __init__(self, make_env, i, env_per_worker):
|
124 |
-
self.env = MyDummyVecEnv([lambda: make_env((i * env_per_worker) + k) for k in range(env_per_worker)])
|
125 |
-
|
126 |
-
def reset(self):
|
127 |
-
return self.env.reset(), 0, False, {}
|
128 |
-
|
129 |
-
def step(self, actions):
|
130 |
-
return self.env.step(actions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MyVecEnv.py
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
from typing import Optional, List, Union, Sequence, Type, Any
|
2 |
-
|
3 |
-
import gym
|
4 |
-
import numpy as np
|
5 |
-
from ray.rllib import BaseEnv
|
6 |
-
from stable_baselines3.common.vec_env import VecEnv
|
7 |
-
from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices, VecEnvStepReturn, VecEnvObs
|
8 |
-
|
9 |
-
from MyRemoteVectorEnv import MyRemoteVectorEnv
|
10 |
-
|
11 |
-
|
12 |
-
class WrapperRay(VecEnv):
|
13 |
-
|
14 |
-
def __init__(self, make_env, num_workers, per_worker_env, device):
|
15 |
-
self.one_env = make_env(0)
|
16 |
-
self.remote: BaseEnv = MyRemoteVectorEnv(make_env, num_workers, per_worker_env, self.one_env.observation_space, device)
|
17 |
-
super(WrapperRay, self).__init__(num_workers * per_worker_env, self.one_env.observation_space, self.one_env.action_space)
|
18 |
-
|
19 |
-
def reset(self) -> VecEnvObs:
|
20 |
-
return self.remote.poll()[0]
|
21 |
-
|
22 |
-
def step_async(self, actions: np.ndarray) -> None:
|
23 |
-
self.remote.send_actions(actions)
|
24 |
-
|
25 |
-
def step_wait(self) -> VecEnvStepReturn:
|
26 |
-
return self.remote.poll()
|
27 |
-
|
28 |
-
def close(self) -> None:
|
29 |
-
self.remote.stop()
|
30 |
-
|
31 |
-
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
|
32 |
-
pass
|
33 |
-
|
34 |
-
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
|
35 |
-
pass
|
36 |
-
|
37 |
-
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
|
38 |
-
pass
|
39 |
-
|
40 |
-
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
|
41 |
-
pass
|
42 |
-
|
43 |
-
def get_images(self) -> Sequence[np.ndarray]:
|
44 |
-
pass
|
45 |
-
|
46 |
-
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
47 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Network.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
from torch import nn, Tensor
|
6 |
-
from torch.distributions import Categorical
|
7 |
-
|
8 |
-
|
9 |
-
class PositionalEncoding(nn.Module):
|
10 |
-
|
11 |
-
def __init__(self, d_model: int, max_len: int = 100):
|
12 |
-
super().__init__()
|
13 |
-
position = torch.arange(max_len).unsqueeze(1)
|
14 |
-
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
15 |
-
pe = torch.zeros(max_len, d_model)
|
16 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
-
self.register_buffer('pe', pe)
|
19 |
-
|
20 |
-
def forward(self, positions: Tensor) -> Tensor:
|
21 |
-
return self.pe[positions]
|
22 |
-
|
23 |
-
|
24 |
-
class Actor(nn.Module):
|
25 |
-
|
26 |
-
def __init__(self, pos_encoder):
|
27 |
-
super(Actor, self).__init__()
|
28 |
-
self.activation = nn.Tanh()
|
29 |
-
self.project = nn.Linear(4, 8)
|
30 |
-
nn.init.xavier_uniform_(self.project.weight, gain=1.0)
|
31 |
-
nn.init.constant_(self.project.bias, 0)
|
32 |
-
self.pos_encoder = pos_encoder
|
33 |
-
|
34 |
-
self.embedding_fixed = nn.Embedding(2, 1)
|
35 |
-
self.embedding_legal_op = nn.Embedding(2, 1)
|
36 |
-
|
37 |
-
self.tokens_start_end = nn.Embedding(3, 4)
|
38 |
-
|
39 |
-
# self.conv_transform = nn.Conv1d(5, 1, 1)
|
40 |
-
# nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
|
41 |
-
# nn.init.constant_(self.conv_transform.bias, 0)
|
42 |
-
|
43 |
-
self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
|
44 |
-
norm_first=True)
|
45 |
-
self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
|
46 |
-
norm_first=True)
|
47 |
-
|
48 |
-
self.final_tmp = nn.Sequential(
|
49 |
-
layer_init_tanh(nn.Linear(8, 32)),
|
50 |
-
nn.Tanh(),
|
51 |
-
layer_init_tanh(nn.Linear(32, 1), std=0.01)
|
52 |
-
)
|
53 |
-
self.no_op = nn.Sequential(
|
54 |
-
layer_init_tanh(nn.Linear(8, 32)),
|
55 |
-
nn.Tanh(),
|
56 |
-
layer_init_tanh(nn.Linear(32, 1), std=0.01)
|
57 |
-
)
|
58 |
-
|
59 |
-
def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
|
60 |
-
embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
|
61 |
-
self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
|
62 |
-
non_zero_tokens = tokens_start_end != 0
|
63 |
-
t = tokens_start_end[non_zero_tokens].long()
|
64 |
-
embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
|
65 |
-
pos_encoder = self.pos_encoder(indexes_inter.long())
|
66 |
-
pos_encoder[non_zero_tokens] = 0
|
67 |
-
obs = self.project(embedded_obs) + pos_encoder
|
68 |
-
|
69 |
-
transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
|
70 |
-
attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
|
71 |
-
transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
|
72 |
-
transformed_obs = transformed_obs.view(obs.shape)
|
73 |
-
obs = transformed_obs.mean(dim=2)
|
74 |
-
|
75 |
-
job_resource = job_resource[:, :-1, :-1] == 0
|
76 |
-
|
77 |
-
obs_action = self.enc2(obs, src_mask=job_resource) + obs
|
78 |
-
|
79 |
-
logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
|
80 |
-
return logits.masked_fill(mask == 0, -3.4028234663852886e+38)
|
81 |
-
|
82 |
-
|
83 |
-
class Agent(nn.Module):
|
84 |
-
def __init__(self):
|
85 |
-
super(Agent, self).__init__()
|
86 |
-
self.pos_encoder = PositionalEncoding(8)
|
87 |
-
self.actor = Actor(self.pos_encoder)
|
88 |
-
|
89 |
-
def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
|
90 |
-
action=None):
|
91 |
-
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
|
92 |
-
probs = Categorical(logits=logits)
|
93 |
-
if action is None:
|
94 |
-
probabilities = probs.probs
|
95 |
-
actions = torch.multinomial(probabilities, probabilities.shape[1])
|
96 |
-
return actions, torch.log(probabilities), probs.entropy()
|
97 |
-
else:
|
98 |
-
return logits, probs.log_prob(action), probs.entropy()
|
99 |
-
|
100 |
-
def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
|
101 |
-
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
|
102 |
-
probs = Categorical(logits=logits)
|
103 |
-
return probs.sample()
|
104 |
-
|
105 |
-
def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
|
106 |
-
logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
|
107 |
-
return logits
|
108 |
-
|
109 |
-
|
110 |
-
def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
|
111 |
-
torch.nn.init.orthogonal_(layer.weight, std)
|
112 |
-
if layer.bias is not None:
|
113 |
-
torch.nn.init.constant_(layer.bias, bias_const)
|
114 |
-
return layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -3,12 +3,10 @@ import os
|
|
3 |
import random
|
4 |
import time
|
5 |
|
6 |
-
import gym
|
7 |
import plotly.figure_factory as ff
|
8 |
import json
|
9 |
|
10 |
import pandas as pd
|
11 |
-
import ray
|
12 |
|
13 |
from compiled_jss.CPEnv import CompiledJssEnvCP
|
14 |
|
@@ -19,7 +17,6 @@ import torch
|
|
19 |
import numpy as np
|
20 |
|
21 |
from MyDummyVecEnv import MyDummyVecEnv
|
22 |
-
from MyVecEnv import WrapperRay
|
23 |
|
24 |
import gradio as gr
|
25 |
|
|
|
3 |
import random
|
4 |
import time
|
5 |
|
|
|
6 |
import plotly.figure_factory as ff
|
7 |
import json
|
8 |
|
9 |
import pandas as pd
|
|
|
10 |
|
11 |
from compiled_jss.CPEnv import CompiledJssEnvCP
|
12 |
|
|
|
17 |
import numpy as np
|
18 |
|
19 |
from MyDummyVecEnv import MyDummyVecEnv
|
|
|
20 |
|
21 |
import gradio as gr
|
22 |
|
checkpoint.pt
DELETED
Binary file (75.6 kB)
|
|