Pierre Tassel commited on
Commit
21e8280
·
1 Parent(s): bb95fff
Files changed (5) hide show
  1. MyRemoteVectorEnv.py +0 -130
  2. MyVecEnv.py +0 -47
  3. Network.py +0 -114
  4. app.py +0 -3
  5. checkpoint.pt +0 -0
MyRemoteVectorEnv.py DELETED
@@ -1,130 +0,0 @@
1
- from typing import Tuple, Callable, Optional
2
- from collections import OrderedDict
3
-
4
- import gym
5
- import torch
6
- import numpy as np
7
- import ray
8
- from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
9
- from ray.rllib.utils.annotations import PublicAPI
10
- from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID, MultiAgentDict
11
- from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
12
- from stable_baselines3.common.vec_env.util import obs_space_info, dict_to_obs
13
-
14
- from MyDummyVecEnv import MyDummyVecEnv
15
-
16
-
17
- @PublicAPI
18
- class MyRemoteVectorEnv(BaseEnv):
19
- """Vector env that executes envs in remote workers.
20
- This provides dynamic batching of inference as observations are returned
21
- from the remote simulator actors. Both single and multi-agent child envs
22
- are supported, and envs can be stepped synchronously or async.
23
- You shouldn't need to instantiate this class directly. It's automatically
24
- inserted when you use the `remote_worker_envs` option for Trainers.
25
- """
26
-
27
- @property
28
- def observation_space(self):
29
- return self._observation_space
30
-
31
- def __init__(self, make_env: Callable[[int], EnvType], num_workers: int, env_per_worker: int, observation_space: Optional[gym.spaces.Space], device: torch.device):
32
- self.make_local_env = make_env
33
- self.num_workers = num_workers
34
- self.env_per_worker = env_per_worker
35
- self.num_envs = num_workers * env_per_worker
36
- self.poll_timeout = None
37
-
38
- self.actors = None # lazy init
39
- self.pending = None # lazy init
40
-
41
- self.observation_space = observation_space
42
- self.keys, shapes, dtypes = obs_space_info(self.observation_space)
43
-
44
- self.device = device
45
-
46
- self.buf_obs = OrderedDict(
47
- [(k, torch.zeros((self.num_envs,) + tuple(shapes[k]), dtype=torch.float, device=self.device)) for k in self.keys])
48
- self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
49
- self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
50
- self.buf_infos = [{} for _ in range(self.num_envs)]
51
-
52
- def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
53
- for key in self.keys:
54
- self.buf_obs[key][env_idx * self.env_per_worker: (env_idx + 1) * self.env_per_worker] = torch.from_numpy(obs[key]).to(self.device,
55
- non_blocking=True)
56
-
57
- def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
58
- MultiEnvDict, MultiEnvDict]:
59
- if self.actors is None:
60
-
61
- def make_remote_env(i):
62
- return _RemoteSingleAgentEnv.remote(self.make_local_env, i, self.env_per_worker)
63
-
64
- self.actors = [make_remote_env(i) for i in range(self.num_workers)]
65
-
66
- if self.pending is None:
67
- self.pending = {a.reset.remote(): a for a in self.actors}
68
-
69
- # each keyed by env_id in [0, num_remote_envs)
70
- ready = []
71
-
72
- # Wait for at least 1 env to be ready here
73
- while not ready:
74
- ready, _ = ray.wait(
75
- list(self.pending),
76
- num_returns=len(self.pending),
77
- timeout=self.poll_timeout)
78
-
79
- for obj_ref in ready:
80
- actor = self.pending.pop(obj_ref)
81
- env_id = self.actors.index(actor)
82
- ob, rew, done, info = ray.get(obj_ref)
83
-
84
- self._save_obs(env_id, ob)
85
- self.buf_rews[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = rew
86
- self.buf_dones[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = done
87
- self.buf_infos[env_id * self.env_per_worker: (env_id + 1) * self.env_per_worker] = info
88
- return (self._obs_from_buf(), self.buf_rews, self.buf_dones, self.buf_infos)
89
-
90
- def _obs_from_buf(self) -> VecEnvObs:
91
- return dict_to_obs(self.observation_space, self.buf_obs)
92
-
93
- @PublicAPI
94
- def send_actions(self, action_list) -> None:
95
- for worker_id in range(self.num_workers):
96
- actions = action_list[worker_id * self.env_per_worker: (worker_id + 1) * self.env_per_worker]
97
- actor = self.actors[worker_id]
98
- obj_ref = actor.step.remote(actions)
99
- self.pending[obj_ref] = actor
100
-
101
- @PublicAPI
102
- def try_reset(self,
103
- env_id: Optional[EnvID] = None) -> Optional[MultiAgentDict]:
104
- actor = self.actors[env_id]
105
- obj_ref = actor.reset.remote()
106
- self.pending[obj_ref] = actor
107
- return ASYNC_RESET_RETURN
108
-
109
- @PublicAPI
110
- def stop(self) -> None:
111
- if self.actors is not None:
112
- for actor in self.actors:
113
- actor.__ray_terminate__.remote()
114
-
115
- @observation_space.setter
116
- def observation_space(self, value):
117
- self._observation_space = value
118
-
119
- @ray.remote(num_cpus=1)
120
- class _RemoteSingleAgentEnv:
121
- """Wrapper class for making a gym env a remote actor."""
122
-
123
- def __init__(self, make_env, i, env_per_worker):
124
- self.env = MyDummyVecEnv([lambda: make_env((i * env_per_worker) + k) for k in range(env_per_worker)])
125
-
126
- def reset(self):
127
- return self.env.reset(), 0, False, {}
128
-
129
- def step(self, actions):
130
- return self.env.step(actions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MyVecEnv.py DELETED
@@ -1,47 +0,0 @@
1
- from typing import Optional, List, Union, Sequence, Type, Any
2
-
3
- import gym
4
- import numpy as np
5
- from ray.rllib import BaseEnv
6
- from stable_baselines3.common.vec_env import VecEnv
7
- from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices, VecEnvStepReturn, VecEnvObs
8
-
9
- from MyRemoteVectorEnv import MyRemoteVectorEnv
10
-
11
-
12
- class WrapperRay(VecEnv):
13
-
14
- def __init__(self, make_env, num_workers, per_worker_env, device):
15
- self.one_env = make_env(0)
16
- self.remote: BaseEnv = MyRemoteVectorEnv(make_env, num_workers, per_worker_env, self.one_env.observation_space, device)
17
- super(WrapperRay, self).__init__(num_workers * per_worker_env, self.one_env.observation_space, self.one_env.action_space)
18
-
19
- def reset(self) -> VecEnvObs:
20
- return self.remote.poll()[0]
21
-
22
- def step_async(self, actions: np.ndarray) -> None:
23
- self.remote.send_actions(actions)
24
-
25
- def step_wait(self) -> VecEnvStepReturn:
26
- return self.remote.poll()
27
-
28
- def close(self) -> None:
29
- self.remote.stop()
30
-
31
- def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
32
- pass
33
-
34
- def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
35
- pass
36
-
37
- def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
38
- pass
39
-
40
- def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
41
- pass
42
-
43
- def get_images(self) -> Sequence[np.ndarray]:
44
- pass
45
-
46
- def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
47
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Network.py DELETED
@@ -1,114 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
- import torch
5
- from torch import nn, Tensor
6
- from torch.distributions import Categorical
7
-
8
-
9
- class PositionalEncoding(nn.Module):
10
-
11
- def __init__(self, d_model: int, max_len: int = 100):
12
- super().__init__()
13
- position = torch.arange(max_len).unsqueeze(1)
14
- div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
15
- pe = torch.zeros(max_len, d_model)
16
- pe[:, 0::2] = torch.sin(position * div_term)
17
- pe[:, 1::2] = torch.cos(position * div_term)
18
- self.register_buffer('pe', pe)
19
-
20
- def forward(self, positions: Tensor) -> Tensor:
21
- return self.pe[positions]
22
-
23
-
24
- class Actor(nn.Module):
25
-
26
- def __init__(self, pos_encoder):
27
- super(Actor, self).__init__()
28
- self.activation = nn.Tanh()
29
- self.project = nn.Linear(4, 8)
30
- nn.init.xavier_uniform_(self.project.weight, gain=1.0)
31
- nn.init.constant_(self.project.bias, 0)
32
- self.pos_encoder = pos_encoder
33
-
34
- self.embedding_fixed = nn.Embedding(2, 1)
35
- self.embedding_legal_op = nn.Embedding(2, 1)
36
-
37
- self.tokens_start_end = nn.Embedding(3, 4)
38
-
39
- # self.conv_transform = nn.Conv1d(5, 1, 1)
40
- # nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
41
- # nn.init.constant_(self.conv_transform.bias, 0)
42
-
43
- self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
44
- norm_first=True)
45
- self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
46
- norm_first=True)
47
-
48
- self.final_tmp = nn.Sequential(
49
- layer_init_tanh(nn.Linear(8, 32)),
50
- nn.Tanh(),
51
- layer_init_tanh(nn.Linear(32, 1), std=0.01)
52
- )
53
- self.no_op = nn.Sequential(
54
- layer_init_tanh(nn.Linear(8, 32)),
55
- nn.Tanh(),
56
- layer_init_tanh(nn.Linear(32, 1), std=0.01)
57
- )
58
-
59
- def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
60
- embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
61
- self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
62
- non_zero_tokens = tokens_start_end != 0
63
- t = tokens_start_end[non_zero_tokens].long()
64
- embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
65
- pos_encoder = self.pos_encoder(indexes_inter.long())
66
- pos_encoder[non_zero_tokens] = 0
67
- obs = self.project(embedded_obs) + pos_encoder
68
-
69
- transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
70
- attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
71
- transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
72
- transformed_obs = transformed_obs.view(obs.shape)
73
- obs = transformed_obs.mean(dim=2)
74
-
75
- job_resource = job_resource[:, :-1, :-1] == 0
76
-
77
- obs_action = self.enc2(obs, src_mask=job_resource) + obs
78
-
79
- logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
80
- return logits.masked_fill(mask == 0, -3.4028234663852886e+38)
81
-
82
-
83
- class Agent(nn.Module):
84
- def __init__(self):
85
- super(Agent, self).__init__()
86
- self.pos_encoder = PositionalEncoding(8)
87
- self.actor = Actor(self.pos_encoder)
88
-
89
- def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
90
- action=None):
91
- logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
92
- probs = Categorical(logits=logits)
93
- if action is None:
94
- probabilities = probs.probs
95
- actions = torch.multinomial(probabilities, probabilities.shape[1])
96
- return actions, torch.log(probabilities), probs.entropy()
97
- else:
98
- return logits, probs.log_prob(action), probs.entropy()
99
-
100
- def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
101
- logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
102
- probs = Categorical(logits=logits)
103
- return probs.sample()
104
-
105
- def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
106
- logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
107
- return logits
108
-
109
-
110
- def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
111
- torch.nn.init.orthogonal_(layer.weight, std)
112
- if layer.bias is not None:
113
- torch.nn.init.constant_(layer.bias, bias_const)
114
- return layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -3,12 +3,10 @@ import os
3
  import random
4
  import time
5
 
6
- import gym
7
  import plotly.figure_factory as ff
8
  import json
9
 
10
  import pandas as pd
11
- import ray
12
 
13
  from compiled_jss.CPEnv import CompiledJssEnvCP
14
 
@@ -19,7 +17,6 @@ import torch
19
  import numpy as np
20
 
21
  from MyDummyVecEnv import MyDummyVecEnv
22
- from MyVecEnv import WrapperRay
23
 
24
  import gradio as gr
25
 
 
3
  import random
4
  import time
5
 
 
6
  import plotly.figure_factory as ff
7
  import json
8
 
9
  import pandas as pd
 
10
 
11
  from compiled_jss.CPEnv import CompiledJssEnvCP
12
 
 
17
  import numpy as np
18
 
19
  from MyDummyVecEnv import MyDummyVecEnv
 
20
 
21
  import gradio as gr
22
 
checkpoint.pt DELETED
Binary file (75.6 kB)