Patrick WAN
commited on
Commit
·
52933b5
1
Parent(s):
faa571f
initial commit
Browse files- CITATION.cff +30 -0
- LICENSE +339 -0
- demo/cvrp_search.ipynb +411 -0
- demo/tsp_search.ipynb +397 -0
- environment.yml +113 -0
- envs/cvrp_vector_env.py +144 -0
- envs/tsp_data.py +43 -0
- envs/tsp_vector_env.py +113 -0
- envs/vrp_data.py +57 -0
- models/attention_model_wrapper.py +152 -0
- models/nets/__init__.py +0 -0
- models/nets/attention_model/__init__.py +0 -0
- models/nets/attention_model/attention_model.py +101 -0
- models/nets/attention_model/context.py +232 -0
- models/nets/attention_model/decoder.py +211 -0
- models/nets/attention_model/dynamic_embedding.py +73 -0
- models/nets/attention_model/embedding.py +185 -0
- models/nets/attention_model/encoder.py +128 -0
- models/nets/attention_model/multi_head_attention.py +188 -0
- ppo.py +311 -0
- ppo_or.py +397 -0
- runs/cvrp-v0__ppo_or__1__1678159979/ckpt/12000.pt +3 -0
- runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt +3 -0
- wrappers/recordWrapper.py +62 -0
- wrappers/syncVectorEnvPomo.py +195 -0
CITATION.cff
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
message: "If you use this software, please cite it as below."
|
3 |
+
authors:
|
4 |
+
- family-names: "WAN"
|
5 |
+
given-names: "Ching Pui"
|
6 |
+
orcid: "https://orcid.org/0000-0002-6217-5418"
|
7 |
+
- family-names: "LI"
|
8 |
+
given-names: "Tung"
|
9 |
+
- family-names: "WANG"
|
10 |
+
given-names: "Jason Min"
|
11 |
+
title: "RLOR: A Flexible Framework of Deep Reinforcement Learning for Operation Research"
|
12 |
+
version: 1.0.0
|
13 |
+
doi: 10.5281/zenodo.1234
|
14 |
+
date-released: 2023-03-23
|
15 |
+
url: "https://github.com/cpwan/RLOR"
|
16 |
+
preferred-citation:
|
17 |
+
type: misc
|
18 |
+
authors:
|
19 |
+
- family-names: "WAN"
|
20 |
+
given-names: "Ching Pui"
|
21 |
+
orcid: "https://orcid.org/0000-0002-6217-5418"
|
22 |
+
- family-names: "LI"
|
23 |
+
given-names: "Tung"
|
24 |
+
- family-names: "WANG"
|
25 |
+
given-names: "Jason Min"
|
26 |
+
doi: 10.48550/arXiv.2303.13117
|
27 |
+
title: "RLOR: A Flexible Framework of Deep Reinforcement Learning for Operation Research"
|
28 |
+
year: 2023
|
29 |
+
eprint: "arXiv:2303.13117"
|
30 |
+
url : "http://arxiv.org/abs/2303.13117"
|
LICENSE
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Ching Pui Wan (cpwan), Tung Li (TonyLiHK)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
23 |
+
--------------------------------------------------------------------------------
|
24 |
+
MIT License
|
25 |
+
|
26 |
+
Copyright (c) 2018 Wouter Kool
|
27 |
+
|
28 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
29 |
+
of this software and associated documentation files (the "Software"), to deal
|
30 |
+
in the Software without restriction, including without limitation the rights
|
31 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
32 |
+
copies of the Software, and to permit persons to whom the Software is
|
33 |
+
furnished to do so, subject to the following conditions:
|
34 |
+
|
35 |
+
The above copyright notice and this permission notice shall be included in all
|
36 |
+
copies or substantial portions of the Software.
|
37 |
+
|
38 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
39 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
40 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
41 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
42 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
43 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
44 |
+
SOFTWARE.
|
45 |
+
|
46 |
+
--------------------------------------------------------------------------------
|
47 |
+
MIT License
|
48 |
+
|
49 |
+
Copyright (c) 2019 CleanRL developers
|
50 |
+
|
51 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
52 |
+
of this software and associated documentation files (the "Software"), to deal
|
53 |
+
in the Software without restriction, including without limitation the rights
|
54 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
55 |
+
copies of the Software, and to permit persons to whom the Software is
|
56 |
+
furnished to do so, subject to the following conditions:
|
57 |
+
|
58 |
+
The above copyright notice and this permission notice shall be included in all
|
59 |
+
copies or substantial portions of the Software.
|
60 |
+
|
61 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
62 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
63 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
64 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
65 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
66 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
67 |
+
SOFTWARE.
|
68 |
+
|
69 |
+
|
70 |
+
--------------------------------------------------------------------------------
|
71 |
+
|
72 |
+
Code in `cleanrl/ppo_procgen.py` and `cleanrl/ppg_procgen.py` are adapted from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py
|
73 |
+
|
74 |
+
**NOTE: the original repo did not fill out the copyright section in their license
|
75 |
+
so the following copyright notice is copied as is per the license requirement.
|
76 |
+
See https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/LICENSE#L190
|
77 |
+
|
78 |
+
|
79 |
+
Copyright [yyyy] [name of copyright owner]
|
80 |
+
|
81 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
82 |
+
you may not use this file except in compliance with the License.
|
83 |
+
You may obtain a copy of the License at
|
84 |
+
|
85 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
86 |
+
|
87 |
+
Unless required by applicable law or agreed to in writing, software
|
88 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
89 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
90 |
+
See the License for the specific language governing permissions and
|
91 |
+
limitations under the License.
|
92 |
+
|
93 |
+
--------------------------------------------------------------------------------
|
94 |
+
Code in `cleanrl/ddpg_continuous_action.py` and `cleanrl/td3_continuous_action.py` are adapted from https://github.com/sfujim/TD3
|
95 |
+
|
96 |
+
|
97 |
+
MIT License
|
98 |
+
|
99 |
+
Copyright (c) 2020 Scott Fujimoto
|
100 |
+
|
101 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
102 |
+
of this software and associated documentation files (the "Software"), to deal
|
103 |
+
in the Software without restriction, including without limitation the rights
|
104 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
105 |
+
copies of the Software, and to permit persons to whom the Software is
|
106 |
+
furnished to do so, subject to the following conditions:
|
107 |
+
|
108 |
+
The above copyright notice and this permission notice shall be included in all
|
109 |
+
copies or substantial portions of the Software.
|
110 |
+
|
111 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
112 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
113 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
114 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
115 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
116 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
117 |
+
SOFTWARE.
|
118 |
+
|
119 |
+
--------------------------------------------------------------------------------
|
120 |
+
Code in `cleanrl/sac_continuous_action.py` is inspired and adapted from [haarnoja/sac](https://github.com/haarnoja/sac), [openai/spinningup](https://github.com/openai/spinningup), [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic), [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3), and [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac).
|
121 |
+
|
122 |
+
- [haarnoja/sac](https://github.com/haarnoja/sac/blob/8258e33633c7e37833cc39315891e77adfbe14b2/LICENSE.txt)
|
123 |
+
|
124 |
+
COPYRIGHT
|
125 |
+
|
126 |
+
All contributions by the University of California:
|
127 |
+
Copyright (c) 2017, 2018 The Regents of the University of California (Regents)
|
128 |
+
All rights reserved.
|
129 |
+
|
130 |
+
All other contributions:
|
131 |
+
Copyright (c) 2017, 2018, the respective contributors
|
132 |
+
All rights reserved.
|
133 |
+
|
134 |
+
SAC uses a shared copyright model: each contributor holds copyright over
|
135 |
+
their contributions to the SAC codebase. The project versioning records all such
|
136 |
+
contribution and copyright details. If a contributor wants to further mark
|
137 |
+
their specific copyright on a particular contribution, they should indicate
|
138 |
+
their copyright solely in the commit message of the change when it is
|
139 |
+
committed.
|
140 |
+
|
141 |
+
LICENSE
|
142 |
+
|
143 |
+
Redistribution and use in source and binary forms, with or without
|
144 |
+
modification, are permitted provided that the following conditions are met:
|
145 |
+
|
146 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
147 |
+
list of conditions and the following disclaimer.
|
148 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
149 |
+
this list of conditions and the following disclaimer in the documentation
|
150 |
+
and/or other materials provided with the distribution.
|
151 |
+
|
152 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
153 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
154 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
155 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
156 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
157 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
158 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
159 |
+
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
160 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
161 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
162 |
+
|
163 |
+
CONTRIBUTION AGREEMENT
|
164 |
+
|
165 |
+
By contributing to the SAC repository through pull-request, comment,
|
166 |
+
or otherwise, the contributor releases their content to the
|
167 |
+
license and copyright terms herein.
|
168 |
+
|
169 |
+
- [openai/spinningup](https://github.com/openai/spinningup/blob/038665d62d569055401d91856abb287263096178/LICENSE)
|
170 |
+
|
171 |
+
The MIT License
|
172 |
+
|
173 |
+
Copyright (c) 2018 OpenAI (http://openai.com)
|
174 |
+
|
175 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
176 |
+
of this software and associated documentation files (the "Software"), to deal
|
177 |
+
in the Software without restriction, including without limitation the rights
|
178 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
179 |
+
copies of the Software, and to permit persons to whom the Software is
|
180 |
+
furnished to do so, subject to the following conditions:
|
181 |
+
|
182 |
+
The above copyright notice and this permission notice shall be included in
|
183 |
+
all copies or substantial portions of the Software.
|
184 |
+
|
185 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
186 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
187 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
188 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
189 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
190 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
191 |
+
THE SOFTWARE.
|
192 |
+
|
193 |
+
- [DLR-RM/stable-baselines3](https://github.com/DLR-RM/stable-baselines3/blob/44e53ff8115e8f4bff1d5218f10c8c7d1a4cfc12/LICENSE)
|
194 |
+
|
195 |
+
The MIT License
|
196 |
+
|
197 |
+
Copyright (c) 2019 Antonin Raffin
|
198 |
+
|
199 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
200 |
+
of this software and associated documentation files (the "Software"), to deal
|
201 |
+
in the Software without restriction, including without limitation the rights
|
202 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
203 |
+
copies of the Software, and to permit persons to whom the Software is
|
204 |
+
furnished to do so, subject to the following conditions:
|
205 |
+
|
206 |
+
The above copyright notice and this permission notice shall be included in
|
207 |
+
all copies or substantial portions of the Software.
|
208 |
+
|
209 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
210 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
211 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
212 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
213 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
214 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
215 |
+
THE SOFTWARE.
|
216 |
+
|
217 |
+
- [denisyarats/pytorch_sac](https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/LICENSE)
|
218 |
+
|
219 |
+
MIT License
|
220 |
+
|
221 |
+
Copyright (c) 2019 Denis Yarats
|
222 |
+
|
223 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
224 |
+
of this software and associated documentation files (the "Software"), to deal
|
225 |
+
in the Software without restriction, including without limitation the rights
|
226 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
227 |
+
copies of the Software, and to permit persons to whom the Software is
|
228 |
+
furnished to do so, subject to the following conditions:
|
229 |
+
|
230 |
+
The above copyright notice and this permission notice shall be included in all
|
231 |
+
copies or substantial portions of the Software.
|
232 |
+
|
233 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
234 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
235 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
236 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
237 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
238 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
239 |
+
SOFTWARE.
|
240 |
+
|
241 |
+
- [pranz24/pytorch-soft-actor-critic](https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/LICENSE)
|
242 |
+
|
243 |
+
MIT License
|
244 |
+
|
245 |
+
Copyright (c) 2018 Pranjal Tandon
|
246 |
+
|
247 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
248 |
+
of this software and associated documentation files (the "Software"), to deal
|
249 |
+
in the Software without restriction, including without limitation the rights
|
250 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
251 |
+
copies of the Software, and to permit persons to whom the Software is
|
252 |
+
furnished to do so, subject to the following conditions:
|
253 |
+
|
254 |
+
The above copyright notice and this permission notice shall be included in all
|
255 |
+
copies or substantial portions of the Software.
|
256 |
+
|
257 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
258 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
259 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
260 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
261 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
262 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
263 |
+
SOFTWARE.
|
264 |
+
|
265 |
+
|
266 |
+
---------------------------------------------------------------------------------
|
267 |
+
The CONTRIBUTING.md is adopted from https://github.com/entity-neural-network/incubator/blob/2a0c38b30828df78c47b0318c76a4905020618dd/CONTRIBUTING.md
|
268 |
+
and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/CONTRIBUTING.md
|
269 |
+
|
270 |
+
MIT License
|
271 |
+
|
272 |
+
Copyright (c) 2021 Entity Neural Network developers
|
273 |
+
|
274 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
275 |
+
of this software and associated documentation files (the "Software"), to deal
|
276 |
+
in the Software without restriction, including without limitation the rights
|
277 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
278 |
+
copies of the Software, and to permit persons to whom the Software is
|
279 |
+
furnished to do so, subject to the following conditions:
|
280 |
+
|
281 |
+
The above copyright notice and this permission notice shall be included in all
|
282 |
+
copies or substantial portions of the Software.
|
283 |
+
|
284 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
285 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
286 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
287 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
288 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
289 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
290 |
+
SOFTWARE.
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
MIT License
|
295 |
+
|
296 |
+
Copyright (c) 2020 Stable-Baselines Team
|
297 |
+
|
298 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
299 |
+
of this software and associated documentation files (the "Software"), to deal
|
300 |
+
in the Software without restriction, including without limitation the rights
|
301 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
302 |
+
copies of the Software, and to permit persons to whom the Software is
|
303 |
+
furnished to do so, subject to the following conditions:
|
304 |
+
|
305 |
+
The above copyright notice and this permission notice shall be included in all
|
306 |
+
copies or substantial portions of the Software.
|
307 |
+
|
308 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
309 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
310 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
311 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
312 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
313 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
314 |
+
SOFTWARE.
|
315 |
+
|
316 |
+
|
317 |
+
---------------------------------------------------------------------------------
|
318 |
+
The cleanrl/ppo_continuous_action_isaacgym.py is contributed by Nvidia
|
319 |
+
|
320 |
+
SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
321 |
+
SPDX-License-Identifier: MIT
|
322 |
+
|
323 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
324 |
+
copy of this software and associated documentation files (the "Software"),
|
325 |
+
to deal in the Software without restriction, including without limitation
|
326 |
+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
327 |
+
and/or sell copies of the Software, and to permit persons to whom the
|
328 |
+
Software is furnished to do so, subject to the following conditions:
|
329 |
+
|
330 |
+
The above copyright notice and this permission notice shall be included in
|
331 |
+
all copies or substantial portions of the Software.
|
332 |
+
|
333 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
334 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
335 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
336 |
+
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
337 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
338 |
+
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
339 |
+
DEALINGS IN THE SOFTWARE.
|
demo/cvrp_search.ipynb
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "3f3052cd",
|
7 |
+
"metadata": {
|
8 |
+
"colab": {
|
9 |
+
"base_uri": "https://localhost:8080/"
|
10 |
+
},
|
11 |
+
"id": "3f3052cd",
|
12 |
+
"outputId": "78d129fd-0956-4f88-ae39-def9953a982e"
|
13 |
+
},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"output_type": "stream",
|
17 |
+
"name": "stdout",
|
18 |
+
"text": [
|
19 |
+
"Cloning into 'RLOR'...\n",
|
20 |
+
"remote: Enumerating objects: 52, done.\u001b[K\n",
|
21 |
+
"remote: Counting objects: 100% (52/52), done.\u001b[K\n",
|
22 |
+
"remote: Compressing objects: 100% (35/35), done.\u001b[K\n",
|
23 |
+
"remote: Total 52 (delta 12), reused 52 (delta 12), pack-reused 0\u001b[K\n",
|
24 |
+
"Unpacking objects: 100% (52/52), 5.19 MiB | 4.39 MiB/s, done.\n",
|
25 |
+
"/content/RLOR\n"
|
26 |
+
]
|
27 |
+
}
|
28 |
+
],
|
29 |
+
"source": [
|
30 |
+
"!git clone https://github.com/cpwan/RLOR\n",
|
31 |
+
"%cd RLOR"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 2,
|
37 |
+
"id": "f01dfb64",
|
38 |
+
"metadata": {
|
39 |
+
"id": "f01dfb64"
|
40 |
+
},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"import numpy as np\n",
|
44 |
+
"import torch\n",
|
45 |
+
"import gym"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"id": "985bf6e6",
|
51 |
+
"metadata": {
|
52 |
+
"id": "985bf6e6"
|
53 |
+
},
|
54 |
+
"source": [
|
55 |
+
"# Define our agent"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 3,
|
61 |
+
"id": "953a7fde",
|
62 |
+
"metadata": {
|
63 |
+
"colab": {
|
64 |
+
"base_uri": "https://localhost:8080/"
|
65 |
+
},
|
66 |
+
"id": "953a7fde",
|
67 |
+
"outputId": "9b37d746-b9a0-4d53-c12a-b2445ed6bd9d"
|
68 |
+
},
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"output_type": "execute_result",
|
72 |
+
"data": {
|
73 |
+
"text/plain": [
|
74 |
+
"<All keys matched successfully>"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
"metadata": {},
|
78 |
+
"execution_count": 3
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"from models.attention_model_wrapper import Agent\n",
|
83 |
+
"device = 'cuda'\n",
|
84 |
+
"ckpt_path = './runs/cvrp-v0__ppo_or__1__1678159979/ckpt/12000.pt'\n",
|
85 |
+
"agent = Agent(device=device, name='cvrp').to(device)\n",
|
86 |
+
"agent.load_state_dict(torch.load(ckpt_path))"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "markdown",
|
91 |
+
"id": "2cbaa255",
|
92 |
+
"metadata": {
|
93 |
+
"id": "2cbaa255"
|
94 |
+
},
|
95 |
+
"source": [
|
96 |
+
"# Define our environment\n",
|
97 |
+
"## CVRP\n",
|
98 |
+
"Given a depot, n nodes with their demands, and the capacity of the vehicle, \n",
|
99 |
+
"find the shortest path that fulfills the demand of every node.\n"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": 4,
|
105 |
+
"id": "81fd7b68",
|
106 |
+
"metadata": {
|
107 |
+
"colab": {
|
108 |
+
"base_uri": "https://localhost:8080/"
|
109 |
+
},
|
110 |
+
"id": "81fd7b68",
|
111 |
+
"outputId": "f4a9d9d8-29f7-413d-b462-d27f04a0153a"
|
112 |
+
},
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"output_type": "stream",
|
116 |
+
"name": "stderr",
|
117 |
+
"text": [
|
118 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:31: UserWarning: \u001b[33mWARN: A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (50, 2)\u001b[0m\n",
|
119 |
+
" logger.warn(\n",
|
120 |
+
"/usr/local/lib/python3.9/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
121 |
+
" deprecation(\n",
|
122 |
+
"/usr/local/lib/python3.9/dist-packages/gym/wrappers/step_api_compatibility.py:39: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
123 |
+
" deprecation(\n",
|
124 |
+
"/usr/local/lib/python3.9/dist-packages/gym/vector/vector_env.py:56: DeprecationWarning: \u001b[33mWARN: Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
125 |
+
" deprecation(\n"
|
126 |
+
]
|
127 |
+
}
|
128 |
+
],
|
129 |
+
"source": [
|
130 |
+
"from wrappers.syncVectorEnvPomo import SyncVectorEnv\n",
|
131 |
+
"from wrappers.recordWrapper import RecordEpisodeStatistics\n",
|
132 |
+
"\n",
|
133 |
+
"env_id = 'cvrp-v0'\n",
|
134 |
+
"env_entry_point = 'envs.cvrp_vector_env:CVRPVectorEnv'\n",
|
135 |
+
"seed = 0\n",
|
136 |
+
"\n",
|
137 |
+
"gym.envs.register(\n",
|
138 |
+
" id=env_id,\n",
|
139 |
+
" entry_point=env_entry_point,\n",
|
140 |
+
")\n",
|
141 |
+
"\n",
|
142 |
+
"def make_env(env_id, seed, cfg={}):\n",
|
143 |
+
" def thunk():\n",
|
144 |
+
" env = gym.make(env_id, **cfg)\n",
|
145 |
+
" env = RecordEpisodeStatistics(env)\n",
|
146 |
+
" env.seed(seed)\n",
|
147 |
+
" env.action_space.seed(seed)\n",
|
148 |
+
" env.observation_space.seed(seed)\n",
|
149 |
+
" return env\n",
|
150 |
+
" return thunk\n",
|
151 |
+
"\n",
|
152 |
+
"envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=50))])"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "markdown",
|
157 |
+
"id": "c363d489",
|
158 |
+
"metadata": {
|
159 |
+
"id": "c363d489"
|
160 |
+
},
|
161 |
+
"source": [
|
162 |
+
"# Inference\n",
|
163 |
+
"We use the Multi-Greedy search strategy: running greedy sampling with different starting nodes"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 5,
|
169 |
+
"id": "bbee9e3c",
|
170 |
+
"metadata": {
|
171 |
+
"id": "bbee9e3c",
|
172 |
+
"outputId": "5632253d-9a70-433b-c35a-3d97e478da0d",
|
173 |
+
"colab": {
|
174 |
+
"base_uri": "https://localhost:8080/"
|
175 |
+
}
|
176 |
+
},
|
177 |
+
"outputs": [
|
178 |
+
{
|
179 |
+
"output_type": "stream",
|
180 |
+
"name": "stderr",
|
181 |
+
"text": [
|
182 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:174: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.\u001b[0m\n",
|
183 |
+
" logger.warn(\n",
|
184 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:190: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.\u001b[0m\n",
|
185 |
+
" logger.warn(\n",
|
186 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:195: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[0m\n",
|
187 |
+
" logger.warn(\n",
|
188 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n",
|
189 |
+
" logger.warn(\n",
|
190 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method is not within the observation space.\u001b[0m\n",
|
191 |
+
" logger.warn(f\"{pre} is not within the observation space.\")\n",
|
192 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n",
|
193 |
+
" logger.deprecation(\n",
|
194 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:234: UserWarning: \u001b[33mWARN: Expects `done` signal to be a boolean, actual type: <class 'numpy.ndarray'>\u001b[0m\n",
|
195 |
+
" logger.warn(\n",
|
196 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n",
|
197 |
+
" logger.warn(\n",
|
198 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method is not within the observation space.\u001b[0m\n",
|
199 |
+
" logger.warn(f\"{pre} is not within the observation space.\")\n",
|
200 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:260: UserWarning: \u001b[33mWARN: The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: <class 'numpy.ndarray'>\u001b[0m\n",
|
201 |
+
" logger.warn(\n"
|
202 |
+
]
|
203 |
+
}
|
204 |
+
],
|
205 |
+
"source": [
|
206 |
+
"trajectories = []\n",
|
207 |
+
"agent.eval()\n",
|
208 |
+
"obs = envs.reset()\n",
|
209 |
+
"done = np.array([False])\n",
|
210 |
+
"while not done.all():\n",
|
211 |
+
" # ALGO LOGIC: action logic\n",
|
212 |
+
" with torch.no_grad():\n",
|
213 |
+
" action, logits = agent(obs)\n",
|
214 |
+
" if trajectories==[]: # Multi-greedy inference\n",
|
215 |
+
" action = torch.arange(1, envs.n_traj + 1).repeat(1, 1)\n",
|
216 |
+
" \n",
|
217 |
+
" obs, reward, done, info = envs.step(action.cpu().numpy())\n",
|
218 |
+
" trajectories.append(action.cpu().numpy())"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 6,
|
224 |
+
"id": "f0fbf6fd",
|
225 |
+
"metadata": {
|
226 |
+
"id": "f0fbf6fd"
|
227 |
+
},
|
228 |
+
"outputs": [],
|
229 |
+
"source": [
|
230 |
+
"nodes_coordinates = np.vstack([obs['depot'],obs['observations'][0]])\n",
|
231 |
+
"final_return = info[0]['episode']['r']\n",
|
232 |
+
"best_traj = np.argmax(final_return)\n",
|
233 |
+
"resulting_traj = np.array(trajectories)[:,0,best_traj]\n",
|
234 |
+
"resulting_traj_with_depot = np.hstack([np.zeros(1,dtype = int),resulting_traj])"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "markdown",
|
239 |
+
"source": [
|
240 |
+
"## Results"
|
241 |
+
],
|
242 |
+
"metadata": {
|
243 |
+
"id": "ViNGfd1PQwlw"
|
244 |
+
},
|
245 |
+
"id": "ViNGfd1PQwlw"
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "code",
|
249 |
+
"execution_count": 7,
|
250 |
+
"id": "dff29ef4",
|
251 |
+
"metadata": {
|
252 |
+
"colab": {
|
253 |
+
"base_uri": "https://localhost:8080/"
|
254 |
+
},
|
255 |
+
"id": "dff29ef4",
|
256 |
+
"outputId": "8a57a330-b340-4d60-dc83-2a1ea548c7d0"
|
257 |
+
},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"output_type": "stream",
|
261 |
+
"name": "stdout",
|
262 |
+
"text": [
|
263 |
+
"A route of length -11.283475875854492\n",
|
264 |
+
"The route is:\n",
|
265 |
+
" [ 0 16 32 27 7 30 23 38 40 2 11 0 9 10 19 36 35 4 5 18 0 25 42 43\n",
|
266 |
+
" 48 21 37 31 50 0 13 8 41 17 12 46 0 3 6 44 0 22 26 49 34 33 28 0\n",
|
267 |
+
" 1 14 29 39 47 0 20 24 45 15 0 0]\n"
|
268 |
+
]
|
269 |
+
}
|
270 |
+
],
|
271 |
+
"source": [
|
272 |
+
"print(f'A route of length {final_return[best_traj]}')\n",
|
273 |
+
"print('The route is:\\n', resulting_traj_with_depot)"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "markdown",
|
278 |
+
"id": "1b78c529",
|
279 |
+
"metadata": {
|
280 |
+
"id": "1b78c529"
|
281 |
+
},
|
282 |
+
"source": [
|
283 |
+
"### Display it in a 2d-grid\n",
|
284 |
+
"- Darker color means later steps in the route.\n",
|
285 |
+
"- We abuse the errorbar to show the relative size of demand at each customer."
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 8,
|
291 |
+
"id": "dc681a06",
|
292 |
+
"metadata": {
|
293 |
+
"tags": [
|
294 |
+
"\"hide-cell\""
|
295 |
+
],
|
296 |
+
"cellView": "form",
|
297 |
+
"id": "dc681a06"
|
298 |
+
},
|
299 |
+
"outputs": [],
|
300 |
+
"source": [
|
301 |
+
"#@title Helper function for plotting\n",
|
302 |
+
"# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb\n",
|
303 |
+
"import matplotlib.pyplot as plt\n",
|
304 |
+
"from matplotlib.collections import LineCollection\n",
|
305 |
+
"from matplotlib.colors import ListedColormap, BoundaryNorm\n",
|
306 |
+
"\n",
|
307 |
+
"def make_segments(x, y):\n",
|
308 |
+
" '''\n",
|
309 |
+
" Create list of line segments from x and y coordinates, in the correct format for LineCollection:\n",
|
310 |
+
" an array of the form numlines x (points per line) x 2 (x and y) array\n",
|
311 |
+
" '''\n",
|
312 |
+
"\n",
|
313 |
+
" points = np.array([x, y]).T.reshape(-1, 1, 2)\n",
|
314 |
+
" segments = np.concatenate([points[:-1], points[1:]], axis=1)\n",
|
315 |
+
" \n",
|
316 |
+
" return segments\n",
|
317 |
+
"\n",
|
318 |
+
"def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):\n",
|
319 |
+
" '''\n",
|
320 |
+
" Plot a colored line with coordinates x and y\n",
|
321 |
+
" Optionally specify colors in the array z\n",
|
322 |
+
" Optionally specify a colormap, a norm function and a line width\n",
|
323 |
+
" '''\n",
|
324 |
+
" \n",
|
325 |
+
" # Default colors equally spaced on [0,1]:\n",
|
326 |
+
" if z is None:\n",
|
327 |
+
" z = np.linspace(0.3, 1.0, len(x))\n",
|
328 |
+
" \n",
|
329 |
+
" # Special case if a single number:\n",
|
330 |
+
" if not hasattr(z, \"__iter__\"): # to check for numerical input -- this is a hack\n",
|
331 |
+
" z = np.array([z])\n",
|
332 |
+
" \n",
|
333 |
+
" z = np.asarray(z)\n",
|
334 |
+
" \n",
|
335 |
+
" segments = make_segments(x, y)\n",
|
336 |
+
" lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)\n",
|
337 |
+
" \n",
|
338 |
+
" ax = plt.gca()\n",
|
339 |
+
" ax.add_collection(lc)\n",
|
340 |
+
" \n",
|
341 |
+
" return lc\n",
|
342 |
+
"\n",
|
343 |
+
"def plot(coords, demand):\n",
|
344 |
+
" x,y = coords.T\n",
|
345 |
+
" lc = colorline(x,y,cmap='Reds')\n",
|
346 |
+
" plt.axis('square')\n",
|
347 |
+
" x, y =obs['observations'][0].T\n",
|
348 |
+
" h = obs['demand']/4\n",
|
349 |
+
" h = np.vstack([h*0,h])\n",
|
350 |
+
" plt.errorbar(x,y,h,fmt='None',elinewidth=2)\n",
|
351 |
+
" return lc"
|
352 |
+
]
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"cell_type": "code",
|
356 |
+
"execution_count": 9,
|
357 |
+
"id": "aa5e32f2",
|
358 |
+
"metadata": {
|
359 |
+
"colab": {
|
360 |
+
"base_uri": "https://localhost:8080/",
|
361 |
+
"height": 282
|
362 |
+
},
|
363 |
+
"id": "aa5e32f2",
|
364 |
+
"outputId": "8da68d19-138b-4e3e-c481-02e06416c5e7"
|
365 |
+
},
|
366 |
+
"outputs": [
|
367 |
+
{
|
368 |
+
"output_type": "execute_result",
|
369 |
+
"data": {
|
370 |
+
"text/plain": [
|
371 |
+
"<matplotlib.collections.LineCollection at 0x7f6c20062b20>"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
"metadata": {},
|
375 |
+
"execution_count": 9
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"output_type": "display_data",
|
379 |
+
"data": {
|
380 |
+
"text/plain": [
|
381 |
+
"<Figure size 432x288 with 1 Axes>"
|
382 |
+
],
|
383 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD4CAYAAAAjDTByAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAABXs0lEQVR4nO2dd3RURRfAf/N203ujJ4TepCdIlSbSFLBRbCgqYkUU6U1QBAEFu4gVK/ipgIANQRSlhN4hCSUJJUACIT27b74/NoEkpOwmb7MbeL9zOGfz3rw7d0Pum5k7d+4VUkp0dHRuPBRHK6Cjo+MYdOPX0blB0Y1fR+cGRTd+HZ0bFN34dXRuUIyO6jg4OFiGh4c7qnsdnRuC7du3n5dShhR1z2HGHx4eTlRUlKO619G5IRBCnCjuXqnTfiHEJ0KIRCHEvmLuCyHEW0KIaCHEHiFEm/Ioq6OjUzFYs+b/DOhTwv2+QIPcfyOB98uvlo6Ojr0p1fillBuBpBKaDAS+kBY2A/5CiOpaKaijo2MftPD21wTi8v0cn3vtGoQQI4UQUUKIqHPnzmnQtY6OTlmp0K0+KeViKWWElDIiJKRIB6SOjk4FoYXxJwCh+X6ulXtNR0fHidHC+FcCD+V6/dsDl6SUpzWQq6OjY0dK3ecXQnwDdAOChRDxwHTABUBK+QGwBugHRAPpwCP2UlanchE+YTUAx+f0d7AmOkVRqvFLKYeVcl8CT2umkY6OToVww8f2h09YfWWE0tG5kbjhjV9H50ZFN34dnRsU3fh1dG5QdOPXsTu6X8U50Y0/l8xfV3MjZDLWDVEnjxva+NOWf82W3a9zoEMWGV9/zsXH7idn725Hq6WjUyE4LJmHI5FSkrr4HTLW/kzQR19irBWKxx2DyFq7kpSJY3Bp3Ravp8dgqFbD0arq6NiNG27kl2YzKa+9TObG9QR9/BXGWpZjCUJRcO8/iMDlP2MIrU3yg/eQ9sEiZHqagzXW0bEPN9TIL7OyuDjlJdTUywR9+AWKt/c1bYSHJ14jn8F94D2kvb+IpMG34znyWdz7D0QYDA7QuvIgpSQrJobL//zL5U2b+HnvPpQaNUjdsYvIY0ccrZ5OIW4Y41dTL5P84jMo/gEELvoQ4epaYntD1Wr4zniNnAN7SX1zLpnLv8Lr+fG4tm1XQRpXDrLj47m86V8u//Mvqf/+CwYjPp074t+3D6GzZpJ9/jzRz4xGCOFoVXUKcUMYv/n8OZKeewLXFi3xfWmKTSO4S9Pm+C9eSvafv3F51hSMDRvj/eyLGEJr21Fj5yXn/HlSN/3H5U2buLzpX9TUNLw7dcSnU0eqj30B17DQAoZ+eedO3MNCS5Co4yiue+M3xZ8k6enH8Og/EO/HnyrTCCSEwK1nb1w7dyPju6UkP3of7v0G4DliFIqvnx20dh7MKSmkbt5yZSqfc+o03u1vxrtzJ0JGPIJ7o4Yl/k6zTsbhFqobvzNyXRt/zqEDJD0/Cu/HnsLrnqHllifc3PB86DHcb7+TtMXvkDT4DrwefQL3OwcjjC4aaOx41IwMUrdFkbrpXy5v+pfMI0fxatsG704dCZv/Op7Nb0IYrf+zyYqLxy20lh011ikr163xZ0Vt4eLEF/AdPw2PW3trKlsJDMJnwnQ87hlG6qJ5ZHz/LV7PvYRrxy6Vbm0rc3JI27XLsmbf9C/pu/fg0awpPp06UmPyRLzatEZxcyuz/Ky4OHw7ttdQY/twI+YeuC6NP2Pdb6TMeRn/197ALeJmu/VjrN8Qv7cWk71pI2mLXifju6V4jx6HsV4Du/WpFWc/+JDL/2wibdt23MJr49O5E1WeGoX3ze0weHlp1o8+7XderjvjT/v+W1KXvEfg2x/h0rip3fsTQuDWuSuu7TuS+cMyLj49Ardut+I18hmUwCC7928rx2b35dCtt5GZ3YrgYUMJf2sRxsAAu/QlpbRM+3WHn1Pi1EE+tsShSym5vPhd0pZ+QtBHSyvE8PMjjC54DL6fwGU/I9zcSBo6gPSlHyOzsytUj9IQioL/rT0RZhP+/fvZzfABTMkXQVEw+l3fTtHKilMbv7VIs5mU12eRuWGdJWrPgdtwiq8f3mMm4L/kK3J27yRpyB1krfvVqQ4NVX1+NOnbt3P5r7/s2k9WXJw+6jsxld74ZXY2FyePxRQbQ9DizzEEO0c9AGNYOH7z38Fn8sukffohl0YNJ+dgkeUOKxzF05Oar8wiYcpU1IwMu/VjWe/rnn5npVIbv5qaStLoJ8BsJvCtxSjePo5W6RpcI9oT8Ply3PoPJGXsM6S8PBFz4llHq4VPt254tmrF2UWL7NZHZRn5ZWb61c/JN04lqUpr/OYL57kwajjG0Nr4z3kTUY7tKHsjDAY8BtxNwLLVGKpUI/mBu0hb8h4yI730h+1I9WlTSVq2nIwDB+wi37LH79zGL6P3or4/lZjIeGI6JmL+YDqmj2ahbvkDmXbZ0erZlUpp/Kb4OC48dj/uXbrhO3F6pTlwo3h54fXkaAI+X475eCxJg+8gc80KpKo6RB+XkBCqvfQS8RMnIs1mzeU787RfZmWgrvoMddVnKANGYLj7CQwDHsHw0lsot9yOPHEE88KxmJfOR939LzI709Eqa06lM/6cI4e48PgDeN03HJ8nnq10QTUAhuo18H1lPr6zF5Dx/bdcHDGUnF3bHaJL4JDBKK5uXFj6peays+LinTKuXx47iPrBNFDNKE/OQtRrduWeMBpRGrXGMPgpDGMXIVp0QO7ehHneaMzL30M9vAtpNjlQe+2oVPv8Wdu3cnHCGHzHTcGjV19Hq1NuXJq3wv/jr8n6bQ0p08djbNrccmioRsWNlkJRqDl7NjGDB+Pb+zZcq2tTXd2yx+9cAT4yOwv5x3Lkoe0otw9HNGxVYnvh5o5o2QladkKmpSD3bUHduBJ+WIy4qR1Ki44QWh+hVLoxFKhExp+5/ncuzZ6B/6vzcWvXwdHqaIYQAvfe/XHr2pP0bz4n+eEhuA+4C8+HR1aYA9O9QX2CHnqQU9OnE754sSYyTecvoHi4YygiZ4IjkCePoq5YgqhZ1zLae9iml/DyRdzcC+XmXsjkROSezZhXfAI5WYgWHVBadEBUdZ4XnTVUildW+o/LuDR3FoFvLb6uDD8/wt0dr0eeIODrn5AXL5J0b38yfvgOaaqYKWaVp54iKzqaS7/+qom8zLg43Go5fr0vc7JRf/sWdfm7KLcORrnrCZsNvzAioApK1wEYnn0Nw33Pg2rG/MV8TO9MQt34M/LieW2UtzOVYuRP/ewjghZ/gTEs3NGq2B1DcAg+U2bhfvggaQvnkvH9N3g/9xKu7TvZtV/FzY2as2cT9/wYvDt2xOBTvllH1sk43MLCNNKubMiEWNSflkCVmiijZiK8fDWVL4SA6rUxVK+N7DUEThxG3fMf6vtTIaQmSsuOiGaRCE/n24IGEI6KPIuIiJBRUVEltjFduEDSmKcIeuNtDMFVKkgz50FKSfbGP0l7az6GsHC8nhuLsU49u/YZN248iocHNV+eUS45CW+/iykpmdrTp2ijmA1IUw7yr5XInRsRfe5DNGtXoY5haTIhj+5B7vkXGb0XUbsRokUHROM2CNeK3ZIWQmyXUkYUdc+pR36Dnx+mkyeQJu23oSqSsh4XFULg1rUnrh27kPH9N1wcNRy3Xn3xevxpFD9/O2gK1SdN5MitvQgYNBDP1q3LLCcrLh7Pxo001Mw65OkTltE+INgy2ntX/LkCYTQimrSBJm2QWRnIg9uRO/9GXfUZolFrRMuOiLrNHL5F7dRrfmE04t6lG5kb1jlaFYciXFzxHDacwO9WAZA05A7Sv/kcmaP9oSGjvz/Vp0yx7P3n5JRZjiW6r+Km/dJsQt3wE+qXCxAd+6AMec4hhl8Y4eaB0qozhuHjMIx+HVGrLur6HzHPew7zz18gTx61+tyH1gVXnNr4ATy630rG+hvb+PNQ/APwGTsZ//c/I2frZpKHDSLrrz81PzTkP3AAxpAQzn38cZllVGSAj0yMR/34FWR8DMoTM1BadnLK+A/h7YfS/jaMI6djGDkN4e2L+cclmN98EfMfy5GJCRWqj1XGL4ToI4Q4LISIFkJMKOJ+mBBivRBipxBijxCin1YKunXoRPb+vaiXLmolstJjrFMPvzffx3vsZNI+WMSlp0dgOnJQM/lCCGq++irnPviQrJMnbX5eqipZCQl23+OXqhn1759RP5+LiOiOcv8LCN9Au/apFSKwKkq3QRiem4Nh6HNgysH82VxM705B/Wc18lKS3XUo1fiFEAbgXaAv0BQYJoQofFh+CrBMStkaGAq8p5mCHh64RbYn42/7Hj+tjLi270TA0v/hdmsfLo5+gsuvTMV8XpuDKW5hYYSMeoKEyZNtnlnkJCZi9PbB4OmhiS5FIc+fRv1kNvLYAZTHp6O06eqUo31pCCEQNcIx9LkPw9iFKH3vQ54/g/ndyZg/mY0atR6ZnmqXvq0Z+dsB0VLKWCllNvAtMLBQGwnk7aP4Aae0UxE8uvckc/0fWoq8bhBGIx53DSFw+WqEnx/J9w0i7dMPkZnlj0UPefRRTOfOcXHFSpues2f2HqmqqP/9ivrJbESLjigPjkX4B9ulr4pGKApK3aYYBj2KYdxbiPa3IaP3Yn7jBcxfval5f9YYf00gLt/P8bnX8jMDeEAIEQ+sAZ4tSpAQYqQQIkoIEXXunPUjlPst3cjc8p8mf9DXK4q3D97PjiXg028xHT5I0pDby115WLi4UGvOHE6/8gqmixetfs5e632ZdBb18zmW8NzHpqC064kQTu+2KhPC6ILSNALD0OcwjF0EjVsTXeVPjr1ym2Z9aPWbGwZ8JqWsBfQDlooi/leklIullBFSyoiQEOuTbhj8A3Bt3JTMzf9qpO71i6FmKH5zFuIzY64mlYc9W7XCr39/Tr862+pntPb0S6mibv0DdckriMZtUYZPQARW1Uy+syPcPRD+IeDmoWmKeGuMPwHIP4erlXstP48CywCklP8B7oCmczH3HreSoU/9rca1dVv8P/0Wj7uHkDJxDClTX8J8pmyrsWpjXyT1779J3bzZqvZa5uqXF8+jLp2P3PMfyohJKB16V9qDNOVBHtqB0rjscRdFYc1vcRvQQAhRRwjhisWhV3gReBLoCSCEaILF+DVNieLRrSeZG9dXWKz79YBQFNz7DbRUHg4Lt1Qefn8RalrBysOl7R8bfHyo8fIMEiZNQrVi6ZWpwbRfSom6/S/Uj15G1G2GMmISIlibE4eVEXloJ6JxG01llmr8UkoT8AzwK3AQi1d/vxBiphBiQG6zF4HHhRC7gW+Ah6XGm8/GGjUxVK1O9u6dWoq9IRAenng9/jQBX/6AOfEMyYP7k7Hyf1cSeMTc5UpMtT9RV32GzMkqUoZf79641a9P4nulb+SUd9ovU5JQv34TuX09yvDxKJ37I5TKkbDFHshzpyE7E2qEayrXqWP7C5Py4buoKSn4vzTRTlrZh7MHDrGoQw88AgIIjWhNWGRbQiPaUCuiFR4OSGudc3AfqW/OhYx0vEaPwxCzBXFTe4jeizx7EuXOJxAhNa55Lvv0aY7260+9Zctwb1C/SNnSbGZrnYZEHtmP4u5uk15SSks8/G/fIdr1RHTujzA4dQR6haD+swaZmIDhrsdtfrbSxvYXxqPHrZwf/RR+YydUqj3doHp18A0MoN+8VzGbTMRt28FvM+eQsGsPfjWrExrRhtCI1oRGtqVmq+a4eNhvfxzApclN+H/4haXy8CtTMSjZeLftjWHACNi9CfXLeYie91qSVeTDtXp1qo4eTcKkSdT97tsi197ZZ85iDAyw3fBTL6H+/Bkkn0d54EVE9RuzCnJRyEM7EB37aC63Uhm/sX5DUBRyjhzGtVFjR6tjNUY3N7pPHMv2T77gkTU/0HroPQCYTSbOHjhEXNQO4qJ2EvXltyQePEJIo/q5L4Q2hEW2oWqzJhhsKI5pDXmVh11qViH9w4VcfGoE7n3vwPPRJ1Fq1EH94QPUE4cRve8rcBIt6MEHuPjTjyR9t4ygYdcWP806edKmyD4pJXL/VuQvXyPa3IK492l9tM+HzExHxsei1L9Jc9mVatoPcHHBHBQvb3xHPWMHreyHKTubN5q2ZcjnH1G7U/GFK3MyMzm1a2/uC8HyUrgYl0CNFjcRGtna8lKIbEtw/bqazH7UP5aBpw80bkfa4nfI2rAOrxFP4Hb7APhjOfL0cUsCjJCroR0ZBw8Se/8DNPxlLS5VCh61Prfsey79tZH6775Vat8yLQV1zVJIPIUy6FFEzbrl/j7XG+rezciovzA8Mr5Mz5c07a90xp+1I4qLc16h6rKfytW/I6qybvv4C/Ys+x+P/rrCpucyU1KI376LuKidnNy2nbioHWSlXKZW27yXgcWP4Ffz2nV6aZgXT0cZ8CiimsVBZ4o+QuqieaiJZ/B6dixGbwOs/x+ixz2IFh2vvHBOz5lDdkICtd9+u4C8+PlvIk05hE4YV2K/8uB21DVLEc07IHrcdd2UONca87L3EGENUNr3KtPz182aH8C1ZWvM5xIxJcRjrOn4NFG20OahYWyYu4Bjf2+iThfrM/O4+/pSv/st1O9+y5VrlxPPEbfNMjvY8vFSvh/1PAYXF4vvIHfJEBrZBs8SavHJi+chMx2qXv09Fqg8/NY8lKrV8XpoOMrmX+HEIehzP8LVnaqjR3Pktt5c3rABn27drjyfFReHT7vI4vvMSEWu/QqZEIty79OIMOevaOwopKoij+xCuW2wXeRXOuMXBgPu3XqQsX4dPg8Md7Q6NmFwcaHHpJf44+XXePyPn8sly6dKCE3796Zp/96AZe2cfOIkcVE7idu2gz9ff5P4HbvxDgkiNLLtlZdCzTYtccstwS2PHUDUaXpNiGyBysM/LuPS1Em43dIND79MxCevotw1CqVKTWq++grxEyfR6PffUDw9AUuAT/A9dxWpszyyG/XnzxFN2loSbbg4b6EVpyAuGnwC7HZ2odIZP1gO+lz+/JNKZ/wArR4Yyvo5C4j962/qdu2imVwhBIHhtQkMr03LewYBoJrNJB4+esV3sGvZj5zZf5CguuGERbahpiGV0J63UiM7G6Or67UyjS543Hs/br1vJ/2TD7j46Td49OyB2xevo9x6D95duuDVti1n3lxIjcmTgKIDfGRmOvLXb5DHD6HcNRIRXnmctY5EPbRD88Ce/FS6NT+AzMriVM/OVFv5K4bAsp3fdsSaP48dS78h6tOlPL5udYVvWZqyszmz9wAntmwlbulHxKcpXDgeR7VmTQiNtOwuhEa0IaRRA5RCW3mmkydIe3s+psMH8GwWjuvNHVBv7svR2wdS54vPcW/UiK11GxEZfRAl92UiY/ahrvwU0aAFotdghJt9tzGvJ0yLxmEY9BiidsMyy7iu1vwAws0N9/Ydydy4Hq9BdztaHZtpOexe1r82n5j1G6nfo2uF9m10daVW21bUDPFE9bqAYfhEslJTSdi5h7ioHRz65Q9+nzWXtPNJ1GzdwuI7aGdZNgSEheE3722yozaT+uZcMo8ux3PfLqo++RjxEyZS6623cAkJRnF1teSu+30Z8ugelAGPIOppv1V1PSOTz8HlixBadDCVFlRK4wdLeq/039dWSuM3GI30nDKedS+/Rr3utzgkYEnG7kfUsZSpcvP2pm6XjtTtcjWoJ+1CEvFRO4mL2sGOL7/lp9HjUM3mqwFJw56gSvIZLi9djNH/ICJLcG7Jx7iFhiKPH0Jd8QkivJGlQIa7Z4V/v8qOPLQT0bCVXQ8xVVrjd+9yC8mzZ6BmpKN4lP2P6+KPP0J+4xPiqjEKcfVe/s/k5mwvpV0BOYXu1akWzLr4ePa++RZ1I/Ot6/La5JdRjA5W6SoEIl/7vH/qpo2IDr0RuRV6RSG5BiGoHV6L2nVCYfAgAFLOJhK/7wAJew/w1+tvkrDvAO7eXlTz9SA45QJuew5Ru0Vj1B8+tKoclk7xyEM7ERH2nRVWWuNXfP1wvaklWf/+g0dP2xMcbA3YS9r2HaRm1YQ8v4eUBZNf5Lt+5R8WzzoltcvfJv+9fO2llLSsVZU/X1+Iz625CScLyZWFny/0WeaTV1w7Wbh/KcFsRp4/Df8eRiCu7aeQ3PwyfKSkMdDYKJEt6nEpI4tzqekkqkaiLySRtGU3xkbNqHYxBx8pK1UYtrMgszORJw6jDCsyJ45mVFrjB0usf5Pfc+D31TY77jJ37aLK6NF4dbJvJZySCDOb2demE9kPDKdRH0sQR0U4ItW9/yGj92K4c6RmMre98grnZ7+JEXBzN3D0hZeQZjNVhwymyuC7cathewDSjYqM3oeoVdfuy6VKnRXBvVuPMj0npSTryGHcGlV8UYn8KAYDPXLX/hW66xK7H1G3WentrCTlxHH+mfc2TbrdjKIoGGN20XrtChq+vZCs+Hh2du/FviH3c+7HFVblA7jRscfZ/aKo1MZvrFqtTM+ZzpxBuLhgDHZ84seb7h5ITkYmh9doUyCzNKSqIo8fRNQpnIC5jPKkZPWwh1CEoOuUcbi7GTkedwn12w/xjWhL/flzidwVRZXB93D262/Z2iqC6HETubxjZ8W+8CoJUkrLKT6Ns/YURaU2/rLiDKN+Hoqi0HPaeNbNnFMxxnDmBHj7IXyLD/u1hV1vvUVS7AlaPPIQnmG1CQzx41zcWcyH96Du3QaAwcODKnffyU3Lv6HVH7/iWq0ah598hp1dexL/7vtkJyZa1ZfWFWucklPHLbn6KiBrUaU3/q2pq9jT6qJNz2QdOoxbQ+cwfoCmg+5ANZk49PNau/clYw9c2eIrL8nRMfzzynxQVSLHjUH4+FGtVgDmbBPp7fth+uItZGZGgWfca9Uk7IXRtN38D/Vef430I9Fs79yd/Q88zPmfV6Nma1+CrDIhD+2w1PmrACq18adv3EDO8WP43HWvTc9ZRn7nCTG1jP4T+OPl1+zelzymzXpfNZtZO3wEVeuF0uS+oXhXrw5evgT6GBCKQvz6TSgNm2P+4dMinxdC4Nf+ZhouWkC7nVsJvqM/pz7+lK2tIoiZPI3UvfvKrWNlpO7v7tTbqM2srDQqrfGrmZkkz55F4KSpiCLi0ksi+8hh3BqVPWTSHjQZ0N/u22IyIw3On4Ja5S/zvf2NhaiXL5EUd4aIF0cDuQVEAvwxursT/9s6DMNGoW79CzXmQImyDF5eVB1yLy1+/J6Wa1Zi9PXh4MOPsrPHbSQsXkLO+Qvl1reyEF13O7FPVExeg0pr/ClLPsS1aTM8Ona26TmZk0P2sWO41neuo6RCCHpOt29uQnn8INRqUO6z8+cOHGLz/Leo3rYV9QcNwDffQR7h40dw25ZkJF8iLekihmGjMH9qOeNvDR7h4dQe/xIR2/6jzoxppO7aTVSHLhx85LFy6VxZUBq3Qf73e8X0VSG9lJHiHDw5J45z+buvCRg3yWZ5dab+hrFGDRQ758krC437a5+nrQCxB8o95Tfn5LB2+KPcfM+tRP+xicixzxds4O1LcNtWuLu5EP/5FyjtukFwNdTV39rUj1AU/G/pTKP33iZy+2YCenTnx13v8Yf7jnLp7+yIyO7Iwzudo1CnsyGlJOm1WfiOeBxjtbJt9TmTsy8/QgiWBh9liav2f+BSSk3W+1vmzsfdYEYNqkWdPr3wqxNe4L7w8SOwfjjS1Y2ElasQQmB86DnMf6xAJpwoU59GX1+qPXg/LX78nnMrVnFx85ZyfQdnRnh4IVp1Qt1i/wI1lc740//4DfOZM/iW4yy/s2zzFYXB6EJOekbpDW3lXAIYXREB1pdJK8zZXXvY8e5iekx4hj2ff0tE4VEfwNsP36qB5GRmk3Iinqzz5xGBIRjuHI7p0zeQqrnM/Rs8PKg7eQIx02YiVbXMcpwdpWMf5ObfrV4qlbkfu0rXGDU9jeTXXyNw8nSES9nXrc5s/FJV7XKSS8buR9Qte2CPKSuLNQ8/TtcH+3I8OpHQbl0ILMJpKnx8UTLTCWjZAm83Awk/WvIVKt36g6Kg/rmqzDoAVLlrECiCs9//UC45zoyoWgtRvTZyz3927adSGf+lD97DPSIS98h25ZLjzMavqmYUg7b/LeETVlP3J1O5pvybZs7Gz9+dho8/yc53PyLypReKbCd8/JCplwhq3w634GDiv/vOcl1RMD4yBvOKpcjzZwvoZkvgjlAU6s+cwbHZczGnpZf5+zg7olNf1H/W2jXwq9IYf3ZMNKk//Y+AF0vOCmsNLqHaVZDVGqmqCIO2pamOTe1CTJsYCCvb9mbCf1vZ9/mX9HppJAd++5fq7SMJvqmYWYS3L1xOIbBdJKqHF+f3HiDncioAonoYhtvuwvTFonL9Ufu1i8C3XSRx731QZhnOjmjcGjJS4eRRu/VRKYxfSknSqy/j98TTGIJtW7NKk4msA/u59PVXV645c5VXadZ+2i93bkQEVEG42lZFByA7LY21jz1Jz+H9cO/Ymx0L36bduBeLbS98/JCXLxEY2ZaUhNP4ebhw+verziul7xC4eAH1v3Vl+i551J0ykfgln5B1+nS55DgrQlFQOvRG3WS/qE/ntYJ8pK1ehXr5Mj5DhpXa1pycTNqG9VxY+AYJDz9IbPsIzox7kcz9+9jdNo2YccUXzHAGVLMZRcORX2ZnIXf/g4go2wnIjVNnUbV2FRo+8hgHv19JcIvmVGndstj2edN+95AQ3EJC8A8JIOGbq9t8wmjE8MgLmL/7EJlysUw6AXiEhVLjwfuJnf16mWU4O5Ztv1122/arFOf5L77xOiFvvI0oVLJKms1kx0STuWun5d/OnZjOJeLevCXurVvjP+Ix3Fu0xODv7xjFy4DW0365bzPUqocIqFJ640KcWP8XR3/4kYcWjUcNbUTUgofo+8WSkh/y9oPUFACC2kXglpTIyY3/Yc7KwuBmSdWt1GmE0uFWzN+8D7SwWa88wkY/w9YOXbm8ew8+Lcsux1nJv+1nsEPu/kph/B5duuLWqjXmlBQy9+zONfQdZO3dgyEwCPdWrXFv1Qr/4Y/gWr+B5mvmikSazSga1eWTqhm5/U+U/g/b/GzWpRTWjnyWXiP649n1DvZ/+z3+9etRvYSCHHB12g8QGBnBhV9/w9vDlcSN/1C9V88r7Qx3Didn6khQym60Rh8fwse/SPTUl2m14vvrMmuQ0rEP5g9fRva4U/OqRpVi2p+TkcXJO/pxvPstJC/+AHJy8H/gQWr/8ju1f/mdqnNex2/ofbg1alypDR/ypv3a/LfII7stx3dr1LH52fXjp1CnVUPq3nk30juAqPkLuXnC2NIf9PKBtMtI1UxQu0iSY44R4mUk/n8Ft+aEmzvGh5+3Wa/CVL9vKKaUFM7/vKbcspwRUbUWokY4crf2235ObfzRY2/m3/Pf496yJVXmzKPulihqffEVQWNexKt7TwyBQY5WUXO0mvZLKZHb/kCJvNXmZ2PW/MqJdeu5ZVhPRNObObL8B7xrVKdm546lPisMBvDwgrRUfBs3IjPxHMHtIjm19lekuWCAj9K0DdGdUziiLEf951ekyWSzrsJgoN7L04iZORs1K8vm5ysDomMf6i6/rHkuA6uMXwjRRwhxWAgRLYSYUEybwUKIA0KI/UKIr7VQzhgcTI0ln+J//4O4N2t2zZr/ekSa1WvKZ5WJuKOQkwU25svPuJDEb8+8QO+Rd+DeqT8SwdZ5b9BufPEe/sLkOf2EwUBg2zaotevh5mLgwtZri7QYbx+G8eEXUbdtwDR1BOa/ViNzbDvTH9i1C16NGhC/pOjjw5Ude2X1KfWvTAhhAN4F+gJNgWFCiKaF2jQAJgKdpJTNgOe1V/XGQCtvvxq1DhHRw+YXyR/Pj6Nh10jCunVFVA0j+qdVuPn5EdrdhjTSuXv9AIHtIkhJzybYTRC/quj6hErD5hjHvIbhsQnI3ZsxTRmB+c8VyGzrR/J6M6YS9857ZF+Hx3/ttTVtjdR2QLSUMlZKmQ18Cwws1OZx4F0pZTKAlNK6vEw616DFtF+ePw1nTiKa2hYJeWj5jyTu3EWn2yMRLW5Bqipb5y6g3fgXbXKmCW9fZKrF6RcUGUHSvgNUa1iHhJ9Wlhjco9RrivG5WRienIY8uBPT5Ecw//4/ZFbpST8969ejyl2DOP76Aqv1rEzEPlab6Jr/airTGuOvCcTl+zk+91p+GgINhRCbhBCbhRBFnk0VQowUQkQJIaLOnTtXNo2vc6TZXO43vYxah2jd1SbvcOqZs6x7cSK9n7oL14juCDcPYlevxeDiQnhv22rDWzz+lpE/IKINybt243drL2RmBpf2lZzYA0AJb4jx6RkYn5uFjDloeQn8sgyZWXI4b/jYMZz7eTVphw7bpG9lwB47GVrNJ4xAA6AbMAz4SAjhX7iRlHKxlDJCShkRElL202XXM+X19svUi8joPYhW1ic5kVLy29NjaD6oF9VvaoQIa4KUkq1zFhA5/gXb//B8/CB3u8/Vzw+vsDAyq9SkSoAXCT/bEMcfWg/jqCkYx7yGjIuxvARWf41MTyuyvUtAALWff5bo6TNt07eyoPELwJq/sgQgNN/PtXKv5SceWCmlzJFSHgOOYHkZ6NhIeaf9csdfiCaRCA9vq5/Zt/RrUk6coH23RihteiKE4MRv6zBnZ1Pv9n4265B/rx8gsF0kly6kEOKuEP/TStvl1QzH+PhEjGPnI88mWF4CK5ci0y5f07bGww+ReeIkF/5cb3M/To0dzvdYY/zbgAZCiDpCCFdgKFD4f/AnLKM+QohgLMuAWO3UvHEoz5FemZ2J3Psvom13q5+5dCKOvya/TJ/nhmJs2g7h7W8Z9efOt6z1y6JLvig/sET6JUVtJ7hnDzLPnCH12HHbZQKieijGES9hnLgQmXzO4hj88dMCLxrF1ZV6M6YSM30Wahm2Dm8kSv2flVKagGeAX4GDwDIp5X4hxEwhxIDcZr8CF4QQB4D1wEtSyuvP7VoBlMfbL/f8iwhrhPC3rhiJVFV+GfUsbR+6l5Bq/oiGbQGI27CRzORk6t85oBQJRZPf4Qe5kX5bt+HaqStVaoSQUM6AHFGlBsbhL2Cc/DakXcY09VHM33+ETEkGIKh3L1yrhHD6i69KkVR5kHYY+q16rUsp10gpG0op60kpX829Nk1KuTL3s5RSviClbCqlbC6ltC1hm84VpKqWyfil2Yzcvh4R2bP0xrnsXPwJ41s+zjBzR5S2PRGKpd+tcxcQOXZMmV9Chaf93vXqYk7PwBRWj2BzGgmrtAlWEcHVMDzwHMbp74MpB9O0xzF/9wFcvED9l6dzfP6b5Fy6VLqgyoCUDlnz61QgFm+/7f/J8vAO8A9GVKttVfvk6Bj+nTWHbfX3EzuiGiLIUkgzYdN/pMYn0GjIPTbrcIV8Dj+weKoD20WQfPgowc2bkXLwIJlntdsNFgEhGIY+hXHGh6AomF4ehceePwnq1oUTb7ylWT+OJl+xdU3Qjd/JUM1mmx1+tobyqmYzax4awc392xPUvQ+iwdUKMVvnLiBi7PPlOlxkifBLKXAtqF0kSdu24965G8HhtWj85jbNw1WFfxCGe0dinLkE3DwINZ/gzOdfkL7j2shCHd34nQ6pStudbCcPg2oGK4pvStXM1nEvoGSn03b26yj1WlzZyjsTtZ3kI0dpct+Qsqh+FS9vyEgrkKwzqJ1l3e/SsQshBvs64oSvP4a7H8Vz/lJq3taJmMdHYPpsAfJs4U2qSobGW/268TsZsgwOP3XbH4iInqXux8v0yyR++iZRX/1E36+/QgmuUeD+1jkLaPvCcxhsrIBUGKFcPdyTh3+rlqQcPAi1ahPoWTFnNISPH2GL3idVdefShXRMc1/AtGQu8lTJKcSdsiCoHXL56cbvZNi6zy8T4+H8aUSTiJLbnTlBzprP+OXtZXSZPQv/hgXz+SXu2kPirt00e+j+MuldGOHtW8DpZ/T0xKdRQy7u3YdHp1s06cMaDB4e1J02mWNr/8UwawmiZm1MC8ZhWjwbGX+swvTQBN3hd32jms02HcaxhPJ2KzaUV0qJun8z6uY1bN12Cq/wOrR45MFr2m19fQFtRj+D0d32PH9FUsjpBxAUGUnS1m0YIzvy9rej2TrIui3J8lJl0AAUFxfOrlqLoe9QjK9+hghviGnhJEzvzUSesF+STO3QR/7rHmlDeK9MSbbk42/Zqej7WRmoG39Anj5GYkgLdn/zA33eW3jN8uD8/oOc/m8rzR8teyGUwhTt9IvgwrbtxO4+iJuaTXDL5pr1V6IuQlBv1vQr6b6FuweG2+7B+OqniEYtML07A9Pb01BjD1WIPmVCSrRe9OvG72TYMu2XO9Yjmt2McPe89l7SGdTfliJ8A1E7DmDtM2PpMW823jWqX9N227w3af3MKFw8r5VTVoR3wb1+sIT5nv9vMwff/4jm3dtj2rFVs/5Kwy+iLf4d23Pynfeu6ujmjqHnIMtLoHkk5g9fxbTQtvqPFYru8Lu+Ua1M3S2zMpD7Nl8TyiulRI3ejfrX/1BadUNp3Z1Ns14nqEljGg++6xo5yUejiVv/Fy1GjtDsOwCWM/2pBY3fo2YNMlPTqDP0Xvxv60POv39r22cp1J0yiYRPPiPz1KkC14WLK4Zud2B89ROUtl2IDt1KbJdLdi2Y4Qzoxu9kWOvtl7v/QdRphvANvHrNlI3cshZ5dCdKz2GI0IbE//MfB75dTq9F84rcDdg2701aPvk4rj4+mn6PwlF+APGrViOMRgKbNsGlwy3k/Pd3hRqYe62a1Hj4IY69MqfI+8LogtKlL4YHRqPuj0L9dJ7NWYXsiz7tv66RsvRpvzSbkDs2FAjllSlJqL9bsqcpve5D+AaSnZrKmpHP0GvRPDxDrnWuXTp2nGNrf6XVqMe1/RJwjcMv53IqOyZOoe6wISTv3IWhTj2QEvV4xZ7/Cnv2KZL/3kTKjp3FthF+gRiefw1UFfPCiVfODDgUfavv+ke1IpmHPBgFQdURVWpZfo47jLruG0SD1oib+yKMln36vybPoFbHm2kwoH+RcqIWLKL5Y4/g5u+n7ZfAUrAzv8Nvz6tzqNr1FsKH3kvStu0IIXDp2IXs/zZq3ndJGL29qTNhLNHTZpY46xCubigjxiEat8I870XkqeMVp2RR6LH91z+lefstobzrUCJ7IlUz6o71qLs2onS9G6V+yytT+2N/rCdmzW/0mP9akXJS4uI5+tNKWj89yi7fI7/DL2nXHk787wdaz5qOf/ObSD12jJyUy7h0vIWcfyvW+AGqDR2Mmp7OuVIOGAlFwXDHgyi3P4B54UTU/Y4LE677+WnqHWurqUzd+J2MUr39xw6AoiBDaqL++R0yNRnltgcQgdWuNMm8eIlfnxxN7/cX4V7MqL5j4dvcNPxBPILtlP7cxxdSU1DNZraNGUur6VNxCwpCcXXFv0ULknfsxKVNO0yH9qOmXpuUw54Ig4F6M6cRO/NVzJml5wdUbu6B4YmpqEsXom4oX4lxZ0I3fiejNG+/um0dokFL5O9fIarXRelyJ8LNo0CbP8dOpF6/26hza9FJPdJOn+HQd9/T+rmnNNU9P3kjf/SSTzF4eFDn/qFX7gVFtuXC1m0ID09cmrfGtG2z3fQojoDOnfBq2pSExR9b1V7Ua4ph7DzUv9dg/u79a2oQVEZ043cySvL2q2dOwLl45NnjKO37oTRrf40H/+jK1ST8t5Wur84oto/ti96lybAheFW1vX6f1Xh5k5F0kb1z5xH5ZsGdhsB2kVdy+Lt0qPh1fx71pk/m5HsfkJ1oXTJZEVwdw9j5kHgK9f2XkRlF5xKsLOjG72QUN+2XWRnINZ9DYBWU3g8WeW4//dx5fh/9En0Xv4Ord9E5/NLPnefAl1/TdswzmuueH6EY2HMwkfr3D8WvUcFzBIGRbUnevgOpqpZ1fwVv+eXhWa8u1e69m2M2pPsWHl4oT82AoKqY57+EPH/GfgraGd34nYyivP0y6Qzq6o8h5SLizicRntfuyUsp+f25sTQdei+1OhVfhnzn2+/R8O5BeNeoUWwbLTj12x9cupRBk4fvu+aee0gIrkGBXD58BEOtMISXN+Yjjgmtrf3i85xfs5bUAwetfkYYDChDn0Lp3AfzgpeQsdY/60zoxu9k5Pf254/Ww+CGaH0LdV/eUORx04Pf/Y8Lhw7TefrEYmVnJiWz79MviHjxeXupD4ApLY2osRNo1SsSo6noqjuW8/1Xp/6O8PoDuPj7U/uF0cRMn2XT7EMIgdJ9AMr9z2H+YBbqtg32U9JO6MbvZOQl85CmnCvReqLTQIg7imjbrchnUk+dZv24yfRb8l6Jp/J2vb+Yurf3wzcstNg2WrDv9QUERbal2k2Nr4nyy8Oy7t8GgGuHLuQ4aN0PUGP4g2QmJJD0x582P6vcFIlh9GzUFZ9j/vmrShUSrBu/k6GazZCRivq7JfOs0us+iN2HqNcc4e1/TXspJb889TwtH3+Eam2LL+iYlZLCnsUfEzn2eTtpbuHi/gPEfvkNbWbPLDLEN4+gyAiStllGfmOrCEyx0agXHRNJp7i4UG/GVKKnz0TNybH5eVEzHMNLC5AHtjthSHDx6MbvZMjsTNiy5kq0Hgjkzo2IYvLz7f3sS9LOJtJhQslVdHd/uITat/bAv15dO2htQaoq28aMpfmkcXhUrXplr78ofBs3IvNsIlkXLiBcXXFp046crdrWorOFoF49ca9Rg1NffFmm56+EBEtpl5Dg2N4mYiLiNZWpG7+TIFUz6s71qBnpGNr3vRKtJw9sgyq1EMHXHsW9ePwEG6fNot+S9zC4FF+XLzs1lV3vLibypTH2/ArEfPElUpXUf8SSF6BwNp/85JXvToraDuCwaL8r+ghBvZnTOLFgETkXL5ZNhqsbyiMvIZq01jwkWCYmIEIKl8gsH7rxOwEy/TLqn8uQl5ORbh4oudF6UqrIqHVFZuWVqsovI58l8vlnCGnWpET5e5d8Rq0uHQls3EgznQvnuctITGTPK3OIXDj/6m5FvoKdRRFYwOnXmZwtmxwaPOPdtAnBfXtzYsGiMssQioLh9gdQ7ngQ88JJ2oUEnzsFVXTjv66QZ0+i/vYlonodlC53IqW8Gtsfsw9cXCH02rKHO977CHN2NpHPP12ifFNGBjveeo/IcSUvC8rLzinTqTNsMAE3NbtyTXj7XXOmPz9B7SxpvQAMVaujBIdgOrDXrnqWRp3xYzmz/HvSY8t32lBp1x3DE1MsIcHrba9PWBiZeApRRdvtWd34HcSV3Hr/rS4QrSfzhfeq29YhIm8t8hz+f3Pm02/Ju6We/d/32VKqRbYhpHmzEtuVhzMb/uL8f1tpPuGlAtfrLUum4bmiU4wBBOaW785zsuWd8XckrlVCCH3qCWJffrXcsiwhwfNR/1lbrpBgmZUJaSkQoG1la934HYDMykD9+0fk6VjLoZx80Xp5RTvkqWNw+SKiYasiZXScMp6A+vVK7MeUlcX2N9+m3YSxWqpfAHNmJlEvjqft67MxennZ9KyLry9eYWFc2n/A8nOHLg43foBaIx8jdf8Bkv/ZVG5ZIrhavpDgGWULCT53CoKrl7mAa3Hoxl/BXMmt5xOA0mPINdF6793/Js0XRllG/YjuV+rn5fF9+GnmH/+O1lak3Tr45TcENWtK1dattPwKBTjwxiL8mjSmZt/e19yLHd2SQ5klF8sMzDf1NzZviXoqHvW8dbH29sLg7k7dqZOImTZTEx/ElZDg4OqY54+1KSQ4fMJq6rxzEKHxeh90468wCuTWa9kVpfW1hl2A+GjETR0KXEpLPMffr8yl67SJpY4C5pwcohYsot14+671jyz5lLZzZxejhAlKWZbkj/QTRhdc2nUgZ7PjR/+QAbejeHpy5rvlmsgTBgOGoU+hdO5reQHEHLBNgMbrfdCNv0K4Gq23w5JbL6x0r7to2Qnh6lbg2p+TZ9D8gaGENC3Zuw9w6Jtl+IbXpkb7dmXW2xqajR2DZ81i/jDNJiimnkAeQfki/cCy7s+u4MSeRSGEoP7MaRx7bR6m1NTSH7ASpfsAlAdGY/7wFZtCgvWRvxJiya2XF613f4GEmyUhWnct8HPcv5s59ucGukweV+qzqsnEtvkLudmOa/08Go58tPibphwwlFyay6tuHcwZGWScOg2AS/tOmKI2I022R9ppjW+b1vh37kjcO+9rKrcsIcEOM34hRB8hxGEhRLQQYkIJ7e4WQkghRMm1o24QZNyRInPrFUf+fXPh5Xvls2oy8cvz4+j52kzcrMiye+R/P+FVtQo1uxTvaS8P5qyrh3VKquYrzWYwljztF0IQmC/UVwkMRqkVhmlP8Qk2K5K6UyaS8OnnZMZrW+RT1AzHMO6N3JDg15HZRR+AuoIjpv1CCAPwLtAXaAoME0JcUw5WCOEDjAa2aK1kZeNKtN6uDdfk1rOGI0MKpt7avvgTPAICaHrvtXn3r+1bZdu8N2g3fqxNfdrC6Q0beX/n+xyfU3Ri0CuYTAhDydN+yJv6Xw2GcenoHF5/APeaNaj5yHBiXyk6F2J5EL4BlpBgwLxoUokhwfkHA62wZuRvB0RLKWOllNnAt8DAItrNAuYCpSdFu465Eq2XkoRy24MFcutZi+J/Nc126tlE/pk9j94L51plzNErfsbFy4uwnt1s7tdaTqz8mdoD7yi9oSkHSpgZ5JFXvjsP145dK7ygR0mEPvMUF//bTMr2HZrLtoQEj7NLSHBpWGP8NYG4fD/H5167ghCiDRAqpXSyusYVS4FovVvuuia3nrUIv6t+gfWTX6bFQ/cR0qRx6f1Lyda5C2g3/kW7jfpqTg5xq38h7I5SRn0As7lUbz9cLd9tzsgAwNC4GWpyEubTp0p5smIwentRZ+I4oqe+bJcju0II+4QEl0K5HX7CUlL2DaDUPSUhxEghRJQQIurcOcfu5WpJcdF6ZSXPKRi3aTPH1v9F54nWOe6OrfkFIQR1ithz14ozf2/Ct24dvENrld7YbCrV4QeW8t2+jRuRvHsPYImPd2nfyaFn/AtTbfA9qFlZnFthv+y9WocEl9qfFW0SgPzZH2rlXsvDB7gJ2CCEOA60B1YW5fSTUi6WUkZIKSNCQrQNVXQUMjuz2Gi9sqIYjblOvpe4de4sq5x8eaN+5PgX7DbqA5xYsYqwYoqAXIMpp9StvjwCIyJIyr/ud5JovzyEolB/1nRiZs22Kt13mfspFBJsT6wx/m1AAyFEHSGEKzAUuPJaklJeklIGSynDpZThwGZggJTScRUO7EThk2wy6Qzqr0sR3kVH65WH7R9+jGdwME3uHmRV+5N/rMeUnkH9AbdrpkNhVLOZuNVrqW1lH9JssrrisKV8dz7jv7kTOTujkFmleMErEP+OHfBpfhPxHy6xaz9XQoLP2XfZU6rxSylNwDPAr8BBYJmUcr8QYqYQYoBdtXNSCkbr3YLSppRoPRtJPXOWf16bT+8351g1iksp2TJ3PpHjXtA8/js/if9uxqNaNXzq1rHuAZPZKocfXA3zzVtTK75+GBs0Imenc40hdadNJu79D8k6m2jXfoSHF8qTM4h9viXHZve1Sx9W/aVIKddIKRtKKetJKV/NvTZNSnnNwkRK2e16HPXzY2u0nq38OWkGLYffT7CV5+/jN/5DxrnzNLByllBWrPby52FFhF8enrVqori4knb8xJVrlsSef9mqpl3xrFuHakMGc3zuPLv3JQwGlGq17PZC1yP8yoK0LVrPVo7/9bfVTj6ArXPmEzn2eatKe5cVqaqcXLWa2oNsMH5TjlXe/jwC20VcOeQDV7P6OltSzNovPMf5X34ndZ+N8flOhm78NnDs+WbEdD+PaF96tF556PX6K8UW3SjMqf+2kHIyjkZD77WbPgDntkXh5u+PX4P61j9kNlvl7c8jKLLgut9QvxEyOxs17kQJT1U8Ln5+hL/4PNHTS6706+zoxm8LRldQVbt506Mb72LnfdVofFdRMVRFs3XuAiJeHF1iDj8tOLFiFWEDbXMmWhx+1hu/ZeS/avxCiNy03s7j9c+j+vAHyD57lgu//eFoVcqMbvy24OoGOfbxPkspEamX8G/czOqXy5ntO7hw8BBN8hXBtJduJ1euttrLfwUbtvqAAuW783Dp2IVsByb2LA7FaKTey1OJmTELNbtypOoujG78tuDibjfjJzMNDEaEa/FFNwqzbe4CIsY8i9HNrfTG5eDCjl0Y3Nzwt+IocQFMpR/syU/+8t15uER0wLR/NzI93ba+K4Cgnj1wDwvl1GdfOFqVMqEbvy24ukFpp6/KSkoy+Phb3fzcnn2c3b6TZsMfsI8++Tix8mfCBtxu+3LHygi//BSO8xdeXhibtiAnquLLeFtDvRlTObHwbXKSHVNwpDzoxm8DeetXaTZpL/xyMtiwe7D19QW0ee5pjB5lOz9gLVJKTq5YZdsWXx42bPXlkf94bx7OFu2XH+8mjQnp34/jCxY6WhWb0Y3fVlzsM/rLlCSEb4BVbS8cPETCP//R/LGHNdejMMn79iPNKoEtm9v8rDSVnsarMIGRbUmKspTvziOvlp+zetbDx48l8X8/kh4d42hVbEI3fltxdYMcO8R2X04GH+uMf9u8N2n99BO42JgttyycWPEzYQPLMOUHsNHbDwXLd+eh1K4DRhfMsUdt16ECcA0OIvTpJ4l5+RVHq2ITuvHbiot9PP7ycjLCp/Rpf3J0DCfXrafFEyWkz9KQkyt/praNW3xXKMO0Hwom9QTLlp+lnJdzTv0Baj0+grSDh0ne+I+jVbEa3fhtxdXdPk6/lGSwYtofNX8hLZ54DDdf7TO7FObiocPkpKYS3LZN2QTYGOGXR2ChpJ5wNdrPWVHc3Kg7fRLR0152aMkxW9CN31bsNPJbM+1POXGS2NVrafXkSO37L4KTK38m7I7+ZY8tt+FgT36CIiOuFPDMw6VNJKajB1FTii//5WhCbu+P0deX099852hVrEI3fhsRLm6lJ1u0EZmTDdmZ4FlySO+2BYu4acRw3AP8Ne2/OE6U1cufRxm2+iC3fPeZs2RduHDlmnBzx6VlW3K2/Vd2fexMXqXf43MXaJru217oxm8r9ojyu2zZ47ckRSqmSUICR3/4idbPPKlt38WQEhNLRuI5Qm6OLLMMacpBlGHNX7h8dx4uHR1fy680fFu1JKBrF04uesfRqpSKbvy24mIHb//lZCjF2bfjzXdo9uD9eIYEl9hOK06ssEz5y3VS0MocfkURWMjpB1f3+/NvAzojdSeP59QXX5JxMq70xg5EN35bsUOUn0xJLnGPP+3MWQ5+u4w2o0sux60lJ1f+bHssf2HKOO2HguW78zDUqIXw86fOpLUFMio5G27Vq1Pz0Uc49uocR6tSIrrx24o94vsvJ5Xo7Nvx1ns0HnovXtWqattvPvKnKEs9cZLUk3FU7dyxfEJNZdvqA0v57ou791wp352Ha4cu5dOpggh7+kkubt7CpW3bS2/sIHTjtxF7OPxK8vRnnL/A/i++pO3zz2rbZwmcXLWa0P59S6zGYw2N0npR78vTZXrWxdcXz9BQ6k79rcAo79LhlnLpVFEYvDypO2kC0dNmOO0yRTd+W7GDw88y7S96zb/znfdpcOdAfGppX6utOCxefvslArWWwHbXOhuNLcsYc+AAqt57F9JkJvHHFY5WpUh047cVe+zzFzPyZyZfZO/HnxHx4mht+yuB9FOnuXQ0hmq3dK6wPosjqN21JR+FnZOWaEleuu/YV17DnJ7haHWuQTd+W3F1s+zJa4RUVUi7BN7+19zb/cFH1OnXB7/w8tcCsJYTq1ZTq08vDK72S1NmLUFFjPyVDf/2N+PTuhXxH37kaFWuQTd+W9Ha4Zd2Cdy9EIXW11kpKez+4CMiXxqjXV9WcNLWDL12xKuYFOFHB/my31h50mfVmzaZuA8/IuvsWUerUgDd+G0l90ivZsdLi5ny71n8CWE9uhNQv542/VhJ0t591OjetdxyZE42B85/SOyM7mWWUdxJQvV0AoqfdScgnQGP8NpUHzaEY6+97mhVCqAbv40IgwEUxbKHrQFFOfty0tLY+c4HFT7qA9TsdSsGd+tTiRWHevQgSmAIirunBloVkn08FiW8ruZy7UntMc9x4fc/ubx3n6NVuYJu/GVBS6dfyrV7/Hs//pyandoT1LT0yrxaU9vaOnyloO7fiXKTdp759HzRcubjMSjhFTsjKi9GX1/Cx44hZpp9Kv2WBd34y4KWUX6XCx7lNWVmsmPRO7QbV2rRY7tQs1dPTeSY9+/A0Kx1ueUcn9OfP+uc5p/+g7h86DAA6olYDLWtLBnmRFR/8D6yLyRx4ZffHK0KoBt/2dAwvr9wEo/9n39JlTatCSlD2qyykpZwire3v8uK4CMYPcs/TZfZWajRB1GatNRAO6j7xGM0mTKBf++8l6S//0FmZiJC7BftaC8Uo5H6M5wn3bdu/GVByzx++Rx+5uxsot54i3YTKm7UP7n2N1Z0vZXQ3r1oPuY5TWSqRw+g1ApHeGqXZix0yL20fGMeW4c/SopngF3LkNuTwB7d8Khbh4RPPne0KrrxlwlXd6QGa34pZYEMPge/+pagxo2oVtbMOTZgzs5m88Sp/Dt2PD2WfkKrsWM0q/Vn3rdD0/V+HtX69qblqOHs23qQ02t+0Vx+RVFvxlROLnqbnCTHpvvWjb8MCA0cfuETVlNn4hoQAuHmgTknh23zFxI53v6jfkrsMX7u1Z/Lx45z599/Uq1De03lq/t3arLeLwp/DyNtnxnBnhfHEfftMrv0YW+8GjUkZMAdHJ+3wKF66MZfFrR0+OWO+oe/+x7fsFBqdtTWEAsT+78fWdWzL/WG3MOt33yBW6C2++UyKws19jBK4xaays1DPRFLQNeudPzpew7NnkusE0bOWUP4uBdJ/GklaUejHaZD+Y5t3ahomdDDJxDVbGbb/DfpsXC+NjKLwJSRweYJUzi1YSO9f/iO4NbaOOMKox7ZhxJWD+Gh/f4+gPl4LIbadfGpUYtOq3/iv7uHkp18kUbjx1YqP4BrUCBhzz5NzIxZtPjKMet/q0Z+IUQfIcRhIUS0EGJCEfdfEEIcEELsEUKsE0JUXDC6I9DQ4Sd8Ajj6wwo8goKo1dU+Z9WTDx1mZffe5FxOZdDff9rN8CFvvW+fKb9MS0WmXERUqwGAZ2gonX7+iTO//Ma+iVOd9uhscdR89GHSj0aTtMExWYlLNX4hhAF4F+gLNAWGCSGaFmq2E4iQUrYAvgecK45RazQ81iu9/dg2dwHt7DBySSk5svQr1vQdSLOnRtLt4w9w9fXRtI/CqPt3YGhmH4el+cQxlNDwAtmE3auE0GnF91zau5edTz13TfIPZ0Zxc6Pe9CnETH8Z1WSHEnCl9W9Fm3ZAtJQyVkqZDXwLFCggL6VcL6XMK6O6GailrZrOhXBxp+5KRZNUUjFRBzB4uFO7Vw8NNLtK9uVU/nrsSfa+8wH91vxEo4cesPu0WGZmoB6PRmlsnxgF9UQshiIi+1z8/Gi//Buyk5PZ9vBjmDOc7/hscQT364MxIIAzX39b4X1bY/w1gfyZCONzrxXHo8Daom4IIUYKIaKEEFHnzp2zXktnw0W7kthRn35Hu/EvamqY53ftZkWXHhi9PBm4/lcCmlRMmLB6eB9KeAOEW/nPBhSF+XispXRXERg9PWm39FMMHh5sHnI/OSmX7aKD1gghqD9zOsdffwPT5YrVWVNvvxDiASACmFfUfSnlYillhJQyIiQkRMuuKxZX7YxfIqjbv682sqRk//uL+fXOIbSZMpHOb72hScSetdhrfz8P9URMkSN/HoqrK20/fBfvBg34d9A9ZJ0/bzddtMSnRXMCe3TjxMK3K7Rfa4w/AQjN93Ot3GsFEELcCkwGBkgp7VTE3knQcOSPnKDNWj8rKZk/7htO9LfLuOOPNdS7504NtLMNVaN4/sLkJRe15jSfMBhoMX8OVXp0Y9Ptd5KRcM2fqlNSZ9I4Tn/5NRknTlZYn9YY/zaggRCijhDCFRgKrMzfQAjRGvgQi+Enaq+mk6HhyN/gzoGlNyqFs1u28mOXHvjUDuP231bjW6/ij7vKjHTUk7EoDW+yWx/q+XMoNUJLbSeEoMmUiYQ9cB//9B9EaiUone1WrRq1Rj5G7KzZFdZnqcYvpTQBzwC/AgeBZVLK/UKImUKIAbnN5gHewHIhxC4hxMpixF0faDDyr298mnc2vlL2OnhYUoDtXrCQP+57mI7z59B+zisY3LR7MdmCemgPSt1GCDv2r9SsdU3Go5Ko/8yTNBw7hk0D7ubSnr1200srQp98gpSo7Vzasq30xhpg1W9SSrkGWFPo2rR8n2/VWC+nRijli4G37BJU5/XL6aW2LY6MxEQ2PP405swMBv71O94VmN23KMz7dqDYKaQ3j5LW+8VR+4H7cPH15b97hxH52RKCNA5l1hKDpwd1JlvSfbdZu6pcA4M16OG9DiQtOaVAsQxrSVj/Fz916UmViDb0W/2Tww0fcuP57ejsA1Bql205U2PA7bT54F22DX+Us7+v01grbal6t8VXc/Z/P9q9L934C1EWYywr2am27UerJhNRM19l46hn6Prhu7SdOrHchTW0QKalosYfR2nYzK79lNX4Aap070q7r75g57PPE18BhlVWhKJQf+YMjr06x+7pvh3/l1PJkKkXkbv+IqbzOZTeD5ZLlkdwycU585Man8CGESMxenoy6O91eFSpUq6+tUQ9tAelflOEi33TfRvKmbcvMLItHX9YxubB95NzKYU6I4ZrpJm2+N0ciW9kBHHvfUD4WPvlcdRHfiuRmemo235D/XkJePmh3Pag5WhvOfCqVs2qdifW/MKKrr0I7dub3j9851SGD5b1vr2O8OZHCQ0vtwzfpk3o9POPxLz7PkfeWOQ0+fQKU3fKROI/+pisM2fs1odu/KUgc7JR9/yD+tN7YDahDByF0robQoPtPq9S1urmrCw2j5/M5nGTuPXrz2g55jm7O4HKgnm//Q7z5EdoVEjEK7w2nX/+iYQffuLA9JlWvwAqcknoUTuM6g/cx7HZ9jsm43x/SU6CVM2oh7ej/vgeJJ9F6fcISvt+CA9vzfrwqlm88afExLKqVz9S4+IZtHEdVW9up1m/WiLTLiNPxaHUL3zWy7lxr16NTqt+IGnzVnaPftEhB2tKo/bzz5L05wYu22mbUjf+YlBXfIA8fgClx2CUrncjfIMK3NdiFPCqXvS0P2b5D6y6tR8N7x9Gz68+0zzhhpaYD+xGaXhTpaqhl4drQAAdflhGenwC2x99AnOWcwWmGn18CB/3AtFTZ9hleaIbfz7kmeNXPivt+qDc9gAiuIbd+vOqUb3Az6b0dP5+5nl2zJ5L75+W0fSJx5w2QcWVkNsKWO9HP1ybPTGLUM9pX+7K6O3Fzd98AUKwZdhDmFLTrHquopYA1e8fhuniJc6vLvKsXLnQjR+QSWcx//EN6qZVV66JmvWKNTyZdunKZ3XXRmTMXmRiPDIjzaY3tFf1q8afdOAgK7rdhjkrm4Eb1xHc0j5psLTGst637/6+sfFNuN01jPQpY5B2OK5rcHOj7ZIP8KxVk//uGkx2smMTa+ZHGAzUe3kaMTNfRc2dmWj14rmht/pk6kXkzg3IU7GIFp0R3QfDtqKzwsrsLOSx/cijO+H8KWL6NEV4+UFKEmpCjCUF9+VkUM2Wirs+AQifAEtabp8AhG/uZ0/fKzIt037LMc61/e8k8pXpNLhvqNOO9kUhz55CqWf/I8OuQx9GPXmMjDlT8Jg+T3PHp2I00nLRAg5Mn8mmO+6iw/ff4l7NOWoDBHa7Bc8G9Un4+DNCn3pCM7k3pPHLzHTknn+QsXsQjSNQ7ny6SO+9NJshIRp5ZCcy7ghUr4PStB2ENUYYi17jyuxMSzruy8nIvBfChdOoudfITANuAfKm/UcB6P/LSvwbNbTXV7YbSqPmNsXblxUhBO5jppD+0iiyPn0f90eftksfTV+ehou/P//0H0iH77/Fq0645v2UhXozprJrwF1UHXyPZjJvKOOXOdnIg1uQB7YgwpuhDBxVrPde3bQKGbMHfIMQDVqhdLoD4VF6EQrh6g7B1SG4OkWN39Js4lhmOsLLMgPY+3RLPIKDMHp4lOerOYyK2N/PQ7i64vHyfNKefgglLBzXXtrUFSzQhxA0fGE0Lv7+bLr9Ttov/xrfpk0078dWvBrUp8qdAzk+7w0QHTWReUMYv1TNyKO7kLv/RlQNRek3okBlXJmdCaePIU/FXn3IzdPycvALKkJi2REGI3hdnfr7hFbujGf2Xu9f059/IJ6vLCL9xcdRatTC2Mw+yUjrjBiOq78f/941mHZLPyMwsq1d+rGF8LEvsLVTVxioG3+pSCnhxEHUnestUXk9hyCCqiPNJuTpY8iEGIvBX0yEkFBEzbrEvlQfEVitUq27HYlSt+KXKoY69XAfP5OMGWPxevtzlGr22ZGpedcgjD4+bL3/Idp8+J5d+rAFl8AAwkY/C8e1kXfdGH+e9/P4HMtUUJ45jrp9Hagqol0fhJsnMiEGddtvkBgP/sGIGvVQIm6FKqHFruF1SkYYHPMn5HJzZ9QhD5M+5Xm83vpU07qA+anaqyeRn3/MtocfgzvnArm+IAdgTksn8+xZlu9cQcTa8qfMuG6MPw+ZdAZ1x3pISoQa4ZCViVy/HOnuiahRF6VxO+g+GOFWOdfYOldxvfs+1BOx1Jm5Abj64teaoA7t6bD8G/739rtcWLeetZ+MhBGL7dJXcZz75TeOTJqKX7tIWnzxiSa7Hded8aurloDBAIoBYTZBWENE+z6WbTkdzTh6uzs5yz5Fpnez26hbGkII3EdPgCn2r3fv16I5bT/6ALBkUFox6mmCu3eze78ZJ+M4Mmkq6TGxNFm0gMAunTWTfd0Yf+xdRuTJw4iG3RChDcE/RF+32xFjpx6o+3aQNX8qbpNfd9j0P/9yLXvNj7j2s3/iUqEo5CQl4VbVfnEAanY2J9/7kJPvf0jYqJE0//hDFI1TpF03EX5Ku94Y7nkOpUVnREAV3fDtjBAC18fHgJRkf/SmUxyNzfr4HUy7o4q8p3U4buaZs7hVtc/R6qS//2FLt15c2hZF5G9rCB/znOaGD9eR8etUPMJgxO2lVzAf3INp1XeOVgePSa+SMXMCakJc6Y3LSdbZRNw0jgDMOpvI/ief5eBzL1B/6iRafvU5HrXDNO0jP7rx65QL4emF+5R55Pz0NaYtjik4mYexbXvcHhpJ+uTRyFT7Vb8xZ2ZiTk/HJdD6TEwlIc1m4pZ8ypauPXGrUYP2/2wgpG9vTWSXxHWz5tdxHEpINdwmzSVz5ouIoCoY6ldMebCicB04GPOJWNJnTcBz9iK7+CKyE8/hFhKsydIyZecuDo2dgMHbmzYr/od3BYZ46yO/jiYY6jfB7anxZL023i5Hb4tD5iuVriZaUl65Pz0WpCTrvQV26TPzzJlyO/tyLl7k0LiJ7H5wBKGjHqfNT8sr1PBBN34dDTG274rLHUPIfGUsMt26c/FlQaoq5r07yFz0KmkP3s5BZQ2HQveR9uR9ZL4zF5l0Ac9pczFt30L2imWa9h0+YTWtfyj7el9KyenvlrO5U3dA0H7Teqrfe7dDHNT6tF9HU4wDh6GejrfLFqB6Mpac9b9gWv8rwssbY48+eL77JUqwxRBd7x9J9vdfkDZqGC49+uA+dhoZ08ei1NLeaVaWkT/10GEOj5uEOT2dll99hm8r+5xLsBZ95NfRFCEEriNfsGwBLllY7i1ANekC2T9+Q/pzD5Ex5TkwmXCfPh/Pd7/C9e4Hrxg+gOIfgPtjo/FashyMLmRMfwFDyzZkzJpY3q91DbZs85nT0ome+So7Bt1D1UF3EPnrzw43fNBHfh07kLcFmDHhCUyrluEyYIhNz8vMDEz/bsC0/hfMh/djbH8Lro88g6FFW4Sh9FJpSkAQ7k+MwfWeB8le9jkmU3ZZv0qxuFmRPl1Kyfm1v3Jk8jT8O7bn5o1/4lbFeUrT68avYxfytgAzJzyBqFod4823lNhemk2Yd0dh+nMtpq3/YGjSAmPPfrhPnotwdy+TDkpQMO5PvojrvQ9ycPlS3Ee9UCY5RVHatD/j+AkOT5pKxomTNH1nIQGdtDmGqyW68evYDaVKddwmziFz1tgitwCllKixRywG/9dviJCqGLv3wfOx0Sj+2uyhAyjBVfB48kXN5AHFOvzUrCxOvPM+cYuXEPb0k7T4bAmKRvUGtEY3fh27YmjQ9MoWoPucxSghVVETz2Da8Cum9WuRWZkYu/fB47X3NKnIU1EUNfInbfybw+Mm4dmwAZF//IKHkydq0Y1fx+4Y23dFnkmg7gJL3P2+4wsxdu6J2zMTUJq0cMoqRKXhGnw1w1PWmbMcnT6TS1HbaTh7JiG9b3OgZtajG79OhWAcOIzo2lshPR1Du9V2L+ppb4TBgGoykfDJ5xx7YyE1H7yfJm/Ox+BZefJEWGX8Qog+wCLAACyRUs4pdN8N+AJoC1wAhkgpj2urqk5lRgiBsfXNjlZDMy5t38HhlyZi9POl7aof8WpQ39Eq2Uyp8y0hhAF4F+gLNAWGCSEKF2Z7FEiWUtYH3gTmaq2ojo4zsWf4Y4Q9PYrWPyyrlIYP1gX5tAOipZSxUsps4FtgYKE2A4HPcz9/D/QU+oF6neuQI+M6sC7gIO03rafa3XdW6rwR1hh/TSD/Aen43GtFtpFSmoBLwDU5r4UQI4UQUUKIqHPnzpVNYx0dB+IaGEi98WNx8av8aeEq1M0qpVwspYyQUkaEhDhPpJOOzo2INcafAITm+7lW7rUi2wghjIAfFsefjo6Ok2KN8W8DGggh6gghXIGhQOGk4SuB4bmf7wH+lM6Q1E1HR6dYSt3qk1KahBDPAL9i2er7REq5XwgxE4iSUq4EPgaWCiGigSQsLwgdHR0nxqp9finlGmBNoWvT8n3OBO7VVjUdHR17UvniKnV0dDRBN34dnRsU3fh1dG5QhKOc8kKIc8CJIm4FA+crWB1b0XUsP86uH1wfOtaWUhYZVOMw4y8OIUSUlDLC0XqUhK5j+XF2/eD611Gf9uvo3KDoxq+jc4PijMa/2NEKWIGuY/lxdv3gOtfR6db8Ojo6FYMzjvw6OjoVgG78Ojo3KA4zfiFEHyHEYSFEtBBiQhH33YQQ3+Xe3yKECHdCHV8QQhwQQuwRQqwTQtR2Jv3ytbtbCCGFEBW+bWWNjkKIwbm/x/1CiK+dTUchRJgQYr0QYmfu/3W/CtbvEyFEohBiXzH3hRDirVz99wgh2lglWEpZ4f+wnA6MAeoCrsBuoGmhNk8BH+R+Hgp854Q6dgc8cz8/WZE6WqNfbjsfYCOwGYhwwt9hA2AnEJD7cxUn1HEx8GTu56bA8QrW8RagDbCvmPv9gLWAANoDW6yR66iRvzLkBSxVRynleilleu6Pm7EkOnEa/XKZhSWhamYF6paHNTo+DrwrpUwGkFImOqGOEvDN/ewHnKpA/ZBSbsRyVL44BgJfSAubAX8hRPXS5DrK+DXLC2hHrNExP49ieftWFKXqlzv9C5VSrq5AvfJjze+wIdBQCLFJCLE5N018RWKNjjOAB4QQ8ViOtj9bMapZja1/q4BetEMThBAPABFAV0frkocQQgHeAB52sCqlYcQy9e+GZea0UQjRXEp50ZFKFWIY8JmUcoEQogOWxDU3SSlVRytWHhw18leGvIDW6IgQ4lZgMjBASplVQbpB6fr5ADcBG4QQx7GsBVdWsNPPmt9hPLBSSpkjpTwGHMHyMqgorNHxUWAZgJTyP8Ady4EaZ8Gqv9VrqEjHRT4HhRGIBepw1cnSrFCbpyno8FvmhDq2xuIsauCMv8NC7TdQ8Q4/a36HfYDPcz8HY5m+BjmZjmuBh3M/N8Gy5hcV/LsMp3iHX38KOvy2WiWzIr9AIYX7YXnLxwCTc6/NxDKCguXtuhyIBrYCdZ1Qxz+As8Cu3H8rnUm/Qm0r3Pit/B0KLMuTA8BeYKgT6tgU2JT7YtgF3FbB+n0DnAZysMyUHgVGAaPy/Q7fzdV/r7X/z3p4r47ODYoe4aejc4OiG7+Ozg2Kbvw6OjcouvHr6Nyg6Mavo3ODohu/js4Nim78Ojo3KP8HWNDOFqce/gUAAAAASUVORK5CYII=\n"
|
384 |
+
},
|
385 |
+
"metadata": {
|
386 |
+
"needs_background": "light"
|
387 |
+
}
|
388 |
+
}
|
389 |
+
],
|
390 |
+
"source": [
|
391 |
+
"plot(nodes_coordinates[resulting_traj_with_depot], obs['demand'])"
|
392 |
+
]
|
393 |
+
}
|
394 |
+
],
|
395 |
+
"metadata": {
|
396 |
+
"kernelspec": {
|
397 |
+
"display_name": "Python 3",
|
398 |
+
"name": "python3"
|
399 |
+
},
|
400 |
+
"language_info": {
|
401 |
+
"name": "python"
|
402 |
+
},
|
403 |
+
"colab": {
|
404 |
+
"provenance": []
|
405 |
+
},
|
406 |
+
"accelerator": "GPU",
|
407 |
+
"gpuClass": "standard"
|
408 |
+
},
|
409 |
+
"nbformat": 4,
|
410 |
+
"nbformat_minor": 5
|
411 |
+
}
|
demo/tsp_search.ipynb
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"source": [
|
6 |
+
"!git clone https://github.com/cpwan/RLOR\n",
|
7 |
+
"%cd RLOR"
|
8 |
+
],
|
9 |
+
"metadata": {
|
10 |
+
"colab": {
|
11 |
+
"base_uri": "https://localhost:8080/"
|
12 |
+
},
|
13 |
+
"id": "5a69iB04JzY2",
|
14 |
+
"outputId": "13a3a63e-34bf-4d8b-a853-9d2597cd03d5"
|
15 |
+
},
|
16 |
+
"id": "5a69iB04JzY2",
|
17 |
+
"execution_count": 1,
|
18 |
+
"outputs": [
|
19 |
+
{
|
20 |
+
"output_type": "stream",
|
21 |
+
"name": "stdout",
|
22 |
+
"text": [
|
23 |
+
"Cloning into 'RLOR'...\n",
|
24 |
+
"remote: Enumerating objects: 52, done.\u001b[K\n",
|
25 |
+
"remote: Counting objects: 100% (52/52), done.\u001b[K\n",
|
26 |
+
"remote: Compressing objects: 100% (35/35), done.\u001b[K\n",
|
27 |
+
"remote: Total 52 (delta 12), reused 52 (delta 12), pack-reused 0\u001b[K\n",
|
28 |
+
"Unpacking objects: 100% (52/52), 5.19 MiB | 7.89 MiB/s, done.\n",
|
29 |
+
"/content/RLOR\n"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 2,
|
37 |
+
"id": "dbe3c5ed",
|
38 |
+
"metadata": {
|
39 |
+
"id": "dbe3c5ed"
|
40 |
+
},
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"import numpy as np\n",
|
44 |
+
"import torch\n",
|
45 |
+
"import gym\n",
|
46 |
+
"from models.attention_model_wrapper import Agent"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "markdown",
|
51 |
+
"id": "985bf6e6",
|
52 |
+
"metadata": {
|
53 |
+
"id": "985bf6e6"
|
54 |
+
},
|
55 |
+
"source": [
|
56 |
+
"# Define our agent"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 3,
|
62 |
+
"id": "953a7fde",
|
63 |
+
"metadata": {
|
64 |
+
"colab": {
|
65 |
+
"base_uri": "https://localhost:8080/"
|
66 |
+
},
|
67 |
+
"id": "953a7fde",
|
68 |
+
"outputId": "06f10aaf-57ca-4870-d22b-f71a14ea4ec4"
|
69 |
+
},
|
70 |
+
"outputs": [
|
71 |
+
{
|
72 |
+
"output_type": "execute_result",
|
73 |
+
"data": {
|
74 |
+
"text/plain": [
|
75 |
+
"<All keys matched successfully>"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
"metadata": {},
|
79 |
+
"execution_count": 3
|
80 |
+
}
|
81 |
+
],
|
82 |
+
"source": [
|
83 |
+
"device = 'cuda'\n",
|
84 |
+
"ckpt_path = './runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt'\n",
|
85 |
+
"agent = Agent(device=device, name='tsp').to(device)\n",
|
86 |
+
"agent.load_state_dict(torch.load(ckpt_path))"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "markdown",
|
91 |
+
"id": "2cbaa255",
|
92 |
+
"metadata": {
|
93 |
+
"id": "2cbaa255"
|
94 |
+
},
|
95 |
+
"source": [
|
96 |
+
"# Define our environment"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 4,
|
102 |
+
"id": "c2bd466f",
|
103 |
+
"metadata": {
|
104 |
+
"colab": {
|
105 |
+
"base_uri": "https://localhost:8080/"
|
106 |
+
},
|
107 |
+
"id": "c2bd466f",
|
108 |
+
"outputId": "92450c8d-f5db-444d-f465-da4a98667799"
|
109 |
+
},
|
110 |
+
"outputs": [
|
111 |
+
{
|
112 |
+
"output_type": "stream",
|
113 |
+
"name": "stderr",
|
114 |
+
"text": [
|
115 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:31: UserWarning: \u001b[33mWARN: A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (50, 2)\u001b[0m\n",
|
116 |
+
" logger.warn(\n",
|
117 |
+
"/usr/local/lib/python3.9/dist-packages/gym/core.py:317: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
118 |
+
" deprecation(\n",
|
119 |
+
"/usr/local/lib/python3.9/dist-packages/gym/wrappers/step_api_compatibility.py:39: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
120 |
+
" deprecation(\n",
|
121 |
+
"/usr/local/lib/python3.9/dist-packages/gym/vector/vector_env.py:56: DeprecationWarning: \u001b[33mWARN: Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
|
122 |
+
" deprecation(\n"
|
123 |
+
]
|
124 |
+
}
|
125 |
+
],
|
126 |
+
"source": [
|
127 |
+
"from wrappers.syncVectorEnvPomo import SyncVectorEnv\n",
|
128 |
+
"from wrappers.recordWrapper import RecordEpisodeStatistics\n",
|
129 |
+
"\n",
|
130 |
+
"env_id = 'tsp-v0'\n",
|
131 |
+
"env_entry_point = 'envs.tsp_vector_env:TSPVectorEnv'\n",
|
132 |
+
"seed = 0\n",
|
133 |
+
"\n",
|
134 |
+
"gym.envs.register(\n",
|
135 |
+
" id=env_id,\n",
|
136 |
+
" entry_point=env_entry_point,\n",
|
137 |
+
")\n",
|
138 |
+
"\n",
|
139 |
+
"def make_env(env_id, seed, cfg={}):\n",
|
140 |
+
" def thunk():\n",
|
141 |
+
" env = gym.make(env_id, **cfg)\n",
|
142 |
+
" env = RecordEpisodeStatistics(env)\n",
|
143 |
+
" env.seed(seed)\n",
|
144 |
+
" env.action_space.seed(seed)\n",
|
145 |
+
" env.observation_space.seed(seed)\n",
|
146 |
+
" return env\n",
|
147 |
+
" return thunk\n",
|
148 |
+
"\n",
|
149 |
+
"envs = SyncVectorEnv([make_env(env_id, seed, dict(n_traj=1))])"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "markdown",
|
154 |
+
"id": "c363d489",
|
155 |
+
"metadata": {
|
156 |
+
"id": "c363d489"
|
157 |
+
},
|
158 |
+
"source": [
|
159 |
+
"# Inference"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 5,
|
165 |
+
"id": "bbee9e3c",
|
166 |
+
"metadata": {
|
167 |
+
"colab": {
|
168 |
+
"base_uri": "https://localhost:8080/"
|
169 |
+
},
|
170 |
+
"id": "bbee9e3c",
|
171 |
+
"outputId": "11750d60-eb7c-4d9c-8b40-3a8a92021ffe"
|
172 |
+
},
|
173 |
+
"outputs": [
|
174 |
+
{
|
175 |
+
"output_type": "stream",
|
176 |
+
"name": "stderr",
|
177 |
+
"text": [
|
178 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:174: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.\u001b[0m\n",
|
179 |
+
" logger.warn(\n",
|
180 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:190: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.\u001b[0m\n",
|
181 |
+
" logger.warn(\n",
|
182 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:195: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information.\u001b[0m\n",
|
183 |
+
" logger.warn(\n",
|
184 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method is not within the observation space.\u001b[0m\n",
|
185 |
+
" logger.warn(f\"{pre} is not within the observation space.\")\n",
|
186 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n",
|
187 |
+
" logger.warn(\n",
|
188 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:227: DeprecationWarning: \u001b[33mWARN: Core environment is written in old step API which returns one bool instead of two. It is recommended to rewrite the environment with new step API. \u001b[0m\n",
|
189 |
+
" logger.deprecation(\n",
|
190 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:234: UserWarning: \u001b[33mWARN: Expects `done` signal to be a boolean, actual type: <class 'numpy.ndarray'>\u001b[0m\n",
|
191 |
+
" logger.warn(\n",
|
192 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:141: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n",
|
193 |
+
" logger.warn(\n",
|
194 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method is not within the observation space.\u001b[0m\n",
|
195 |
+
" logger.warn(f\"{pre} is not within the observation space.\")\n",
|
196 |
+
"/usr/local/lib/python3.9/dist-packages/gym/utils/passive_env_checker.py:260: UserWarning: \u001b[33mWARN: The reward returned by `step()` must be a float, int, np.integer or np.floating, actual type: <class 'numpy.ndarray'>\u001b[0m\n",
|
197 |
+
" logger.warn(\n"
|
198 |
+
]
|
199 |
+
}
|
200 |
+
],
|
201 |
+
"source": [
|
202 |
+
"num_steps = 51\n",
|
203 |
+
"trajectories = []\n",
|
204 |
+
"agent.eval()\n",
|
205 |
+
"obs = envs.reset()\n",
|
206 |
+
"for step in range(0, num_steps):\n",
|
207 |
+
" # ALGO LOGIC: action logic\n",
|
208 |
+
" with torch.no_grad():\n",
|
209 |
+
" action, logits = agent(obs)\n",
|
210 |
+
" obs, reward, done, info = envs.step(action.cpu().numpy())\n",
|
211 |
+
" trajectories.append(action.cpu().numpy())"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 6,
|
217 |
+
"id": "f0fbf6fd",
|
218 |
+
"metadata": {
|
219 |
+
"id": "f0fbf6fd"
|
220 |
+
},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"nodes_coordinates = obs['observations'][0]\n",
|
224 |
+
"final_return = info[0]['episode']['r']\n",
|
225 |
+
"resulting_traj = np.array(trajectories)[:,0,0]"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "markdown",
|
230 |
+
"source": [
|
231 |
+
"## Results"
|
232 |
+
],
|
233 |
+
"metadata": {
|
234 |
+
"id": "5n9rBoH5Q8gn"
|
235 |
+
},
|
236 |
+
"id": "5n9rBoH5Q8gn"
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 7,
|
241 |
+
"id": "dff29ef4",
|
242 |
+
"metadata": {
|
243 |
+
"colab": {
|
244 |
+
"base_uri": "https://localhost:8080/"
|
245 |
+
},
|
246 |
+
"id": "dff29ef4",
|
247 |
+
"outputId": "dcffcda5-5728-464c-ee3e-0fcc702443ff"
|
248 |
+
},
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"output_type": "stream",
|
252 |
+
"name": "stdout",
|
253 |
+
"text": [
|
254 |
+
"A route of length [-5.908508]\n",
|
255 |
+
"The route is:\n",
|
256 |
+
" [26 34 33 49 37 21 48 43 31 28 42 29 47 39 38 23 27 30 7 32 24 40 20 14\n",
|
257 |
+
" 25 1 18 22 0 11 2 16 45 15 46 12 17 41 8 13 3 6 44 9 10 19 36 5\n",
|
258 |
+
" 35 4 26]\n"
|
259 |
+
]
|
260 |
+
}
|
261 |
+
],
|
262 |
+
"source": [
|
263 |
+
"print(f'A route of length {final_return}')\n",
|
264 |
+
"print('The route is:\\n', resulting_traj)"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "markdown",
|
269 |
+
"id": "b009802e",
|
270 |
+
"metadata": {
|
271 |
+
"id": "b009802e"
|
272 |
+
},
|
273 |
+
"source": [
|
274 |
+
"## Display it in a 2d-grid\n",
|
275 |
+
"- Darker color means later steps in the route."
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": 8,
|
281 |
+
"id": "dc681a06",
|
282 |
+
"metadata": {
|
283 |
+
"tags": [
|
284 |
+
"\"hide-cell\""
|
285 |
+
],
|
286 |
+
"cellView": "form",
|
287 |
+
"id": "dc681a06"
|
288 |
+
},
|
289 |
+
"outputs": [],
|
290 |
+
"source": [
|
291 |
+
"#@title Helper function for plotting\n",
|
292 |
+
"# colorline taken from https://nbviewer.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb\n",
|
293 |
+
"import matplotlib.pyplot as plt\n",
|
294 |
+
"from matplotlib.collections import LineCollection\n",
|
295 |
+
"from matplotlib.colors import ListedColormap, BoundaryNorm\n",
|
296 |
+
"\n",
|
297 |
+
"def make_segments(x, y):\n",
|
298 |
+
" '''\n",
|
299 |
+
" Create list of line segments from x and y coordinates, in the correct format for LineCollection:\n",
|
300 |
+
" an array of the form numlines x (points per line) x 2 (x and y) array\n",
|
301 |
+
" '''\n",
|
302 |
+
"\n",
|
303 |
+
" points = np.array([x, y]).T.reshape(-1, 1, 2)\n",
|
304 |
+
" segments = np.concatenate([points[:-1], points[1:]], axis=1)\n",
|
305 |
+
" \n",
|
306 |
+
" return segments\n",
|
307 |
+
"\n",
|
308 |
+
"def colorline(x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0), linewidth=1, alpha=1.0):\n",
|
309 |
+
" '''\n",
|
310 |
+
" Plot a colored line with coordinates x and y\n",
|
311 |
+
" Optionally specify colors in the array z\n",
|
312 |
+
" Optionally specify a colormap, a norm function and a line width\n",
|
313 |
+
" '''\n",
|
314 |
+
" \n",
|
315 |
+
" # Default colors equally spaced on [0,1]:\n",
|
316 |
+
" if z is None:\n",
|
317 |
+
" z = np.linspace(0.3, 1.0, len(x))\n",
|
318 |
+
" \n",
|
319 |
+
" # Special case if a single number:\n",
|
320 |
+
" if not hasattr(z, \"__iter__\"): # to check for numerical input -- this is a hack\n",
|
321 |
+
" z = np.array([z])\n",
|
322 |
+
" \n",
|
323 |
+
" z = np.asarray(z)\n",
|
324 |
+
" \n",
|
325 |
+
" segments = make_segments(x, y)\n",
|
326 |
+
" lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha)\n",
|
327 |
+
" \n",
|
328 |
+
" ax = plt.gca()\n",
|
329 |
+
" ax.add_collection(lc)\n",
|
330 |
+
" \n",
|
331 |
+
" return lc\n",
|
332 |
+
"\n",
|
333 |
+
"def plot(coords):\n",
|
334 |
+
" x,y = coords.T\n",
|
335 |
+
" lc = colorline(x,y,cmap='Reds')\n",
|
336 |
+
" plt.axis('square')\n",
|
337 |
+
" return lc"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": 9,
|
343 |
+
"id": "bb0548fb",
|
344 |
+
"metadata": {
|
345 |
+
"colab": {
|
346 |
+
"base_uri": "https://localhost:8080/",
|
347 |
+
"height": 282
|
348 |
+
},
|
349 |
+
"id": "bb0548fb",
|
350 |
+
"outputId": "3e967b86-32e9-4be3-e5b7-c8a457aa12a7"
|
351 |
+
},
|
352 |
+
"outputs": [
|
353 |
+
{
|
354 |
+
"output_type": "execute_result",
|
355 |
+
"data": {
|
356 |
+
"text/plain": [
|
357 |
+
"<matplotlib.collections.LineCollection at 0x7f15aabac8b0>"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
"metadata": {},
|
361 |
+
"execution_count": 9
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"output_type": "display_data",
|
365 |
+
"data": {
|
366 |
+
"text/plain": [
|
367 |
+
"<Figure size 432x288 with 1 Axes>"
|
368 |
+
],
|
369 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD4CAYAAAAjDTByAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAA4QElEQVR4nO2dd3iUVdqH72dm0kklBUKAUBIhdAgqIk1aAAUBRWy7Kuradm27a1td1911dXVddXU/xYJ1VURUFKSIUpTeqyGhJgTSgCQkpMzM+f6YgBECTJJ3Wubc15UrM/OeOe8vk/mdfp4jSik0Go3/YfK0AI1G4xm0+TUaP0WbX6PxU7T5NRo/RZtfo/FTLJ66cWxsrEpOTvbU7TUav2D9+vVFSqm4+q55zPzJycmsW7fOU7fXaPwCEdl/tmvnbfaLyNsiUiAi285yXUTkZRHJFpEtItK3KWI1Go17cKbP/w6QcY7rY4CU2p/bgf9ruiyNRuNqzmt+pdQy4Mg5kkwA3lMOVgFRItLaKIEajcY1GDHa3wbIqfM8t/a1MxCR20VknYisKywsNODWGo2msbh1qk8pNV0pla6USo+Lq3cAUqPRuAkjzH8QaFvneVLtaxqNxosxwvxzgF/VjvpfDJQopQ4ZkK9Go3Eh553nF5GPgKFArIjkAn8GAgCUUq8B84CxQDZQAdzsKrEajcY4zmt+pdS157mugLsNU6TRaNyCXtuv0fgp2vwajZ+iza/R+Cna/BqNn6LNr9H4KR7b0ust5H35FftnzsJWWUVkt65EdksjolsaEakpmAIDPS1Po3EZfm3+fW+9Q9YLL9HrlX9jq6mhdPtODi9aTOaL/6F8/wFadOxAZLc0Irt1JSItjchuaQS3SkBEPC1do2kyfml+pRS7nn2e3FmzuWTuF4Qltweg9aiRp9LYKisp25VFybbtlGzfSf73/0fJtu2gqFMYOFoK4RekYgkN9dSfo9E0Cr8zv7LZ2PrHRzi6YSOXzptDUHz9G4zMwcFE9exBVM8eP79XKaoKCinZvoOS7Tso+nEFu6e/RVn2bkLbJhGZ9nO3IbJbV0LbtvXrVoKtpoacL7+icP1GJCiQpNEjie+fjsnid187r0Q8dWJPenq6cncYL1tlJRt+czfWklLS33ubgIhwQ/K119RQlpVN6Y6dpwqGku07sR4/TmRaVyJqWwiRaWlEdO1i2H29DaUUZVnZFCz7gfxlyzm6eQshycm0u2oSpXv3kbNwEeU5B2kzfBhtR4+kzYjLCG4Z42nZzRoRWa+USq/3mr+Yv6a0lLU33kxgTAx9XnsFc1CQy+9ZdeRIbYHgKBRKd+yg9KdMguLiaguDn1sKLTokI2azyzUZTfmBHAqW/0DBsh8oWP4DYrYQP+RSEgYPIn7QpYS0bvXL9HmHyF34LTkLFnFo+Q9Ed+1C24xRtB09kuhuaX7dUnIFfm/+yvwCVl9zPdHp/ejx7N89ajJls3F87z5KT7YQalsLVYWFRHS54NTA4snxhMDoaI9prY/KwkIKlv9IwbLlFCz7gZrjx4kfdCnxgy8lYcggwtq3d9rA1spKDv+wgtwFi8hZ8C32mmqSRo+k7eiRtB58KQFhYS7+a5o/fm3+8r37WHXVVJKuuZrUPzzgtTVLTWkZpTt/omSHo1AorW0tWMLDTxUEkd3SiEjrSnhKZ0wBAW7RVV1SSuGPKxy1+9LlVBzMI27gAOKHDCJh0KVEdO1iyGeqlKIkK5ucBYvIXbCIoo2biL/oQtrWFgbhtYOymobht+Yv2bKV1VNvJPUP95N8869dei9XoJSiIieHku07f24pbN9BRe5Bwjt3+sUUZGS3NIIT4pt8T2tFBUWr1zpq9uU/UJq5i5b900/V7tG9erplwK66pJSD3y9xtAoWLiY4JvpUqyDh4gvdVvj5On5p/qLlP7L+1t/Q47lnSBx/ucvu4wmsFRWU/pR5qnVQssPRUhCL2TGoWKelEJ6agjk4+Kx52WtqOLJ+I/m1zfijmzYR1b078YMvJX7IIFqm93PL+Mi5UHY7RRs3n2oVlO7dS+KwIbQdNZKkUcMJ0SHhzorfmT9vztds/cPD9HvzdWIHDXTJPbwNpRSVhw7/ojAo2b6D43v3Eda+3c/dhm5dCYiIoGjtegqWLqdozRpadOjgGKAbfCmxAy4moIV397Ur8vPJXbSY3AWLyFuyjIjOnU91D1r26oGY9Kr1k/iV+ffNeJddz/2biz5+n8g6c/T+iq2qyjENWWcK0hwVSWB0NAmDLyXu0ksIivHd6TZbdTX5K1efahVUl5WRNHK4Yypx2FACwlt4WqJH8QvzK6XY9dwL5H7yKRd/+hFhHTsYlrfGdyjdvYec2tmDgrVrie+fTttRI0gaPZLIzp08Lc/tNHvzK5uNbQ//iSNr13HRJx8aMvCl8X1qjh8nb8kyR6tg4bdYQkNpO3okSaNH0mrgAMx+sHGrWZvfVlXFxjt/S3VxMf3fn0FARIQB6jTNDaUUR7ZsPdUqKNm1i9aDB9F29AiSRo0gtFWr82figzRb89eUlrHuV7dgiYyg7+uvnnNUW6Opy4miIg4u+o6cBYs4+N0SYtP7kHDJJXS+ehLh7dt5Wp5hNEvzVxUUsnrq9UT16U2Pf/7DJ5fGarwDu9XK4VVr2D1rNnu++IqolM50njKZjhMnEBLb0tPymkSzM3/5vv2svvpa2kyeSOpDv/faVXsa38NWXU3u4u/J/vQzDiz8llYXXUjna64iedwYn1xu3KzMX7JtO2uuuYGUB+4ledpNxgvTaGqpOX6cfXO/IWvmZ+SvXkO70SPpfPVkkoYPw+wjKwybjfmLflzB+ltup8ezT5N45XgXKdNozuREYRG7P/+S7JmfUZK9m44Tx5My5SoSLurv1YuKmoX5D309jy0P/pG+0/+PuCGDXKhMozk3pfv2k/3pbLJnzqKmooLOV08mZcpkYtK6elraGfi8+fe/9wGZzzzPhf97l6jevVysTKNxDqUUR7ZtJ+uTWWTPmk1QdDQpV0+m09WTCG+b5Gl5gA+bXylF1r9eJOd/H3PRpx/RolNHN6nTaBqGsts5tGIl2TM/Y++XXxPVJZWUa66i44TxHo1W5LPmryosZO0NN5H+7tsEt0pwkzKNpmnYqqvJWbSY7E8/I+fb72h9yQA6T5lM+7EZBLg50Ou5zO/VkRQDoqOp3LsPZbN6WopG4zTmwECSx40hedwYqsvK2Pf1N2R++DHL7/s97ceMpvOUq0gaNsTjgUy9d5gSMFksxI8aScE3CzwtRaNpFIHh4aReO4Vxn8/kmg2riO/Xl/X/eI4PUnvww+8f5vDqtXiq9e3V5gdIGDua/HnzPS1Do2kyofHxdL/jNiZ+N58J384jND6OpXf9jo96prPmqac5+tMut+pxyvwikiEimSKSLSIP13O9nYh8LyIbRWSLiIw1SmDs0CGUbNxM9dGjRmWp0XicyI4d6PvHB5mybgWjPpiBrbKSr8dPYtbAYWx+6RWOH8xzuYbzml9EzMCrwBggDbhWRNJOS/YnYKZSqg8wFfivUQLNoaG0HDSQwkWLjcpSo/EaRITYXj0Z8PRTXL9zMwP+8RTHsrKZNWAIX42dwM4Z71F5xDUVnzM1/4VAtlJqj1KqGvgYmHBaGgWc3EsbCRhabMWPzdBNf02zx2Q202bwIIa88iI3Zm2j+52/Ife7JXzUoy8Lpt7IgQWLjL2fE2naADl1nufWvlaXJ4EbRCQXmAf8tr6MROR2EVknIusKCwudFhk/agTFS5djO3HC6fdoNL6MOSiIDleMZeT7b3P9zi20G5PBkof/hLWqyrB7GDXgdy3wjlIqCRgLvC8iZ+StlJqulEpXSqXHNSDiamBMDBE9e1C0ZJlBcjUa3yEwIpwWye0IDA/HYmAkZWfMfxBoW+d5Uu1rdZkGzARQSq0EgoFYIwSeJGFcBvnf6Ka/xj/Z+81COowZbWiezph/LZAiIh1EJBDHgN6c09IcAIYDiEhXHOZ3vl3vBAljMihYsAi7VS/4MRpltVJ9YL+nZWjOwd75C+iQMcrQPM9rfqWUFbgHWADsxDGqv11EnhKRk/tqHwRuE5HNwEfATcrglQshbZMISUzk2Br3nuxrBPbqak9LOCdH33uXnGumkP/Yw9j1uIrXcTQrm5rj5cT17mlovk71+ZVS85RSqUqpTkqpv9e+9oRSak7t4x1KqYFKqV5Kqd5KqYWGqqwlfmwG+XO/cUXWLqEq9yD7n/o7G3r2Y88993pazlkp/2E5cY88CjVWcqdMpjo7y9OSNHXYO38RHTJGGR6xyutX+NWl1bgx5H8z32PLIZ1BKUXZmrXsuu1OtowYjbLWkPbVbI6vXM3x9Rs8Le8M7JWVnNiwgbBhlxH/7HNE3XQzub+6ntIvZntamqaWvd8soMMYY5v84GPmb9G1C5hMlG3f4WkpZ2Cvrqbos8/ZlnE52b+7n/CL+tNn7UqSn3qS0JQUWt97D3nPveBpmWdwYs1qgrumYQ4PR0SImHw1bd75gKNvvE7+ow9hr6jwtES/pqqklIING2k7dLDhefuU+UWEhDHeteCnpqiY3H+/xMb+Ayj46BOSHryP3j8upfWtt2AJDz+VruXUKVTu3kPZmrUeVHsm5cuWETb4l1+soNRU2s6cDXZF7pTJVGXpboCnOLD4e1oPuNglwUN9yvwACWMzyJ/n+X5/+Y6d7L7/92waOISqnFy6fPQBabM+JnrUyHrDiJsCA2l932/Je967av/y5WeaH8AUFkbCM/8katqtHPz19ZTOnuXV3a3myt5vFtBxrLFTfCfxOfNHX5hO1eF8KvYfcPu9lc3GkQUL2TH5Gn669gaC2rWl949L6fTCc4Q5Eb+t5ZSrqNp/gLKVq92g9vzU5OZiKykhqOvpWzV+JmLiZNq8+yHH3n6Lgof/iL283I0K/Ru7zca+hd+SbPAU30l8zvxiNhM/ehT5btzjby0r49Abb7Fp4BAOvvAS8ddNpc/alSTdfy8BDTjUwRQQQOL993pN7V++fBlhgwadN/psUEoKSTM/A4uZnCmTqNrl3q2n/kr+ug2EtWpFhIviAfqc+cHR9C9wQ7+/cv9+9j3xJBv7D6BszTo6/+dFus//mtjJEzE18pDHlldNojovj9IVKw1W23Ac5h/iVFpTaCgJf3+G6Nt+w8GbbqB01qe6G+Bi9syb75JR/pP4pPlbDhlE6bbtVBUVG563UoqSH1eSedM0to25AgkIpOe3C0h94/8I75/e5LlWsVhoff+95D33gkfNo6qrqVi1itCBAxv0vogrJ5H03occe3cGBQ/9QXcDXIgrlvTWxSfNbw4OJnboYAoXGrfF0V5ZScFHn7B1+Gj2PvwoUcOG0mftKto//ihBSadvYmwaLSddSU1BAWU//Ghovg3hxMaNBHbogKVlw8+iC+zs6AZIYCA5V0+ialemCxT6N6UHcqjIzychva/L7uGT5geIN2jKrzo/n5xnn2dj+gCK53xNuz89Qq+li0n49Y2Yw1wTaVUsFhIfuM+jtb+jyd/4uWNTSAjxf3uamDvu4uBNN1Ly6UzdDTCQvfMXkjxqBCYXHkDru+YfeRnFP6zAWt64RSjWsjKy77mPzYOHU1NcTNrsmXT96H2iLhvmluOXYq4cj/XoMUqXLnf5veqjfNkyQgc1/eSj8PETSPrgI0ref5f8Pz6IraTEAHUaVzf5wYfNHxAVRVS/PhR9v6RR78/590sos4k+q5bT8Z//ICQ1xViB50HMZhIfvI+8591f+1sLCqg5mEtIr96G5BfYsRNJMz/DEp9A3vgxFP/5MSo3rtctgUZSU17OoZWraDd8mEvv47PmB0gYN6bRTf+y9RuJmzQRS3S0waqcJ/qKcdjKyij9bolb71v+w3JCLxmIGBg33hQcTOwfHqL1p59jaZ9M8Z//RN4VGZS88RrWw4cNu48/kLNkGfF9+xAUGXH+xE3Ap80fnzGKwkWLsdfUNOh9SinKd/5EaNcuLlLmHCdr/4Nurv3rW9JrFJb4BCJvuY3EL+cR+/dnseblcWjyePJ/M43yb+aiDAxD1VxxR5MffNz8IYmJhCa352gDV8xVHzqEKSCAwHjnQ4m5iujLx6EqKyn51j3RiZXNRsWKHwm71LUnHYsIQb160/LPT9Fm8TJajL+S47NnkTt8MMV//TNVW7fobkE9KKVctovvdHza/ACxw4ZQsPi7Br3HG2r9k4jJROKD95P3/L/dYobKrVsxx8UR0Lq1y+91ElNwMGHjriDhjRm0/vRzzPEJFD30IIcmXk7JjDexFRka9MmnKdy0hYDwFkR37uTye/m8+Su276BFu7bnT1j3PTt2EuYl5geIGpuBstkoMXDdwtkoX768SVN8TcXSOpGo39xF4tyFxDz+F2r27CZv/FgK7rmD8kULUDXeHfXI1eydv5COYzLcci+fNn/Rwm+pyN5Nmxuua9D7ynf+5NRGHHdxsvY/6Ibav3z5Upf19xuCiBDcL53Yv/6DNt8uIXTkaMr+9wG5wwdz5Jm/Ub3T+2I2uIM98+aTnDHSLffyWfPbTpwg89HHueCZv2NqYDhjb2r2nyQqYxSCcMyFG5Zsx45RnZVFSL96T2z2GKbQMFpMmEirGe/T6sOZmFqEU3Dv3eRdNYHSD97FduSIpyW6DWkRDm46vddnzb/vpVcI79mDlsOc25hyEntNDZV79hCamuoiZY1DREj8w/2OeX+73SX3KF/xIyH9L2xwYelOAtq2I+qee2kzfzHRv3+Y6m1bOXj5KAruu4eK7xejGjiz42t0GjOKTW/McMu9fNL8Fbv3kDvjXVL/9mSD33siezdBbdpgDg0xXlgTiRw5AgkMpHjmLJfk78opPqMRk4mQiwcQ+8zzJC34npABl3DkqSeo3rnd09JcSo9fXc+ehd9SlnfI5ffyOfMrpch85HHa//ZughMTG/z+Ci9s8p9EROjw8gscfv558l/+D8pmMyxvpVST1/N7CgkKonrFckJ69ybwHIFHmgPBUZF0vXoSm956x+X38jnzF3w9j8q8PNr95tZGvb98509eNdJ/OiGpqVzw9RyOr1jB7qnXUp1nzJmnVZmZmIKDCWzf3pD83IWqqqL4wd9CQAAx/3wRCWhcHAVfou8dt7L5rXcNPZevPnzK/Nbj5ex6/Em6PPs0poCARuVRvvMnQr1opL8+Alq1ouOHHxA+dChZl4/n2Nx5Tc6zfNlSwgY3bHzE06jKSorvvwtTaBgx//gX0sj/ua8R27ULcd3TyJz9pUvv41Pm3/vCi0RfcjHRAwc0Oo+Knd41x382xGwm4e676PD2mxx65lly/vgwtiaE0a7wsSa//cQJiu+9A1NUNNF//6ffGP8k/e76Deteed2lU78+Y/7jmbvI+9/HpPz58UbnYS0tpebIUYJ9qOkb2rs3qd/MRdVUkzX2ciq2bmtUPnaThUAvm+E4G/aKcop/ezvmhFZE//VZQzcg+QqdMkZSeewYeS48ns4nzK+UIvOhR+nw4P0EJcQ36L12q5Wyrds4+PY77LznPrft1zcSc4sWtPv3CyTcdy97bvwVBdPfaPB0oDkujpKvv3aRQuOwlx+n+J7bsbRtT9STT9cbBt0fEJOJvr+Zxvr/TnfdPTy1uSI9PV2tW+dcqXZo1mz2v/oaFy6ah+k8tUDNkSOUrN9A6dp1lKxbT9mmLQQltiYyvR8R6f2IzRhNYAMi7nobVQdyOHDvvZhCw2j3wr8IcLIwPLFlCwfuuovUpUu91lD2sjKK77kNS8oFRD36Z58rpI2m8lgJr3XtzbT1KwhPbNxeDBFZr5Sqd1WX17enrKWlZD35N3rOmH6G8ZXNRnnmLkrXradk3XpK1q6jOr+AiD69ieifTru77ySibx8CPLhn32iC2rWl86czyX/pZXaNHUfbZ58hYsTw874vpGdPLLGxlH3/PREjRrhBacOwl5VSdNetBHbrTuRDjxt+KKUvEhwVSdo1V7HprXcY9Pgjhufv9TV/5qNPYKuoIO3F56kpKaF0/UaH2deuo2zjJgJiW56q1SP7pxPW5QKvrdmM5vjqNRy4734iR46k9aMPYwoOPmf6o7NnU/LFFyS/956bFDqHveQYRXfeQmCfdCJ//4g2fh2Kfsrk44wJ3JG5GUsjVmaeq+b3avMfW7eejVOuIz5jFOXbtlOZk0t4r55E9O/nMHy/fj7dhDcC67ESch95lKrdu2n/n5cJvuDsg3r2qioyBw6k48yZBHXs6EaVZ8d29CjFd95C0EUDiLjvD9r49fDJFZPpdu0Uul93TYPfey7zo5TyyE+/fv3U+ajML1Abrp6qct58W5Vu3qJsNTXnfY8/YrfbVfEnn6itvfqownffU3a7/axpDz37rMr7y1/cqO7clM14Qx175q/n1OzvZM2dr9655LJGfUbAOnUWDzo1oiIiGSKSKSLZIvLwWdJMEZEdIrJdRP7X4CKqHoLi4+gz8yOSpt1MeM8e5x3s81dEhJgpU0iZPYsjH3/Cvltvw3qWnXAx113Hsc8/95qjt2vWrCB40BBd458DV037ndf8ImIGXgXGAGnAtSKSdlqaFOARYKBSqhtwn6EqNU4R1LEjnb+YTVDHTuzKGFvvoSCBSUmE9u/PsS++cL/AerDlHcTcpmHBWPwNW3U1Cen92Dnrc0PzdabmvxDIVkrtUUpVAx8DE05LcxvwqlLqKIBSqsBQlRqnMQUGkvjYI7R9/jkO3P8AeU//A3v1L6PjtPzVryh+7z2Px9BTNhu2gsOYWzd8g5Y/cGT3HhY9/Dgvde5BxZGj9Lz5RkPzd8b8bYCcOs9za1+rSyqQKiI/isgqEak3DpGI3C4i60RkXWGhjtvmSsIHD+KC+fOoysome+JkqvbsOXUtbOBAVE0NFWvXelAh2AsOY4qOQRp56GlzxG618tOXX/PhuEnMGJqBmM3csmwR1331KfEG70kxqhNtAVKAoUASsExEeiiljtVNpJSaDkwHx2i/QffWnAVLy5Ykv/0mxe+9T9akq0h89GGir74aEaHlDTdw5P33CbvwQo/ps+UdxJxo7DmIvkpZ3iE2znifjW+/R2S7tvS7/RaumTS+UdN7zuKM+Q8CdTtlSbWv1SUXWK2UqgH2isguHIWBZ6sWDSJC7K9/RdjFF3Hgnt9RumQpbf/xNFGTJ5P/4ovUFBQQEN+wJdNGYTuYiznRNWfP+wJKKfZ9v4x1b7zNviXL6XbVRKZ+8QkJPbq55f7ONPvXAiki0kFEAoGpwJzT0nyBo9ZHRGJxdAP2oPEaQi64gJSvvsQSG0vmmLGc+CmTqCuu4OhHH3lMk7/W/CeOHmPVy//lvz0vYuEfH6PDsCH8btdmxv7nX24zPjhR8yulrCJyD7AAMANvK6W2i8hTOOYQ59ReGyUiOwAb8AelVLErhWsajik4mKSn/kLp4MHsv/NOIkePpvjDD4m76y6PbJm1Hcwh8KJL3H5fT5G3bgPrXn+LzK/m0TljFONff5mkARd5bJrTqT6/UmoeMO+0156o81gBD9T+aLyciBHDSZ03jwP3P4C1vIIjH39MyxuNHUl2BkfN37yb/dXl5WyfOZv1b8zgxJGj9LvtZoY//RfC4mI9Lc37N/ZoXENAQjwdP3iPQ8/+k5A+fTyiwZaX22yb/YU/ZbLhjRls/ehT2l5yMUOffJROIy7zqp2K2vx+jJhMJD5S74JNl6MqK7GXlmCK88xgoyuwVVfz05y5rJ8+g+JdWfS+6QZuW7WUyHbe2brR5td4BNvhPMwJrZrFDsySA7lseOtdNr37AS1TU+h3+810GT8Os5evX9Dm13gEX5/mU3Y7uxctZt30GeSuXE2Pa6/mhvlfENflAk9Lcxptfo1HsObm+GR/v7ywiM3v/Y/1b8wgJCaafrfdzKT33iAwLMzT0hqMNr/GI9h3biXgwsZHYXYnSilyV65m3fQZZM9fyAXjxzH5g7dITO/raWlNQptf43aU1UrNyuWE3nWfp6Wck6qyMrZ+9Cnrp7+Ntaqa9NtvJuPfzxISHeVpaYagza9xO9YtGzEltsEc38rTUurl0JatbJg+gx2ffUHy0EGMeu5pkocOanYxB7T5NW6nevn3BF46zNMyzsoP//oPmd8soseUifT91XUk9e/X7IwPPhK3X9N8UEo5zD/Ye81/1Tuvc8vCOYS2bMln0+7muU7dmfvgI+z7YQV2Aw9P9TS65te4Fdve3WCzYe7svVNiIkJi7x4k9u7ByL88RsGOn9g2ew5f3ftHjhcWkTZhHN0nTSB50CWYfTi0nFdH79U0P068+wb24kLCHnjU01IaRVHWbrZ/Podts+dw7EAOaePH0W3iFXQcNhiLFy7q8dnQ3ZrmR8mt1xFy+28J9JFpvnNxdN8Btn8+h+2fz6FwVzZdxo6m26TxdB4xjIDznKHgLrT5NV6BvbiIY9eOJ3ru0mZ36m5J7kF2fDmX7Z/P4dDmraRmjKTbxPGkZowgMDTUY7q0+TVeQeWXs6hZv5rwp57ztBSXcjy/gB1z5rJt9hxy166n8/BhdJ80ngvGjiIoPNytWrT5NV5B6e/vJmjUOIJGjfW0FLdRUXyEHV/NY/vnc9j/wyo6DLmU7hPH0+WKMYRERbr8/tr8Go+jTlRw9PJhRH2xCFN4hKfleIQTx0rInDufbbPnsGfpctoNuJDukybQ9YqxhLno2DmfPqVX0zyoWbMSS1oPvzU+QEhUJL2vv4be119DVVkZmd8sYvvnXzHvD4+RlN6XbhOvIG3C5YS3SnCLHm1+jVuoXv49AYOGelqG1xAUHk7PKZPoOWUS1RUVZC38ju2zv2Th43+lVY9udJs4nm5XXk5kkut2Pupmv8blKJuNo1cMI/KtjzC39r1tvO7EWlVF9rffs/3zr9j59TfEpnSm28Qr6DZxPDEd2jc4P93s13gU67bNmFrGauM7gSUoiC7jMugyLgNbTQ17lixn2+w5vHrxMEb9/QkuuvUm4+5lWE4azVmoXv49gYMu87QMn8McEEDKyMtIGXkZMamdyduy3dD89cYejcupXr6EQN3fbxLF2XuIT+1saJ7a/BqXYtu/F1VRjrmL+06iaY4U7comLjXF0Dy1+TUuxdHkH9os98O7k8Ks3cSmdjI0T21+jUtxmN979+77AtXl5VQUHyGqXdvzJ24AesDPx1BlpVh3bEFEkMQkTK0SEYv3bZJRVivV8+egqioJ6HeRp+X4NEVZu2nZqQMmg0/70eb3EVR1NTVfzaT6k3cx9R+IKsjHnpeDKipAYuMxJSY5flq3xdSmbe3ztkhIiHt1WmuoXvA1Ve9Nx5TQmhaPPIl44T53X6JgVzZxBg/2gTa/16PsdqxLF1I147+YkjsR8s/XMCf/3PdT1hpU/iHsB3Ow5+Viz8uhZttGx/PDeUhoWJ3CoLZAqH0uEVGG9cVVTQ3V8+dQ9f4bmFonEfrIX7H0rndtiaaBFO3KJjZFm9+vsG5eR9WbLwMQ/OATWHqdaSaxBCBt2mFq0+6Ma8puRxUXYc87UFsw5FKzcumpggKlMLVO+mXh0Kado2CIjXfqUElVU0P1vM+p/OAtzG3bE/qnp7H09O149t5G4a5sUkYMNTxfbX4vxLZvN1Vvv4J9/x6Cbr4Ly+CRjTrdVUwmJC7ecRhmPQWHKi2pLQgchYFt22ZqFn7t6E6UlWFqlYgp8eeCQU4WDq0SQSmq5852mD65E2F/fhZL994G/PWa0ynalc0ld95qeL5OmV9EMoCXADPwplLqmbOkmwzMAvorpfTC/UZg3bKByr89RODUmwn407Mu7S9LRCTmiEjMXbufcU2dOIH9UO6proRt/x7sK5dhz8tFFeVj6toTAgIIe+p5LN16uUyjv6OUonBXNrGe6POLiBl4FRgJ5AJrRWSOUmrHaenCgXuB1Yar9COsP3xLwPXTCJww1aM6JCQEc8cUzB3PXFiirDWoykpMLdwblcYfOV5QiCnAQljLGMPzdqYteSGQrZTao5SqBj4GJtST7q/As0Clgfr8jxMVXr/nXSwB2vhuojAzizgXDPaBc+ZvA+TUeZ5b+9opRKQv0FYpNddAbX6JOlGBhHgu4KPGuyjM2k3cBZ4z/zkRERPwAvCgE2lvF5F1IrKusLCwqbdunpyoAG1+TS2umuYD58x/EKi7rjCp9rWThAPdgSUisg+4GJgjImcMLyulpiul0pVS6XFxcY1X3YzRNb+mLoW7sjxa868FUkSkg4gEAlOBOScvKqVKlFKxSqlkpVQysAoYr0f7G4c6UY6EhHlahsZLKN6zj9jOxm7oOcl5za+UsgL3AAuAncBMpdR2EXlKRMa7RJU/o5v9mjqUn6jE7KLTf5ya51dKzQPmnfbaE2dJO7TpsvwX3ezX1CWsZQzlRcWQYnztr1f4NQBVWYH1x8VQXg7BwUhgsON3UEjt72AICobgkNrHIWCxOL1+XikFlSfAzZtxNN5LeHwcZfkFLslbm78B1Hz8BvbDeZgS2kBVJfaqE1BZiaqqhKpKVO1zqk6gqqqg6gTYVZ2CIQSCgpFgRyEhtc9PXlfmADBbELP+t2gchMfHUVbgmpkx/S1zElVehm3ZQoKfm4G0jHf+fVarozCoLRQchUQlVJ6o/f1zwSFVlQT/7SUX/hUaX8NR82vzexTr4q8x97m4QcYHEIsFLOFImF4Rp2k44QnxFGbtdkneOoyXEyirFdv8WVjGTfG0FI2f0SI+ltIC1/T5tfmdwLZ6CZLQBlPHCzwtReNnhCfEc7ygyCV5a/OfB6UU1q8/wXL5NZ6WovFDXDnar81/Huw7N8OJckx9BnhaisYPCU+Id9lovzb/ebDOnYll7JRGRdLRaJpKaEw0laVl2GpqDM9bf6PPgf1QDvbMrZgHj/a0FI2fYjKZaBHbkuOFxvf7tfnPgfWbWVhGjEeC9Yo7jedoER9HqQv6/dr8Z0EdL8X2wyIsoyZ6WorGzwlPiOO4C/r92vxnwfrtHMz9BiIxsZ6WovFzXLXKT5u/HpS1Buv82VjG6ek9jecJj3fNiL82fz3YVnyHqU17TMmuiaCi0TSEFvGxLpnr1+Y/DaWUY3rvcr2UV+MdRCTE69F+d2DfvhGqqzD10ifLarwDPdrvJqxzZ2IZpxf1aLwHPdrvBux5B7Bn79CLejRehR7tdwPWeZ9iGTkBCQzytBSN5hQt4uM4XljkCPNmINr8tajSY9hWLNaLejReR0BQEIGhoVQcPWpovjqSTy3WRV9ivnAwEmX8gYj+hrJaIf8AqqTYMXYSHQ/R8bpF1QQc/f4iwmKM+35q8wOqphrrws8JeuwFT0vxOZTdDsWHUAf3onL3oA7ugYJciElAkrugyo7CkXw4VgTBoY7XY+IhOs7xODoeYuKRYB2u/FycHPFP6JJqWJ7a/IDtx8WY2nXC1K6jp6V4NUopKD3ys8kP7kHl7YPQcCSpI5LYAVOPi6B18hm1vFJ2KD0GR/JRRwvgSAFqx1rUEcdjLAGOQiDG0UpwFAxxEBPvyN/J8OfNFVeM+Pu9+R2Lej4h4Pq7PC3F61AVx382+cG9jsdKIW06Ikkd4dJxmBI7OBWcVMQEkTEQGYN06PrL+ygF5aWOAuFogaOlkLUZ+9ECKM6HvkMwj/TvRVfhccaP+Pu9+e1b14PdjqlXf09L8SjKakXl7YHaWl0d3APlpUhiB2jTEel9KaZxN0JkS8NrYRGBFpHQIhJpl/JLXYcPYJ/7rqH380VcEdHH781vnfuJY1GPHzYrld2O2rsDtXklavsa6JiGRLZEOvfANHQCtGzt+cVOZgvYbJ7V4AW0iI/l4Kathubp1+a35+7DvjeLwAf/5mkpbkMpBTnZ2DevQG1dBeFRmHpdgumepx19bG/DbAGb8SGsfI2IhHgyDV7f79fmt86diWXUlc1+CkopBYcPOAy/ZSVYAjD1HIDptseRuERPyzs3FgvYrJ5W4XFcsb7fb82vSo5iW72E4Bf/52kpLkMVHUZtXoF9ywqorkJ6DsB8w/2O0Xhf6eZYAsCqa35H/H7d5zcE68IvMF88DImI8rQUQ1ElxagtK7FvXgElR5AeF2GedDu07ez5/ntjMOuaHyA8PpYygw/v8Evzq+oqrIu+IOiJ5nEopjpeitq22mH4/BykW39MGdciHdIQs9nT8pqG2QJWbf7giAhs1dVUnzhBoEFHuDtlfhHJAF4CzMCbSqlnTrv+AHArYAUKgVuUUvsNUegCbD8swtQhFVNSsqelNBpVWYHavha1eQXqQBZyQW9Mg8Yhqb0QS4Cn5RmH2QJ2G0rZHWsF/BQRITzesdAnpn07Q/I8r/lFxAy8CowEcoG1IjJHKbWjTrKNQLpSqkJE7gT+CXhlALyTkXoCfv07T0tpMKq6CpW50WH47G1IxzSk3xBMN9yPBAZ7Wp5LEJGfp/ss/mt+OBnOy43mBy4EspVSewBE5GNgAnDK/Eqp7+ukXwXcYIg6F2DfvAZMZkw9+nlailMoqxWVtQW1ZQXqp41IUiek1wBMk29HQlp4Wp57ONn0b04tmkZg9EIfZ8zfBsip8zwXOFeMq2nAN/VdEJHbgdsB2rUzpvRqKNa5n2C53LsX9Zyx+CaujWNqbuwNSHiUp+W5n1Nz/f59eIrRh3YaOuAnIjcA6cCQ+q4rpaYD0wHS09ONjUzgBPYDu7Hn7CXwkuHuvvV58bnFN+7EEqAH/ag1v5tr/oNA2zrPk2pf+wUiMgJ4DBiilKoyRp6xWOd+imXURCQg0NNSAB9ffONO9Co/wLHQ5+iBXMPyc8b8a4EUEemAw/RTgevqJhCRPsDrQIZSyjWHiTcRdawY29plXrOox36iHPv0v0JluW8uvnEnepUf4Ojz56zfaFh+5zW/UsoqIvcAC3BM9b2tlNouIk8B65RSc4DngBbAp7Vf3gNKqfGGqTQA68IvMF8y3HsW9ezZCUHBmH/7tG8uvnEneq4fqF3o4+4+v1JqHjDvtNeeqPN4hGGKXIBjUc+XBP3lFU9LOYVERENlhTa+M+hVfsDJ0X7jVvn5xTfPtnQ+ppQ0TImemWGol8gYKDniaRW+gSVAmx/jR/ubvfmV3e4IyW3woZtNDqPcIhIqyx3BLjXnxmzRm3uAsNiWVBw9ht2g+AbN3vz2TashMAhTWm/j8szcQs3vr8eeu7fReYjJBC2ioMzYcMzNEl3zA2C2WAiJiuR4UbEh+TV78zsW9Vxj2Ci6PWsb1leeQiKjoexY0zKLiNZNfycQswWla36g9tBOg+b6m7X57fuysB88gHnAMGPy270D63/+guX2h5DYBKg60aT8JDIGVapr/vOiB/xOYeRx3c3a/Na5M7FkTDJkl5t9bybWl/6MZdrvMfXoD0HBUFnZtEwjYqBU1/znRc/zn8LIEf9ma351pAjb+hVYRjR9uYF9fzbWF/+E5eb7Tx3dLcEhqKbW/BHRKN3sPyfqRDnqeIme56/FEcLbmJq/2QbzsC6YjfnSkUiLiCblY8/Zi/WFR7Hc+DtMfS75+UJQcJOb/UTGwOGc86fzM9SJ46jsLajMjZC3F9p3QfoO9rQsryA8wbj1/c3S/KryBNbFXxH0t/9rWj4H92P918OYr7sLU/qgX14Mj0btWI894ORJM404XSYiBnSfH6g1fNYW1K4NkLcPkrsi3Qcg429t9gFWG0KL+DgKsxs/y1SXZml+29L5mLr0xNQqqdF5qEMHqHn+IcxTbsN80dAzrptGTkJZTKjiw5C1FfvJY6dsNoiJO3XUlMTEO56fPIYqMhoxOUJrSUQ0dj/u89dv+EuQ8bdpw5+F8IR4jhfqmr9eHIt6ZhJwx8ONz+NwLjXPPYR58s2YL6l/5bKYTMjwSWe+90QFHC2oPYOu0PH74N7awqEQKkohuQuYTEh4JJQYM2frK6iK46jszY4m/aF9SHJXTD0uAW14pzBylV+zM799wwoIDcfUpWej3q8KDjlq/Ak3Yr50dIPfLyGhEJKMJCbXn39NDepoEZQWw9EiZOiERun0Jeo1fM+BMEEbvqE4+vzGjPY3O/Nb585scKQeVXEc+87N2LdvQOXnYh47FfOQsS7RJwEBSHxriG/tkvy9BVVRhsrajNq1EQ4f+NnwV97uNfEUfJGTNb9SqskL15qV+e17MlH5efX20euiaqqxZ+/Avm099h0bUbl7MXVOw9StH6bLrsDsw1F9PUm9hu91KVzZTRu+iShlB5uVABOYAyxUlpUREtG0maxmZX7r159gHjMZsfzyz1J2O+rAbuzb12PfvgF71g4ksR2m7v2wXD0NU+duSKD+cjaGU4bP3Aj5Jw0/CK5Ma3aGVzVVoBRYAn+xFVvZbY6NRzar4+e0x8pmdUQistb+Pj2NrcaxwetUunrystvAHAAWC1P/dBdmS9Ot22zMby/Kx7ZpNcHT7nc8L8jDvm2Dw/A7NiIRUZjS+mC+bDwBdz+BhPlJ5FsXYS8+hFr0CRQeRDqkYeo9CDo0P8PXxf7tx3D8qMOggcFgMteuPFSnjIk5wLEc2WxxbEgyBzgqI3Oda5YAx/tPpbFgsgSclua09CbzqWZ+vyuM+XuajfmtX32MKbUb1o9fx75tA8pag6lbX8y9Lybg+rsc8/Aa48jPgeNHkVuewBQW7mk17qH8GKZx0yCkBaqmGhHAbDk1detrNBvzi8XiONWlbUcCMq5CEtvreHguRLr2h0P7UAs/RE24zWcN0CCCwxzRl0LDm8UsRbMxf8CNd3tagl8hIjBsMurLN1BLZsOwq5p/YRsUCpXlnlZhGM12Y4/G9YjJjGnsTai8PaiNSz0tx+VIcBiqssLTMgxDm1/TJCQoGNP421Hrv0ft3uppOa4lWNf8Gs0vkIhoTOOnYf/2Y1R+M96lWNvnby5o82sMQRLaYRp+Dfav3kS5MS6hstagsrdg37EW1dSwauejmdX8zWbAT+N5pHNPpKQY+5fTMV19LxLkmmPDlbJD3l5U5gbUnq0Q2wZatkKtmY906on0GYq0iDT8vhIchr0Z1fza/BpDkb5D4Vgh9nnvYDJ4ClAdyXcYftcGCApGUvthmvoA0iLKcT19OGrzcuyzXkY693IUAmFNWwL7C3TNr9GcnV9OAX4OwyY3aQpQlZeisjY5DF9RhqT0wTTuZiT2zINMJaQFcvEYVM9LHYXApy8hKX2QPkOQUAMWIjWzPr82v8ZwTk4B2me+BBuXOloDDUDVVKP2bHUYPj/HsXx4wFho09mp480kNBwZMNZRCGxahn3mi5DaF+k5EFNtK6FR1Nb8Ruyo8wa0+TUuQYKCMU24DfsnL6EiWyKdepwzvbLbIDfb0azft8NxYvEF6UjGrxu9X0DCIpCBl6N6D0L9tA71xWvY4pIwpfaBdhcg5oZ9/SUgEEQcm22awR4GbX6Ny5CIGExXTMP+5euYWkQhCW1/cV0pBUV5qMz1qKxN0CIKSe2LaeAVSKhxG68kLBLpNxzVaxBq307sP62FlXORjj2Q1L5ITILzmQWHOfr92vwazbmRVj9PAZquuQ8Jj0aVHUXt2ojK3ADWauSCvpiuvAOJjnetFksg0rkXdO6FKj2CytqEfdGHENICSe2DdOiOBIWcO5OTg37h0S7V6g60+TUuxzEFWIR99msQEQXFh5HOPTENnQyt2yPi/uUmEhGD9LsM1Wco5O1G7dqIfd23SNtUJLUvtDqLrmY06KfNr3EL0ncYxLZBrNWOKL0N7G+7CjGZICkFSUpBVZajdm/FvuobsNYgKb2Rzr1/sWZAgkNRleX4/nCfk+YXkQzgJcAMvKmUeua060HAe0A/oBi4Rim1z1ipGl9GRJD2F3haxjmR4DCk28WotIugOM/RGvjyNYhtc2qQ0K9qfhExA68CI4FcYK2IzFFK7aiTbBpwVCnVWUSmAs8C17hCsEbjakTE0UqJbYO6cBRq/0+nBgmJiofo5hEYxpma/0IgWym1B0BEPgYmAHXNPwF4svbxLOAVERGllDJQq0bjdsQSiHTqCZ16okqPYN+5Bknp42lZhuDMSEsboO5Wrdza1+pNo5SyAiVAy9MzEpHbRWSdiKwrNOjUEY3GXUhEDOaLMjA1ZGrQi3HrMKtSarpSKl0plR4X1zyaThqNr+KM+Q8CdVdnJNW+Vm8aEbEAkTgG/jQajZfijPnXAiki0kFEAoGpwJzT0swBfl37+CrgO93f12i8m/MO+CmlrCJyD7AAx1Tf20qp7SLyFLBOKTUHeAt4X0SygSM4CgiNRuPFODXPr5SaB8w77bUn6jyuBK42VppGo3ElOoyXRuOnaPNrNH6KNr9G46eIpwblRaQQ2F/PpVigyM1yGorW2HS8XR80D43tlVL1LqrxmPnPhoisU0qle1rHudAam46364Pmr1E3+zUaP0WbX6PxU7zR/NM9LcAJtMam4+36oJlr9Lo+v0ajcQ/eWPNrNBo3oM2v0fgpHjO/iGSISKaIZIvIw/VcDxKRT2qvrxaRZC/U+ICI7BCRLSKyWETae5O+Oukmi4gSEbdPWzmjUUSm1H6O20Xkf96mUUTaicj3IrKx9n891s363haRAhHZdpbrIiIv1+rfIiJ9ncpYKeX2Hxy7A3cDHYFAYDOQdlqau4DXah9PBT7xQo3DgNDax3e6U6Mz+mrThQPLgFVAuhd+hinARiC69nm8F2qcDtxZ+zgN2OdmjYOBvsC2s1wfC3wDCHAxsNqZfD1V85+KC6iUqgZOxgWsywTg3drHs4Dh4t4D0s6rUSn1vVLqZCjXVTgCnXiNvlr+iiOgaqUbtZ3EGY23Aa8qpY4CKKUKvFCjAk4e9xsJ5LlRH0qpZTi2yp+NCcB7ysEqIEpEWp8vX0+Z37C4gC7EGY11mYaj9HUX59VX2/xrq5Sa60ZddXHmM0wFUkXkRxFZVRsm3p04o/FJ4AYRycWxtf237pHmNA39rgL60A5DEJEbgHRgiKe1nEQcx828ANzkYSnnw4Kj6T8UR8tpmYj0UEod86So07gWeEcp9S8RGYAjcE13pZTd08Kagqdqfl+IC+iMRkRkBPAYMF4pVeUmbXB+feFAd2CJiOzD0Rec4+ZBP2c+w1xgjlKqRim1F9iFozBwF85onAbMBFBKrQSCcWyo8Rac+q6egTsHLuoMUFiAPUAHfh5k6XZamrv55YDfTC/U2AfHYFGKN36Gp6VfgvsH/Jz5DDOAd2sfx+Jovrb0Mo3fADfVPu6Ko88vbv4skzn7gN84fjngt8apPN35B5wmeCyOUn438Fjta0/hqEHBUbp+CmQDa4COXqjxWyAf2FT7M8eb9J2W1u3md/IzFBzdkx3AVmCqF2pMA36sLRg2AaPcrO8j4BBQg6OlNA24A7ijzmf4aq3+rc7+n/XyXo3GT9Er/DQaP0WbX6PxU7T5NRo/RZtfo/FTtPk1Gj9Fm1+j8VO0+TUaP+X/Abem6TLSC0KwAAAAAElFTkSuQmCC\n"
|
370 |
+
},
|
371 |
+
"metadata": {
|
372 |
+
"needs_background": "light"
|
373 |
+
}
|
374 |
+
}
|
375 |
+
],
|
376 |
+
"source": [
|
377 |
+
"plot(nodes_coordinates[resulting_traj])"
|
378 |
+
]
|
379 |
+
}
|
380 |
+
],
|
381 |
+
"metadata": {
|
382 |
+
"kernelspec": {
|
383 |
+
"display_name": "Python 3",
|
384 |
+
"name": "python3"
|
385 |
+
},
|
386 |
+
"language_info": {
|
387 |
+
"name": "python"
|
388 |
+
},
|
389 |
+
"colab": {
|
390 |
+
"provenance": []
|
391 |
+
},
|
392 |
+
"accelerator": "GPU",
|
393 |
+
"gpuClass": "standard"
|
394 |
+
},
|
395 |
+
"nbformat": 4,
|
396 |
+
"nbformat_minor": 5
|
397 |
+
}
|
environment.yml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cleanrl
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
dependencies:
|
5 |
+
- _libgcc_mutex=0.1
|
6 |
+
- _openmp_mutex=4.5
|
7 |
+
- asttokens=2.1.0
|
8 |
+
- backcall=0.2.0
|
9 |
+
- backports=1.0
|
10 |
+
- backports.functools_lru_cache=1.6.4
|
11 |
+
- bzip2=1.0.8
|
12 |
+
- ca-certificates=2022.9.24
|
13 |
+
- debugpy=1.6.3
|
14 |
+
- decorator=5.1.1
|
15 |
+
- entrypoints=0.4
|
16 |
+
- executing=1.2.0
|
17 |
+
- ipykernel=6.17.0
|
18 |
+
- ipython=8.6.0
|
19 |
+
- jedi=0.18.1
|
20 |
+
- jupyter_client=7.4.4
|
21 |
+
- jupyter_core=4.11.2
|
22 |
+
- ld_impl_linux-64=2.39
|
23 |
+
- libffi=3.4.2
|
24 |
+
- libgcc-ng=12.2.0
|
25 |
+
- libgomp=12.2.0
|
26 |
+
- libnsl=2.0.0
|
27 |
+
- libsodium=1.0.18
|
28 |
+
- libsqlite=3.39.4
|
29 |
+
- libstdcxx-ng=12.2.0
|
30 |
+
- libuuid=2.32.1
|
31 |
+
- libzlib=1.2.13
|
32 |
+
- matplotlib-inline=0.1.6
|
33 |
+
- ncurses=6.3
|
34 |
+
- nest-asyncio=1.5.6
|
35 |
+
- openssl=3.0.7
|
36 |
+
- packaging=21.3
|
37 |
+
- parso=0.8.3
|
38 |
+
- pexpect=4.8.0
|
39 |
+
- pickleshare=0.7.5
|
40 |
+
- pip=22.3.1
|
41 |
+
- prompt-toolkit=3.0.32
|
42 |
+
- psutil=5.9.4
|
43 |
+
- ptyprocess=0.7.0
|
44 |
+
- pure_eval=0.2.2
|
45 |
+
- pygments=2.13.0
|
46 |
+
- pyparsing=3.0.9
|
47 |
+
- python=3.8.13
|
48 |
+
- python-dateutil=2.8.2
|
49 |
+
- python_abi=3.8
|
50 |
+
- pyzmq=24.0.1
|
51 |
+
- readline=8.1.2
|
52 |
+
- setuptools=65.5.1
|
53 |
+
- six=1.16.0
|
54 |
+
- sqlite=3.39.4
|
55 |
+
- stack_data=0.6.0
|
56 |
+
- tk=8.6.12
|
57 |
+
- tornado=6.2
|
58 |
+
- traitlets=5.5.0
|
59 |
+
- wcwidth=0.2.5
|
60 |
+
- wheel=0.38.3
|
61 |
+
- xz=5.2.6
|
62 |
+
- zeromq=4.3.4
|
63 |
+
- pip:
|
64 |
+
- absl-py==1.3.0
|
65 |
+
- cachetools==5.2.0
|
66 |
+
- certifi==2022.9.24
|
67 |
+
- cfgv==3.3.1
|
68 |
+
- charset-normalizer==2.1.1
|
69 |
+
- cloudpickle==2.2.0
|
70 |
+
- distlib==0.3.6
|
71 |
+
- filelock==3.8.0
|
72 |
+
- google-auth==2.14.1
|
73 |
+
- google-auth-oauthlib==0.4.6
|
74 |
+
- grpcio==1.50.0
|
75 |
+
- gym==0.23.1
|
76 |
+
- gym-notices==0.0.8
|
77 |
+
- identify==2.5.8
|
78 |
+
- idna==3.4
|
79 |
+
- importlib-metadata==5.0.0
|
80 |
+
- llvmlite==0.39.1
|
81 |
+
- markdown==3.4.1
|
82 |
+
- markupsafe==2.1.1
|
83 |
+
- nodeenv==1.7.0
|
84 |
+
- numba==0.56.4
|
85 |
+
- numpy==1.23.4
|
86 |
+
- nvidia-cublas-cu11==11.10.3.66
|
87 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
88 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
89 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
90 |
+
- oauthlib==3.2.2
|
91 |
+
- pillow==9.3.0
|
92 |
+
- platformdirs==2.5.3
|
93 |
+
- pre-commit==2.20.0
|
94 |
+
- protobuf==3.20.3
|
95 |
+
- pyasn1==0.4.8
|
96 |
+
- pyasn1-modules==0.2.8
|
97 |
+
- pygame==2.1.0
|
98 |
+
- pyyaml==6.0
|
99 |
+
- requests==2.28.1
|
100 |
+
- requests-oauthlib==1.3.1
|
101 |
+
- rsa==4.9
|
102 |
+
- tensorboard==2.11.0
|
103 |
+
- tensorboard-data-server==0.6.1
|
104 |
+
- tensorboard-plugin-wit==1.8.1
|
105 |
+
- toml==0.10.2
|
106 |
+
- torch==1.13.0
|
107 |
+
- torchvision==0.14.0
|
108 |
+
- typing-extensions==4.4.0
|
109 |
+
- urllib3==1.26.12
|
110 |
+
- virtualenv==20.16.6
|
111 |
+
- werkzeug==2.2.2
|
112 |
+
- zipp==3.10.0
|
113 |
+
prefix: /opt/conda/envs/cleanrl
|
envs/cvrp_vector_env.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
import numpy as np
|
3 |
+
from gym import spaces
|
4 |
+
|
5 |
+
from .vrp_data import VRPDataset
|
6 |
+
|
7 |
+
|
8 |
+
def assign_env_config(self, kwargs):
|
9 |
+
"""
|
10 |
+
Set self.key = value, for each key in kwargs
|
11 |
+
"""
|
12 |
+
for key, value in kwargs.items():
|
13 |
+
setattr(self, key, value)
|
14 |
+
|
15 |
+
|
16 |
+
def dist(loc1, loc2):
|
17 |
+
return ((loc1[:, 0] - loc2[:, 0]) ** 2 + (loc1[:, 1] - loc2[:, 1]) ** 2) ** 0.5
|
18 |
+
|
19 |
+
|
20 |
+
class CVRPVectorEnv(gym.Env):
|
21 |
+
def __init__(self, *args, **kwargs):
|
22 |
+
self.max_nodes = 50
|
23 |
+
self.capacity_limit = 40
|
24 |
+
self.n_traj = 50
|
25 |
+
# if eval_data==True, load from 'test' set, the '0'th data
|
26 |
+
self.eval_data = False
|
27 |
+
self.eval_partition = "test"
|
28 |
+
self.eval_data_idx = 0
|
29 |
+
self.demand_limit = 10
|
30 |
+
assign_env_config(self, kwargs)
|
31 |
+
|
32 |
+
obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
|
33 |
+
obs_dict["depot"] = spaces.Box(low=0, high=1, shape=(2,))
|
34 |
+
obs_dict["demand"] = spaces.Box(low=0, high=1, shape=(self.max_nodes,))
|
35 |
+
obs_dict["action_mask"] = spaces.MultiBinary(
|
36 |
+
[self.n_traj, self.max_nodes + 1]
|
37 |
+
) # 1: OK, 0: cannot go
|
38 |
+
obs_dict["last_node_idx"] = spaces.MultiDiscrete([self.max_nodes + 1] * self.n_traj)
|
39 |
+
obs_dict["current_load"] = spaces.Box(low=0, high=1, shape=(self.n_traj,))
|
40 |
+
|
41 |
+
self.observation_space = spaces.Dict(obs_dict)
|
42 |
+
self.action_space = spaces.MultiDiscrete([self.max_nodes + 1] * self.n_traj)
|
43 |
+
self.reward_space = None
|
44 |
+
|
45 |
+
self.reset()
|
46 |
+
|
47 |
+
def seed(self, seed):
|
48 |
+
np.random.seed(seed)
|
49 |
+
|
50 |
+
def _STEP(self, action):
|
51 |
+
|
52 |
+
self._go_to(action) # Go to node 'action', modify the reward
|
53 |
+
self.num_steps += 1
|
54 |
+
self.state = self._update_state()
|
55 |
+
|
56 |
+
# need to revisit the first node after visited all other nodes
|
57 |
+
self.done = (action == 0) & self.is_all_visited()
|
58 |
+
|
59 |
+
return self.state, self.reward, self.done, self.info
|
60 |
+
|
61 |
+
# Euclidean cost function
|
62 |
+
def cost(self, loc1, loc2):
|
63 |
+
return dist(loc1, loc2)
|
64 |
+
|
65 |
+
def is_all_visited(self):
|
66 |
+
# assumes no repetition in the first `max_nodes` steps
|
67 |
+
return self.visited[:, 1:].all(axis=1)
|
68 |
+
|
69 |
+
def _update_state(self):
|
70 |
+
obs = {"observations": self.nodes[1:]} # n x 2 array
|
71 |
+
obs["depot"] = self.nodes[0]
|
72 |
+
obs["action_mask"] = self._update_mask()
|
73 |
+
obs["demand"] = self.demands
|
74 |
+
obs["last_node_idx"] = self.last
|
75 |
+
obs["current_load"] = self.load
|
76 |
+
return obs
|
77 |
+
|
78 |
+
def _update_mask(self):
|
79 |
+
# Only allow to visit unvisited nodes
|
80 |
+
action_mask = ~self.visited
|
81 |
+
|
82 |
+
# can only visit depot when last node is not depot or all visited
|
83 |
+
action_mask[:, 0] |= self.last != 0
|
84 |
+
action_mask[:, 0] |= self.is_all_visited()
|
85 |
+
|
86 |
+
# not allow visit nodes with higher demand than capacity
|
87 |
+
action_mask[:, 1:] &= self.demands <= (
|
88 |
+
self.load.reshape(-1, 1) + 1e-5
|
89 |
+
) # to handle the floating point subtraction precision
|
90 |
+
|
91 |
+
return action_mask
|
92 |
+
|
93 |
+
def _RESET(self):
|
94 |
+
self.visited = np.zeros((self.n_traj, self.max_nodes + 1), dtype=bool)
|
95 |
+
self.visited[:, 0] = True
|
96 |
+
self.num_steps = 0
|
97 |
+
self.last = np.zeros(self.n_traj, dtype=int) # idx of the last elem
|
98 |
+
self.load = np.ones(self.n_traj, dtype=float) # current load
|
99 |
+
|
100 |
+
if self.eval_data:
|
101 |
+
self._load_orders()
|
102 |
+
else:
|
103 |
+
self._generate_orders()
|
104 |
+
self.state = self._update_state()
|
105 |
+
self.info = {}
|
106 |
+
self.done = np.array([False] * self.n_traj)
|
107 |
+
return self.state
|
108 |
+
|
109 |
+
def _load_orders(self):
|
110 |
+
data = VRPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx]
|
111 |
+
self.nodes = np.concatenate((data["depot"][None, ...], data["loc"]))
|
112 |
+
self.demands = data["demand"]
|
113 |
+
self.demands_with_depot = self.demands.copy()
|
114 |
+
|
115 |
+
def _generate_orders(self):
|
116 |
+
self.nodes = np.random.rand(self.max_nodes + 1, 2)
|
117 |
+
self.demands = (
|
118 |
+
np.random.randint(low=1, high=self.demand_limit, size=self.max_nodes)
|
119 |
+
/ self.capacity_limit
|
120 |
+
)
|
121 |
+
self.demands_with_depot = self.demands.copy()
|
122 |
+
|
123 |
+
def _go_to(self, destination):
|
124 |
+
dest_node = self.nodes[destination]
|
125 |
+
dist = self.cost(dest_node, self.nodes[self.last])
|
126 |
+
self.last = destination
|
127 |
+
self.load[destination == 0] = 1
|
128 |
+
self.load[destination > 0] -= self.demands[destination[destination > 0] - 1]
|
129 |
+
self.demands_with_depot[destination[destination > 0] - 1] = 0
|
130 |
+
self.visited[np.arange(self.n_traj), destination] = True
|
131 |
+
self.reward = -dist
|
132 |
+
|
133 |
+
def step(self, action):
|
134 |
+
# return last state after done,
|
135 |
+
# for the sake of PPO's abuse of ff on done observation
|
136 |
+
# see https://github.com/opendilab/DI-engine/issues/497
|
137 |
+
# Not needed for CleanRL
|
138 |
+
# if self.done.all():
|
139 |
+
# return self.state, self.reward, self.done, self.info
|
140 |
+
|
141 |
+
return self._STEP(action)
|
142 |
+
|
143 |
+
def reset(self):
|
144 |
+
return self._RESET()
|
envs/tsp_data.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
root_dir = "/home/cwan5/OR/attention-learn-to-route/data/tsp/"
|
5 |
+
|
6 |
+
|
7 |
+
def load(filename, root_dir=root_dir):
|
8 |
+
return pickle.load(open(root_dir + filename, "rb"))
|
9 |
+
|
10 |
+
|
11 |
+
file_catalog = {
|
12 |
+
"test": {
|
13 |
+
20: "tsp20_test_seed1234.pkl",
|
14 |
+
50: "tsp50_test_seed1234.pkl",
|
15 |
+
100: "tsp100_test_seed1234.pkl",
|
16 |
+
},
|
17 |
+
"eval": {
|
18 |
+
20: "tsp20_validation_seed4321.pkl",
|
19 |
+
50: "tsp50_validation_seed4321.pkl",
|
20 |
+
100: "tsp100_validation_seed4321.pkl",
|
21 |
+
},
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class lazyClass:
|
26 |
+
data = {
|
27 |
+
"test": {},
|
28 |
+
"eval": {},
|
29 |
+
}
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
partition, nodes, idx = index
|
33 |
+
if not (partition in self.data) or not (nodes in self.data[partition]):
|
34 |
+
logging.warning(
|
35 |
+
f"Data sepecified by ({partition}, {nodes}) was not initialized. Attepmting to load it for the first time."
|
36 |
+
)
|
37 |
+
data = load(file_catalog[partition][nodes])
|
38 |
+
self.data[partition][nodes] = [tuple(instance) for instance in data]
|
39 |
+
|
40 |
+
return self.data[partition][nodes][idx]
|
41 |
+
|
42 |
+
|
43 |
+
TSPDataset = lazyClass()
|
envs/tsp_vector_env.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
import numpy as np
|
3 |
+
from gym import spaces
|
4 |
+
|
5 |
+
from .tsp_data import TSPDataset
|
6 |
+
|
7 |
+
|
8 |
+
def assign_env_config(self, kwargs):
|
9 |
+
"""
|
10 |
+
Set self.key = value, for each key in kwargs
|
11 |
+
"""
|
12 |
+
for key, value in kwargs.items():
|
13 |
+
setattr(self, key, value)
|
14 |
+
|
15 |
+
|
16 |
+
def dist(loc1, loc2):
|
17 |
+
return ((loc1[:, 0] - loc2[:, 0]) ** 2 + (loc1[:, 1] - loc2[:, 1]) ** 2) ** 0.5
|
18 |
+
|
19 |
+
|
20 |
+
class TSPVectorEnv(gym.Env):
|
21 |
+
def __init__(self, *args, **kwargs):
|
22 |
+
self.max_nodes = 50
|
23 |
+
self.n_traj = 50
|
24 |
+
# if eval_data==True, load from 'test' set, the '0'th data
|
25 |
+
self.eval_data = False
|
26 |
+
self.eval_partition = "test"
|
27 |
+
self.eval_data_idx = 0
|
28 |
+
assign_env_config(self, kwargs)
|
29 |
+
|
30 |
+
obs_dict = {"observations": spaces.Box(low=0, high=1, shape=(self.max_nodes, 2))}
|
31 |
+
obs_dict["action_mask"] = spaces.MultiBinary(
|
32 |
+
[self.n_traj, self.max_nodes]
|
33 |
+
) # 1: OK, 0: cannot go
|
34 |
+
obs_dict["first_node_idx"] = spaces.MultiDiscrete([self.max_nodes] * self.n_traj)
|
35 |
+
obs_dict["last_node_idx"] = spaces.MultiDiscrete([self.max_nodes] * self.n_traj)
|
36 |
+
obs_dict["is_initial_action"] = spaces.Discrete(1)
|
37 |
+
|
38 |
+
self.observation_space = spaces.Dict(obs_dict)
|
39 |
+
self.action_space = spaces.MultiDiscrete([self.max_nodes] * self.n_traj)
|
40 |
+
self.reward_space = None
|
41 |
+
|
42 |
+
self.reset()
|
43 |
+
|
44 |
+
def seed(self, seed):
|
45 |
+
np.random.seed(seed)
|
46 |
+
|
47 |
+
def reset(self):
|
48 |
+
self.visited = np.zeros((self.n_traj, self.max_nodes), dtype=bool)
|
49 |
+
self.num_steps = 0
|
50 |
+
self.last = np.zeros(self.n_traj, dtype=int) # idx of the first elem
|
51 |
+
self.first = np.zeros(self.n_traj, dtype=int) # idx of the first elem
|
52 |
+
|
53 |
+
if self.eval_data:
|
54 |
+
self._load_orders()
|
55 |
+
else:
|
56 |
+
self._generate_orders()
|
57 |
+
self.state = self._update_state()
|
58 |
+
self.info = {}
|
59 |
+
self.done = False
|
60 |
+
return self.state
|
61 |
+
|
62 |
+
def _load_orders(self):
|
63 |
+
self.nodes = np.array(TSPDataset[self.eval_partition, self.max_nodes, self.eval_data_idx])
|
64 |
+
|
65 |
+
def _generate_orders(self):
|
66 |
+
self.nodes = np.random.rand(self.max_nodes, 2)
|
67 |
+
|
68 |
+
def step(self, action):
|
69 |
+
|
70 |
+
self._go_to(action) # Go to node 'action', modify the reward
|
71 |
+
self.num_steps += 1
|
72 |
+
self.state = self._update_state()
|
73 |
+
|
74 |
+
# need to revisit the first node after visited all other nodes
|
75 |
+
self.done = (action == self.first) & self.is_all_visited()
|
76 |
+
|
77 |
+
return self.state, self.reward, self.done, self.info
|
78 |
+
|
79 |
+
# Euclidean cost function
|
80 |
+
def cost(self, loc1, loc2):
|
81 |
+
return dist(loc1, loc2)
|
82 |
+
|
83 |
+
def is_all_visited(self):
|
84 |
+
# assumes no repetition in the first `max_nodes` steps
|
85 |
+
return self.visited[:, :].all(axis=1)
|
86 |
+
|
87 |
+
def _go_to(self, destination):
|
88 |
+
dest_node = self.nodes[destination]
|
89 |
+
if self.num_steps != 0:
|
90 |
+
dist = self.cost(dest_node, self.nodes[self.last])
|
91 |
+
else:
|
92 |
+
dist = np.zeros(self.n_traj)
|
93 |
+
self.first = destination
|
94 |
+
|
95 |
+
self.last = destination
|
96 |
+
|
97 |
+
self.visited[np.arange(self.n_traj), destination] = True
|
98 |
+
self.reward = -dist
|
99 |
+
|
100 |
+
def _update_state(self):
|
101 |
+
obs = {"observations": self.nodes} # n x 2 array
|
102 |
+
obs["action_mask"] = self._update_mask()
|
103 |
+
obs["first_node_idx"] = self.first
|
104 |
+
obs["last_node_idx"] = self.last
|
105 |
+
obs["is_initial_action"] = self.num_steps == 0
|
106 |
+
return obs
|
107 |
+
|
108 |
+
def _update_mask(self):
|
109 |
+
# Only allow to visit unvisited nodes
|
110 |
+
action_mask = ~self.visited
|
111 |
+
# can only visit first node when all nodes have been visited
|
112 |
+
action_mask[np.arange(self.n_traj), self.first] |= self.is_all_visited()
|
113 |
+
return action_mask
|
envs/vrp_data.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
root_dir = "/home/cwan5/OR/attention-learn-to-route/data/vrp/"
|
7 |
+
|
8 |
+
|
9 |
+
def load(filename, root_dir=root_dir):
|
10 |
+
return pickle.load(open(root_dir + filename, "rb"))
|
11 |
+
|
12 |
+
|
13 |
+
file_catalog = {
|
14 |
+
"test": {
|
15 |
+
20: "vrp20_test_seed1234.pkl",
|
16 |
+
50: "vrp50_test_seed1234.pkl",
|
17 |
+
100: "vrp100_test_seed1234.pkl",
|
18 |
+
},
|
19 |
+
"eval": {
|
20 |
+
20: "vrp20_validation_seed4321.pkl",
|
21 |
+
50: "vrp50_validation_seed4321.pkl",
|
22 |
+
100: "vrp100_validation_seed4321.pkl",
|
23 |
+
},
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def make_instance(args):
|
28 |
+
depot, loc, demand, capacity, *args = args
|
29 |
+
grid_size = 1
|
30 |
+
if len(args) > 0:
|
31 |
+
depot_types, customer_types, grid_size = args
|
32 |
+
return {
|
33 |
+
"loc": np.array(loc) / grid_size,
|
34 |
+
"demand": np.array(demand) / capacity,
|
35 |
+
"depot": np.array(depot) / grid_size,
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class lazyClass:
|
40 |
+
data = {
|
41 |
+
"test": {},
|
42 |
+
"eval": {},
|
43 |
+
}
|
44 |
+
|
45 |
+
def __getitem__(self, index):
|
46 |
+
partition, nodes, idx = index
|
47 |
+
if not (partition in self.data) or not (nodes in self.data[partition]):
|
48 |
+
logging.warning(
|
49 |
+
f"Data sepecified by ({partition}, {nodes}) was not initialized. Attepmting to load it for the first time."
|
50 |
+
)
|
51 |
+
data = load(file_catalog[partition][nodes])
|
52 |
+
self.data[partition][nodes] = [make_instance(instance) for instance in data]
|
53 |
+
|
54 |
+
return self.data[partition][nodes][idx]
|
55 |
+
|
56 |
+
|
57 |
+
VRPDataset = lazyClass()
|
models/attention_model_wrapper.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .nets.attention_model.attention_model import *
|
4 |
+
|
5 |
+
|
6 |
+
class Problem:
|
7 |
+
def __init__(self, name):
|
8 |
+
self.NAME = name
|
9 |
+
|
10 |
+
|
11 |
+
class Backbone(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
embedding_dim=128,
|
15 |
+
problem_name="tsp",
|
16 |
+
n_encode_layers=3,
|
17 |
+
tanh_clipping=10.0,
|
18 |
+
n_heads=8,
|
19 |
+
device="cpu",
|
20 |
+
):
|
21 |
+
super(Backbone, self).__init__()
|
22 |
+
self.device = device
|
23 |
+
self.problem = Problem(problem_name)
|
24 |
+
self.embedding = AutoEmbedding(self.problem.NAME, {"embedding_dim": embedding_dim})
|
25 |
+
|
26 |
+
self.encoder = GraphAttentionEncoder(
|
27 |
+
n_heads=n_heads,
|
28 |
+
embed_dim=embedding_dim,
|
29 |
+
n_layers=n_encode_layers,
|
30 |
+
)
|
31 |
+
|
32 |
+
self.decoder = Decoder(
|
33 |
+
embedding_dim, self.embedding.context_dim, n_heads, self.problem, tanh_clipping
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, obs):
|
37 |
+
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
|
38 |
+
input = state.states["observations"]
|
39 |
+
embedding = self.embedding(input)
|
40 |
+
encoded_inputs, _ = self.encoder(embedding)
|
41 |
+
|
42 |
+
# decoding
|
43 |
+
cached_embeddings = self.decoder._precompute(encoded_inputs)
|
44 |
+
logits, glimpse = self.decoder.advance(cached_embeddings, state)
|
45 |
+
|
46 |
+
return logits, glimpse
|
47 |
+
|
48 |
+
def encode(self, obs):
|
49 |
+
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
|
50 |
+
input = state.states["observations"]
|
51 |
+
embedding = self.embedding(input)
|
52 |
+
encoded_inputs, _ = self.encoder(embedding)
|
53 |
+
cached_embeddings = self.decoder._precompute(encoded_inputs)
|
54 |
+
return cached_embeddings
|
55 |
+
|
56 |
+
def decode(self, obs, cached_embeddings):
|
57 |
+
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
|
58 |
+
logits, glimpse = self.decoder.advance(cached_embeddings, state)
|
59 |
+
|
60 |
+
return logits, glimpse
|
61 |
+
|
62 |
+
|
63 |
+
class Actor(nn.Module):
|
64 |
+
def __init__(self):
|
65 |
+
super(Actor, self).__init__()
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
logits = x[0] # .squeeze(1) # not needed for pomo
|
69 |
+
return logits
|
70 |
+
|
71 |
+
|
72 |
+
class Critic(nn.Module):
|
73 |
+
def __init__(self, *args, **kwargs):
|
74 |
+
super(Critic, self).__init__()
|
75 |
+
hidden_size = kwargs["hidden_size"]
|
76 |
+
self.mlp = nn.Sequential(
|
77 |
+
nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
out = self.mlp(x[1]) # B x T x h_dim --mlp--> B x T X 1
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
class Agent(nn.Module):
|
86 |
+
def __init__(self, embedding_dim=128, device="cpu", name="tsp"):
|
87 |
+
super().__init__()
|
88 |
+
self.backbone = Backbone(embedding_dim=embedding_dim, device=device, problem_name=name)
|
89 |
+
self.critic = Critic(hidden_size=embedding_dim)
|
90 |
+
self.actor = Actor()
|
91 |
+
|
92 |
+
def forward(self, x): # only actor
|
93 |
+
x = self.backbone(x)
|
94 |
+
logits = self.actor(x)
|
95 |
+
action = logits.max(2)[1]
|
96 |
+
return action, logits
|
97 |
+
|
98 |
+
def get_value(self, x):
|
99 |
+
x = self.backbone(x)
|
100 |
+
return self.critic(x)
|
101 |
+
|
102 |
+
def get_action_and_value(self, x, action=None):
|
103 |
+
x = self.backbone(x)
|
104 |
+
logits = self.actor(x)
|
105 |
+
probs = torch.distributions.Categorical(logits=logits)
|
106 |
+
if action is None:
|
107 |
+
action = probs.sample()
|
108 |
+
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
|
109 |
+
|
110 |
+
def get_value_cached(self, x, state):
|
111 |
+
x = self.backbone.decode(x, state)
|
112 |
+
return self.critic(x)
|
113 |
+
|
114 |
+
def get_action_and_value_cached(self, x, action=None, state=None):
|
115 |
+
if state is None:
|
116 |
+
state = self.backbone.encode(x)
|
117 |
+
x = self.backbone.decode(x, state)
|
118 |
+
else:
|
119 |
+
x = self.backbone.decode(x, state)
|
120 |
+
logits = self.actor(x)
|
121 |
+
probs = torch.distributions.Categorical(logits=logits)
|
122 |
+
if action is None:
|
123 |
+
action = probs.sample()
|
124 |
+
return action, probs.log_prob(action), probs.entropy(), self.critic(x), state
|
125 |
+
|
126 |
+
|
127 |
+
class stateWrapper:
|
128 |
+
"""
|
129 |
+
from dict of numpy arrays to an object that supplies function and data
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, states, device, problem="tsp"):
|
133 |
+
self.device = device
|
134 |
+
self.states = {k: torch.tensor(v, device=self.device) for k, v in states.items()}
|
135 |
+
if problem == "tsp":
|
136 |
+
self.is_initial_action = self.states["is_initial_action"].to(torch.bool)
|
137 |
+
self.first_a = self.states["first_node_idx"]
|
138 |
+
elif problem == "cvrp":
|
139 |
+
input = {
|
140 |
+
"loc": self.states["observations"],
|
141 |
+
"depot": self.states["depot"].squeeze(-1),
|
142 |
+
"demand": self.states["demand"],
|
143 |
+
}
|
144 |
+
self.states["observations"] = input
|
145 |
+
self.VEHICLE_CAPACITY = 0
|
146 |
+
self.used_capacity = -self.states["current_load"]
|
147 |
+
|
148 |
+
def get_current_node(self):
|
149 |
+
return self.states["last_node_idx"]
|
150 |
+
|
151 |
+
def get_mask(self):
|
152 |
+
return (1 - self.states["action_mask"]).to(torch.bool)
|
models/nets/__init__.py
ADDED
File without changes
|
models/nets/attention_model/__init__.py
ADDED
File without changes
|
models/nets/attention_model/attention_model.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from ...nets.attention_model.decoder import Decoder
|
4 |
+
from ...nets.attention_model.embedding import AutoEmbedding
|
5 |
+
from ...nets.attention_model.encoder import GraphAttentionEncoder
|
6 |
+
|
7 |
+
|
8 |
+
class AttentionModel(nn.Module):
|
9 |
+
r"""
|
10 |
+
The Attention Model from Kool et al.,2019.
|
11 |
+
`Link to paper <https://arxiv.org/abs/1803.08475>`_.
|
12 |
+
|
13 |
+
For an instance :math:`s`,
|
14 |
+
|
15 |
+
.. math::
|
16 |
+
\log(p_\theta(\pi|s)),\pi = \mathrm{AttentionModel}(s)
|
17 |
+
|
18 |
+
The following is executed:
|
19 |
+
|
20 |
+
.. math::
|
21 |
+
\begin{aligned}
|
22 |
+
\pmb{x} &= \mathrm{Embedding}(s) \\
|
23 |
+
\pmb{h} &= \mathrm{Encoder}(\pmb{x}) \\
|
24 |
+
\{\log(\pmb{p}_t)\},\pi &= \mathrm{Decoder}(s, \pmb{h}) \\
|
25 |
+
\log(p_\theta(\pi|s)) &= \sum\nolimits_t\log(\pmb{p}_{t,\pi_t})
|
26 |
+
\end{aligned}
|
27 |
+
where :math:`\pmb{h}_i` is the node embedding for each node :math:`i` in the graph.
|
28 |
+
|
29 |
+
In a nutshell, :math:`\mathrm{Embedding}` is a linear projection.
|
30 |
+
The :math:`\mathrm{Encoder}` is a transformer.
|
31 |
+
The :math:`\mathrm{Decoder}` uses (multi-head) attentions.
|
32 |
+
The policy (sequence of action) :math:`\pi` is decoded autoregressively.
|
33 |
+
The log-likelihood :math:`\log(p_\theta(\pi|s))` of the policy is also returned.
|
34 |
+
|
35 |
+
.. seealso::
|
36 |
+
The definition of :math:`\mathrm{Embedding}`, :math:`\mathrm{Encoder}`, and
|
37 |
+
:math:`\mathrm{Decoder}` can be found in the
|
38 |
+
:mod:`.embedding`, :mod:`.encoder`, :mod:`.decoder` modules.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
embedding_dim : the dimension of the embedded inputs
|
42 |
+
hidden_dim : the dimension of the hidden state of the encoder
|
43 |
+
problem : an object defining the state of the problem
|
44 |
+
n_encode_layers: number of encoder layers
|
45 |
+
tanh_clipping : the clipping scale for the decoder
|
46 |
+
Inputs: inputs
|
47 |
+
* **inputs**: problem instance :math:`s`. [batch, graph_size, input_dim]
|
48 |
+
Outputs: ll, pi
|
49 |
+
* **ll**: :math:`\log(p_\theta(\pi|s))` the log-likelihood of the policy. [batch, T]
|
50 |
+
* **pi**: the policy generated :math:`\pi`. [batch, T]
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
embedding_dim,
|
56 |
+
hidden_dim,
|
57 |
+
problem,
|
58 |
+
n_encode_layers=2,
|
59 |
+
tanh_clipping=10.0,
|
60 |
+
normalization="batch",
|
61 |
+
n_heads=8,
|
62 |
+
):
|
63 |
+
super(AttentionModel, self).__init__()
|
64 |
+
|
65 |
+
self.embedding_dim = embedding_dim
|
66 |
+
self.hidden_dim = hidden_dim
|
67 |
+
self.n_encode_layers = n_encode_layers
|
68 |
+
self.decode_type = None
|
69 |
+
self.n_heads = n_heads
|
70 |
+
|
71 |
+
self.problem = problem
|
72 |
+
|
73 |
+
self.embedding = AutoEmbedding(problem.NAME, {"embedding_dim": embedding_dim})
|
74 |
+
step_context_dim = self.embedding.context_dim
|
75 |
+
|
76 |
+
self.encoder = GraphAttentionEncoder(
|
77 |
+
n_heads=n_heads,
|
78 |
+
embed_dim=embedding_dim,
|
79 |
+
n_layers=self.n_encode_layers,
|
80 |
+
)
|
81 |
+
self.decoder = Decoder(embedding_dim, step_context_dim, n_heads, problem, tanh_clipping)
|
82 |
+
|
83 |
+
def set_decode_type(self, decode_type):
|
84 |
+
self.decoder.set_decode_type(decode_type)
|
85 |
+
|
86 |
+
def forward(self, input):
|
87 |
+
embedding = self.embedding(input)
|
88 |
+
encoded_inputs, _ = self.encoder(embedding)
|
89 |
+
_log_p, pi = self.decoder(input, encoded_inputs)
|
90 |
+
ll = self._calc_log_likelihood(_log_p, pi)
|
91 |
+
return ll, pi
|
92 |
+
|
93 |
+
def _calc_log_likelihood(self, _log_p, pi):
|
94 |
+
|
95 |
+
# Get log_p corresponding to selected actions
|
96 |
+
log_p = _log_p.gather(2, pi.unsqueeze(-1)).squeeze(-1)
|
97 |
+
|
98 |
+
assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"
|
99 |
+
|
100 |
+
# Calculate log_likelihood
|
101 |
+
return log_p.sum(1)
|
models/nets/attention_model/context.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Problem specific global embedding for global context.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
def AutoContext(problem_name, config):
|
10 |
+
"""
|
11 |
+
Automatically select the corresponding module according to ``problem_name``
|
12 |
+
"""
|
13 |
+
mapping = {
|
14 |
+
"tsp": TSPContext,
|
15 |
+
"cvrp": VRPContext,
|
16 |
+
"sdvrp": VRPContext,
|
17 |
+
"pctsp": PCTSPContext,
|
18 |
+
"op": OPContext,
|
19 |
+
}
|
20 |
+
embeddingClass = mapping[problem_name]
|
21 |
+
embedding = embeddingClass(**config)
|
22 |
+
return embedding
|
23 |
+
|
24 |
+
|
25 |
+
def _gather_by_index(source, index):
|
26 |
+
"""
|
27 |
+
target[i,1,:] = source[i,index[i],:]
|
28 |
+
Inputs:
|
29 |
+
source: [B x H x D]
|
30 |
+
index: [B x 1] or [B]
|
31 |
+
Outpus:
|
32 |
+
target: [B x 1 x D]
|
33 |
+
"""
|
34 |
+
|
35 |
+
target = torch.gather(source, 1, index.unsqueeze(-1).expand(-1, -1, source.size(-1)))
|
36 |
+
return target
|
37 |
+
|
38 |
+
|
39 |
+
class PrevNodeContext(nn.Module):
|
40 |
+
"""
|
41 |
+
Abstract class for Context.
|
42 |
+
Any subclass, by default, will return a concatenation of
|
43 |
+
|
44 |
+
+---------------------+-----------------+
|
45 |
+
| prev_node_embedding | state_embedding |
|
46 |
+
+---------------------+-----------------+
|
47 |
+
|
48 |
+
The ``prev_node_embedding`` is the node embedding of the last visited node.
|
49 |
+
It is obtained by ``_prev_node_embedding`` method.
|
50 |
+
It requires ``state.get_current_node()`` to provide the index of the last visited node.
|
51 |
+
|
52 |
+
The ``state_embedding`` is the global context we want to include, such as the remaining capacity in VRP.
|
53 |
+
It is obtained by ``_state_embedding`` method.
|
54 |
+
It is not implemented. The subclass of this abstract class needs to implement this method.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
problem: an object defining the settings of the environment
|
58 |
+
context_dim: the dimension of the output
|
59 |
+
Inputs: embeddings, state
|
60 |
+
* **embeddings** : [batch x graph size x embed dim]
|
61 |
+
* **state**: An object providing observations in the environment. \
|
62 |
+
Needs to supply ``state.get_current_node()``
|
63 |
+
Outputs: context_embedding
|
64 |
+
* **context_embedding**: [batch x 1 x context_dim]
|
65 |
+
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, context_dim):
|
69 |
+
super(PrevNodeContext, self).__init__()
|
70 |
+
self.context_dim = context_dim
|
71 |
+
|
72 |
+
def _prev_node_embedding(self, embeddings, state):
|
73 |
+
current_node = state.get_current_node()
|
74 |
+
prev_node_embedding = _gather_by_index(embeddings, current_node)
|
75 |
+
return prev_node_embedding
|
76 |
+
|
77 |
+
def _state_embedding(self, embeddings, state):
|
78 |
+
raise NotImplementedError("Please implement the embedding for your own problem.")
|
79 |
+
|
80 |
+
def forward(self, embeddings, state):
|
81 |
+
prev_node_embedding = self._prev_node_embedding(embeddings, state)
|
82 |
+
state_embedding = self._state_embedding(embeddings, state)
|
83 |
+
# Embedding of previous node + remaining capacity
|
84 |
+
context_embedding = torch.cat((prev_node_embedding, state_embedding), -1)
|
85 |
+
return context_embedding
|
86 |
+
|
87 |
+
|
88 |
+
class TSPContext(PrevNodeContext):
|
89 |
+
"""
|
90 |
+
Context node embedding for traveling salesman problem.
|
91 |
+
Return a concatenation of
|
92 |
+
|
93 |
+
+------------------------+---------------------------+
|
94 |
+
| first node's embedding | previous node's embedding |
|
95 |
+
+------------------------+---------------------------+
|
96 |
+
|
97 |
+
.. note::
|
98 |
+
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
|
99 |
+
|
100 |
+
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
|
101 |
+
The input ``state`` needs to supply ``state.first_a`` for the index of the first visited node.
|
102 |
+
|
103 |
+
.. warning::
|
104 |
+
The official implementation concatenates the context with [first node, prev node].
|
105 |
+
However, if we follow the paper closely, it should instead be [prev node, first node].
|
106 |
+
Please check ``forward_code`` and ``forward_paper`` for the different implementations.
|
107 |
+
We follow the official implementation in this class.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, context_dim):
|
111 |
+
super(TSPContext, self).__init__(context_dim)
|
112 |
+
self.W_placeholder = nn.Parameter(torch.Tensor(self.context_dim).uniform_(-1, 1))
|
113 |
+
|
114 |
+
def _state_embedding(self, embeddings, state):
|
115 |
+
first_node = state.first_a
|
116 |
+
state_embedding = _gather_by_index(embeddings, first_node)
|
117 |
+
return state_embedding
|
118 |
+
|
119 |
+
def forward_paper(self, embeddings, state):
|
120 |
+
batch_size = embeddings.size(0)
|
121 |
+
if state.i.item() == 0:
|
122 |
+
context_embedding = self.W_placeholder[None, None, :].expand(
|
123 |
+
batch_size, 1, self.W_placeholder.size(-1)
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
context_embedding = super().forward(embeddings, state)
|
127 |
+
return context_embedding
|
128 |
+
|
129 |
+
def forward_code(self, embeddings, state):
|
130 |
+
batch_size = embeddings.size(0)
|
131 |
+
if state.i.item() == 0:
|
132 |
+
context_embedding = self.W_placeholder[None, None, :].expand(
|
133 |
+
batch_size, 1, self.W_placeholder.size(-1)
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
context_embedding = _gather_by_index(
|
137 |
+
embeddings, torch.cat([state.first_a, state.get_current_node()], -1)
|
138 |
+
).view(batch_size, 1, -1)
|
139 |
+
return context_embedding
|
140 |
+
|
141 |
+
def forward_vectorized(self, embeddings, state):
|
142 |
+
n_queries = state.states["first_node_idx"].shape[-1]
|
143 |
+
batch_size = embeddings.size(0)
|
144 |
+
out_shape = (batch_size, n_queries, self.context_dim)
|
145 |
+
|
146 |
+
switch = state.is_initial_action # tensor, 1 if is initial action
|
147 |
+
switch = switch[:, None, None].expand(out_shape) # mask for each data
|
148 |
+
|
149 |
+
# only used for the first action
|
150 |
+
placeholder_embedding = self.W_placeholder[None, None, :].expand(out_shape)
|
151 |
+
# used after first action
|
152 |
+
indexes = torch.stack([state.first_a, state.get_current_node()], -1).flatten(-2)
|
153 |
+
normal_embedding = _gather_by_index(embeddings, indexes).view(out_shape)
|
154 |
+
|
155 |
+
context_embedding = switch * placeholder_embedding + (~switch) * normal_embedding
|
156 |
+
return context_embedding
|
157 |
+
|
158 |
+
def forward(self, embeddings, state):
|
159 |
+
return self.forward_vectorized(embeddings, state)
|
160 |
+
|
161 |
+
|
162 |
+
class VRPContext(PrevNodeContext):
|
163 |
+
"""
|
164 |
+
Context node embedding for capacitated vehicle routing problem.
|
165 |
+
Return a concatenation of
|
166 |
+
|
167 |
+
+---------------------------+--------------------+
|
168 |
+
| previous node's embedding | remaining capacity |
|
169 |
+
+---------------------------+--------------------+
|
170 |
+
|
171 |
+
.. note::
|
172 |
+
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
|
173 |
+
|
174 |
+
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
|
175 |
+
The input ``state`` needs to supply ``state.VEHICLE_CAPACITY`` and ``state.used_capacity``
|
176 |
+
for calculating the remaining capcacity.
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self, context_dim):
|
180 |
+
super(VRPContext, self).__init__(context_dim)
|
181 |
+
|
182 |
+
def _state_embedding(self, embeddings, state):
|
183 |
+
state_embedding = state.VEHICLE_CAPACITY - state.used_capacity[:, :, None]
|
184 |
+
return state_embedding
|
185 |
+
|
186 |
+
|
187 |
+
class PCTSPContext(PrevNodeContext):
|
188 |
+
"""
|
189 |
+
Context node embedding for prize collecting traveling salesman problem.
|
190 |
+
Return a concatenation of
|
191 |
+
|
192 |
+
+---------------------------+----------------------------+
|
193 |
+
| previous node's embedding | remaining prize to collect |
|
194 |
+
+---------------------------+----------------------------+
|
195 |
+
|
196 |
+
.. note::
|
197 |
+
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
|
198 |
+
|
199 |
+
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
|
200 |
+
The input ``state`` needs to supply ``state.get_remaining_prize_to_collect()``.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, context_dim):
|
204 |
+
super(PCTSPContext, self).__init__(context_dim)
|
205 |
+
|
206 |
+
def _state_embedding(self, embeddings, state):
|
207 |
+
state_embedding = state.get_remaining_prize_to_collect()[:, :, None]
|
208 |
+
return state_embedding
|
209 |
+
|
210 |
+
|
211 |
+
class OPContext(PrevNodeContext):
|
212 |
+
"""
|
213 |
+
Context node embedding for orienteering problem.
|
214 |
+
Return a concatenation of
|
215 |
+
|
216 |
+
+---------------------------+---------------------------------+
|
217 |
+
| previous node's embedding | remaining tour length to travel |
|
218 |
+
+---------------------------+---------------------------------+
|
219 |
+
|
220 |
+
.. note::
|
221 |
+
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
|
222 |
+
|
223 |
+
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
|
224 |
+
The input ``state`` needs to supply ``state.get_remaining_length()``.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(self, context_dim):
|
228 |
+
super(OPContext, self).__init__(context_dim)
|
229 |
+
|
230 |
+
def _state_embedding(self, embeddings, state):
|
231 |
+
state_embedding = state.get_remaining_length()[:, :, None]
|
232 |
+
return state_embedding
|
models/nets/attention_model/decoder.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from ...nets.attention_model.context import AutoContext
|
5 |
+
from ...nets.attention_model.dynamic_embedding import AutoDynamicEmbedding
|
6 |
+
from ...nets.attention_model.multi_head_attention import (
|
7 |
+
AttentionScore,
|
8 |
+
MultiHeadAttention,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class Decoder(nn.Module):
|
13 |
+
r"""
|
14 |
+
The decoder of the Attention Model.
|
15 |
+
|
16 |
+
.. math::
|
17 |
+
\{\log(\pmb{p}_t)\},\pi = \mathrm{Decoder}(s, \pmb{h})
|
18 |
+
|
19 |
+
First of all, precompute the keys and values for the embedding :math:`\pmb{h}`:
|
20 |
+
|
21 |
+
.. math::
|
22 |
+
\pmb{k}, \pmb{v}, \pmb{k}^\prime = W^K\pmb{h}, W^V\pmb{h}, W^{K^\prime}\pmb{h}
|
23 |
+
and the projection of the graph embedding:
|
24 |
+
|
25 |
+
.. math::
|
26 |
+
W_{gc}\bar{\pmb{h}} \quad \text{ for } \bar{\pmb{h}} = \frac{1}{N}\sum\nolimits_i \pmb{h}_i.
|
27 |
+
|
28 |
+
Then, the decoder iterates the decoding process autoregressively.
|
29 |
+
In each decoding step, we perform multiple attentions to get the logits for each node.
|
30 |
+
|
31 |
+
.. math::
|
32 |
+
\begin{aligned}
|
33 |
+
\pmb{h}_{(c)} &= [\bar{\pmb{h}}, \text{Context}(s,\pmb{h})] \\
|
34 |
+
q & = W^Q \pmb{h}_{(c)} = W_{gc}\bar{\pmb{h}} + W_{sc}\text{Context}(s,\pmb{h}) \\
|
35 |
+
q_{gl} &= \mathrm{MultiHeadAttention}(q,\pmb{k},\pmb{v},\mathrm{mask}_t) \\
|
36 |
+
\pmb{p}_t &= \mathrm{Softmax}(\mathrm{AttentionScore}_{\text{clip}}(q_{gl},\pmb{k}^\prime, \mathrm{mask}_t))\\
|
37 |
+
\pi_{t} &= \mathrm{DecodingStartegy}(\pmb{p}_t) \\
|
38 |
+
\mathrm{mask}_{t+1} &= \mathrm{mask}_t.update(\pi_t).
|
39 |
+
\end{aligned}
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
.. note::
|
44 |
+
If there are dynamic node features specified by :mod:`.dynamic_embedding` ,
|
45 |
+
the keys and values projections are updated in each decoding step by
|
46 |
+
|
47 |
+
.. math::
|
48 |
+
\begin{aligned}
|
49 |
+
\pmb{k}_{\text{dynamic}}, \pmb{v}_{\text{dynamic}}, \pmb{k}^\prime_{\text{dynamic}} &= \mathrm{DynamicEmbedding}(s)\\
|
50 |
+
\pmb{k} &= \pmb{k} + \pmb{k}_{\text{dynamic}}\\
|
51 |
+
\pmb{v} &= \pmb{v} +\pmb{v}_{\text{dynamic}} \\
|
52 |
+
\pmb{k}^\prime &= \pmb{k}^\prime +\pmb{k}^\prime_{\text{dynamic}}.
|
53 |
+
\end{aligned}
|
54 |
+
.. seealso::
|
55 |
+
* The :math:`\text{Context}` is defined in the :mod:`.context` module.
|
56 |
+
* The :math:`\text{AttentionScore}` is defined by the :class:`.AttentionScore` class.
|
57 |
+
* The :math:`\text{MultiHeadAttention}` is defined by the :class:`.MultiHeadAttention` class.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
embedding_dim : the dimension of the embedded inputs
|
61 |
+
step_context_dim : the dimension of the context :math:`\text{Context}(\pmb{x})`
|
62 |
+
n_heads: number of heads in the :math:`\mathrm{MultiHeadAttention}`
|
63 |
+
problem: an object defining the state and the mask updating rule of the problem
|
64 |
+
tanh_clipping : the clipping scale of the pointer (attention layer before output)
|
65 |
+
Inputs: input, embeddings
|
66 |
+
* **input** : dict of inputs, for example ``{'loc': tensor, 'depot': tensor, 'demand': tensor}`` for CVRP.
|
67 |
+
* **embeddings**: [batch, graph_size, embedding_dim]
|
68 |
+
Outputs: log_ps, pi
|
69 |
+
* **log_ps**: [batch, graph_size, T]
|
70 |
+
* **pi**: [batch, T]
|
71 |
+
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, embedding_dim, step_context_dim, n_heads, problem, tanh_clipping):
|
75 |
+
super(Decoder, self).__init__()
|
76 |
+
# For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
|
77 |
+
|
78 |
+
self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False)
|
79 |
+
self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
80 |
+
self.project_step_context = nn.Linear(step_context_dim, embedding_dim, bias=False)
|
81 |
+
|
82 |
+
self.context = AutoContext(problem.NAME, {"context_dim": step_context_dim})
|
83 |
+
self.dynamic_embedding = AutoDynamicEmbedding(
|
84 |
+
problem.NAME, {"embedding_dim": embedding_dim}
|
85 |
+
)
|
86 |
+
self.glimpse = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
|
87 |
+
self.pointer = AttentionScore(use_tanh=True, C=tanh_clipping)
|
88 |
+
|
89 |
+
self.decode_type = None
|
90 |
+
self.problem = problem
|
91 |
+
|
92 |
+
def forward(self, input, embeddings):
|
93 |
+
outputs = []
|
94 |
+
sequences = []
|
95 |
+
|
96 |
+
state = self.problem.make_state(input)
|
97 |
+
|
98 |
+
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
|
99 |
+
cached_embeddings = self._precompute(embeddings)
|
100 |
+
|
101 |
+
# Perform decoding steps
|
102 |
+
while not (state.all_finished()):
|
103 |
+
|
104 |
+
log_p, mask = self.advance(cached_embeddings, state)
|
105 |
+
|
106 |
+
# Select the indices of the next nodes in the sequences, result (batch_size) long
|
107 |
+
# Squeeze out steps dimension
|
108 |
+
action = self.decode(log_p.exp(), mask)
|
109 |
+
state = state.update(action)
|
110 |
+
|
111 |
+
# Collect output of step
|
112 |
+
outputs.append(log_p)
|
113 |
+
sequences.append(action)
|
114 |
+
|
115 |
+
# Collected lists, return Tensor
|
116 |
+
return torch.stack(outputs, 1), torch.stack(sequences, 1)
|
117 |
+
|
118 |
+
def set_decode_type(self, decode_type):
|
119 |
+
r"""
|
120 |
+
Currently support
|
121 |
+
|
122 |
+
.. code-block:: python
|
123 |
+
|
124 |
+
["greedy", "sampling"]
|
125 |
+
|
126 |
+
"""
|
127 |
+
self.decode_type = decode_type
|
128 |
+
|
129 |
+
def decode(self, probs, mask):
|
130 |
+
r"""
|
131 |
+
Execute the decoding strategy specified by ``self.decode_type``.
|
132 |
+
|
133 |
+
Inputs:
|
134 |
+
* **probs**: [batch_size, graph_size]
|
135 |
+
* **mask** (bool): [batch_size, graph_size]
|
136 |
+
Outputs:
|
137 |
+
* **idxs** (int): index of action chosen. [batch_size]
|
138 |
+
"""
|
139 |
+
assert (probs == probs).all(), "Probs should not contain any nans"
|
140 |
+
|
141 |
+
if self.decode_type == "greedy":
|
142 |
+
_, selected = probs.max(1)
|
143 |
+
assert not mask.gather(
|
144 |
+
1, selected.unsqueeze(-1)
|
145 |
+
).data.any(), "Decode greedy: infeasible action has maximum probability"
|
146 |
+
|
147 |
+
elif self.decode_type == "sampling":
|
148 |
+
selected = probs.multinomial(1).squeeze(1)
|
149 |
+
|
150 |
+
# Check if sampling went OK, can go wrong due to bug on GPU
|
151 |
+
# See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
|
152 |
+
while mask.gather(1, selected.unsqueeze(-1)).data.any():
|
153 |
+
print("Sampled bad values, resampling!")
|
154 |
+
selected = probs.multinomial(1).squeeze(1)
|
155 |
+
|
156 |
+
else:
|
157 |
+
assert False, "Unknown decode type"
|
158 |
+
return selected
|
159 |
+
|
160 |
+
def _precompute(self, embeddings):
|
161 |
+
|
162 |
+
# The fixed context projection of the graph embedding is calculated only once for efficiency
|
163 |
+
graph_embed = embeddings.mean(1)
|
164 |
+
# fixed context = (batch_size, 1, embed_dim) to make broadcastable with parallel timesteps
|
165 |
+
graph_context = self.project_fixed_context(graph_embed).unsqueeze(-2)
|
166 |
+
# The projection of the node embeddings for the attention is calculated once up front
|
167 |
+
glimpse_key, glimpse_val, logit_key = self.project_node_embeddings(embeddings).chunk(
|
168 |
+
3, dim=-1
|
169 |
+
)
|
170 |
+
|
171 |
+
cache = (
|
172 |
+
embeddings,
|
173 |
+
graph_context,
|
174 |
+
glimpse_key,
|
175 |
+
glimpse_val,
|
176 |
+
logit_key,
|
177 |
+
) # single head for the final logit
|
178 |
+
return cache
|
179 |
+
|
180 |
+
def advance(self, cached_embeddings, state):
|
181 |
+
|
182 |
+
node_embeddings, graph_context, glimpse_K, glimpse_V, logit_K = cached_embeddings
|
183 |
+
|
184 |
+
# Compute context node embedding: [graph embedding| prev node| problem-state-context]
|
185 |
+
# [batch, 1, context dim]
|
186 |
+
context = self.context(node_embeddings, state)
|
187 |
+
step_context = self.project_step_context(context) # [batch, 1, embed_dim]
|
188 |
+
query = graph_context + step_context # [batch, 1, embed_dim]
|
189 |
+
|
190 |
+
glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(state)
|
191 |
+
glimpse_K = glimpse_K + glimpse_key_dynamic
|
192 |
+
glimpse_V = glimpse_V + glimpse_val_dynamic
|
193 |
+
logit_K = logit_K + logit_key_dynamic
|
194 |
+
# Compute the mask
|
195 |
+
mask = state.get_mask()
|
196 |
+
|
197 |
+
# Compute logits (unnormalized log_p)
|
198 |
+
logits, glimpse = self.calc_logits(query, glimpse_K, glimpse_V, logit_K, mask)
|
199 |
+
|
200 |
+
return logits, glimpse
|
201 |
+
|
202 |
+
def calc_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
|
203 |
+
# Compute glimpse with multi-head-attention.
|
204 |
+
# Then use glimpse as a query to compute logits for each node
|
205 |
+
|
206 |
+
# [batch, 1, embed dim]
|
207 |
+
glimpse = self.glimpse(query, glimpse_K, glimpse_V, mask)
|
208 |
+
|
209 |
+
logits = self.pointer(glimpse, logit_K, mask)
|
210 |
+
|
211 |
+
return logits, glimpse
|
models/nets/attention_model/dynamic_embedding.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Problem specific node embedding for dynamic feature.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
def AutoDynamicEmbedding(problem_name, config):
|
9 |
+
"""
|
10 |
+
Automatically select the corresponding module according to ``problem_name``
|
11 |
+
"""
|
12 |
+
mapping = {
|
13 |
+
"tsp": NonDyanmicEmbedding,
|
14 |
+
"cvrp": NonDyanmicEmbedding,
|
15 |
+
"sdvrp": SDVRPDynamicEmbedding,
|
16 |
+
"pctsp": NonDyanmicEmbedding,
|
17 |
+
"op": NonDyanmicEmbedding,
|
18 |
+
}
|
19 |
+
embeddingClass = mapping[problem_name]
|
20 |
+
embedding = embeddingClass(**config)
|
21 |
+
return embedding
|
22 |
+
|
23 |
+
|
24 |
+
class SDVRPDynamicEmbedding(nn.Module):
|
25 |
+
"""
|
26 |
+
Embedding for dynamic node feature for the split delivery vehicle routing problem.
|
27 |
+
|
28 |
+
It is implemented as a linear projection of the demands left in each node.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
embedding_dim: dimension of output
|
32 |
+
Inputs: state
|
33 |
+
* **state** : a class that provide ``state.demands_with_depot`` tensor
|
34 |
+
Outputs: glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic
|
35 |
+
* **glimpse_key_dynamic** : [batch, graph_size, embedding_dim]
|
36 |
+
* **glimpse_val_dynamic** : [batch, graph_size, embedding_dim]
|
37 |
+
* **logit_key_dynamic** : [batch, graph_size, embedding_dim]
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, embedding_dim):
|
42 |
+
super(SDVRPDynamicEmbedding, self).__init__()
|
43 |
+
self.projection = nn.Linear(1, 3 * embedding_dim, bias=False)
|
44 |
+
|
45 |
+
def forward(self, state):
|
46 |
+
glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.projection(
|
47 |
+
state.demands_with_depot[:, 0, :, None].clone()
|
48 |
+
).chunk(3, dim=-1)
|
49 |
+
return glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic
|
50 |
+
|
51 |
+
|
52 |
+
class NonDyanmicEmbedding(nn.Module):
|
53 |
+
"""
|
54 |
+
Embedding for problems that do not have any dynamic node feature.
|
55 |
+
|
56 |
+
It is implemented as simply returning zeros.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
embedding_dim: dimension of output
|
60 |
+
Inputs: state
|
61 |
+
* **state** : not used, just for consistency
|
62 |
+
Outputs: glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic
|
63 |
+
* **glimpse_key_dynamic** : [batch, graph_size, embedding_dim]
|
64 |
+
* **glimpse_val_dynamic** : [batch, graph_size, embedding_dim]
|
65 |
+
* **logit_key_dynamic** : [batch, graph_size, embedding_dim]
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, embedding_dim):
|
70 |
+
super(NonDyanmicEmbedding, self).__init__()
|
71 |
+
|
72 |
+
def forward(self, state):
|
73 |
+
return 0, 0, 0
|
models/nets/attention_model/embedding.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Problem specific node embedding for static feature.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
def AutoEmbedding(problem_name, config):
|
10 |
+
"""
|
11 |
+
Automatically select the corresponding module according to ``problem_name``
|
12 |
+
"""
|
13 |
+
mapping = {
|
14 |
+
"tsp": TSPEmbedding,
|
15 |
+
"cvrp": VRPEmbedding,
|
16 |
+
"sdvrp": VRPEmbedding,
|
17 |
+
"pctsp": PCTSPEmbedding,
|
18 |
+
"op": OPEmbedding,
|
19 |
+
}
|
20 |
+
embeddingClass = mapping[problem_name]
|
21 |
+
embedding = embeddingClass(**config)
|
22 |
+
return embedding
|
23 |
+
|
24 |
+
|
25 |
+
class TSPEmbedding(nn.Module):
|
26 |
+
"""
|
27 |
+
Embedding for the traveling salesman problem.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
embedding_dim: dimension of output
|
31 |
+
Inputs: input
|
32 |
+
* **input** : [batch, n_customer, 2]
|
33 |
+
Outputs: out
|
34 |
+
* **out** : [batch, n_customer, embedding_dim]
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, embedding_dim):
|
38 |
+
super(TSPEmbedding, self).__init__()
|
39 |
+
node_dim = 2 # x, y
|
40 |
+
self.context_dim = 2 * embedding_dim # Embedding of first and last node
|
41 |
+
|
42 |
+
self.init_embed = nn.Linear(node_dim, embedding_dim)
|
43 |
+
|
44 |
+
def forward(self, input):
|
45 |
+
out = self.init_embed(input)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
class VRPEmbedding(nn.Module):
|
50 |
+
"""
|
51 |
+
Embedding for the capacitated vehicle routing problem.
|
52 |
+
The shape of tensors in ``input`` is summarized as following:
|
53 |
+
|
54 |
+
+-----------+-------------------------+
|
55 |
+
| key | size of tensor |
|
56 |
+
+===========+=========================+
|
57 |
+
| 'loc' | [batch, n_customer, 2] |
|
58 |
+
+-----------+-------------------------+
|
59 |
+
| 'depot' | [batch, 2] |
|
60 |
+
+-----------+-------------------------+
|
61 |
+
| 'demand' | [batch, n_customer, 1] |
|
62 |
+
+-----------+-------------------------+
|
63 |
+
|
64 |
+
Args:
|
65 |
+
embedding_dim: dimension of output
|
66 |
+
Inputs: input
|
67 |
+
* **input** : dict of ['loc', 'depot', 'demand']
|
68 |
+
Outputs: out
|
69 |
+
* **out** : [batch, n_customer+1, embedding_dim]
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, embedding_dim):
|
73 |
+
super(VRPEmbedding, self).__init__()
|
74 |
+
node_dim = 3 # x, y, demand
|
75 |
+
|
76 |
+
self.context_dim = embedding_dim + 1 # Embedding of last node + remaining_capacity
|
77 |
+
|
78 |
+
self.init_embed = nn.Linear(node_dim, embedding_dim)
|
79 |
+
self.init_embed_depot = nn.Linear(2, embedding_dim) # depot embedding
|
80 |
+
|
81 |
+
def forward(self, input): # dict of 'loc', 'demand', 'depot'
|
82 |
+
# batch, 1, 2 -> batch, 1, embedding_dim
|
83 |
+
depot_embedding = self.init_embed_depot(input["depot"])[:, None, :]
|
84 |
+
# [batch, n_customer, 2, batch, n_customer, 1] -> batch, n_customer, embedding_dim
|
85 |
+
node_embeddings = self.init_embed(
|
86 |
+
torch.cat((input["loc"], input["demand"][:, :, None]), -1)
|
87 |
+
)
|
88 |
+
# batch, n_customer+1, embedding_dim
|
89 |
+
out = torch.cat((depot_embedding, node_embeddings), 1)
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class PCTSPEmbedding(nn.Module):
|
94 |
+
"""
|
95 |
+
Embedding for the prize collecting traveling salesman problem.
|
96 |
+
The shape of tensors in ``input`` is summarized as following:
|
97 |
+
|
98 |
+
+------------------------+-------------------------+
|
99 |
+
| key | size of tensor |
|
100 |
+
+========================+=========================+
|
101 |
+
| 'loc' | [batch, n_customer, 2] |
|
102 |
+
+------------------------+-------------------------+
|
103 |
+
| 'depot' | [batch, 2] |
|
104 |
+
+------------------------+-------------------------+
|
105 |
+
| 'deterministic_prize' | [batch, n_customer, 1] |
|
106 |
+
+------------------------+-------------------------+
|
107 |
+
| 'penalty' | [batch, n_customer, 1] |
|
108 |
+
+------------------------+-------------------------+
|
109 |
+
|
110 |
+
Args:
|
111 |
+
embedding_dim: dimension of output
|
112 |
+
Inputs: input
|
113 |
+
* **input** : dict of ['loc', 'depot', 'deterministic_prize', 'penalty']
|
114 |
+
Outputs: out
|
115 |
+
* **out** : [batch, n_customer+1, embedding_dim]
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, embedding_dim):
|
119 |
+
super(PCTSPEmbedding, self).__init__()
|
120 |
+
node_dim = 4 # x, y, prize, penalty
|
121 |
+
self.context_dim = embedding_dim + 1 # Embedding of last node + remaining prize to collect
|
122 |
+
|
123 |
+
self.init_embed = nn.Linear(node_dim, embedding_dim)
|
124 |
+
self.init_embed_depot = nn.Linear(2, embedding_dim) # depot embedding
|
125 |
+
|
126 |
+
def forward(self, input): # dict of 'loc', 'deterministic_prize', 'penalty', 'depot'
|
127 |
+
# batch, 1, 2 -> batch, 1, embedding_dim
|
128 |
+
depot_embedding = self.init_embed_depot(input["depot"])[:, None, :]
|
129 |
+
# [batch, n_customer, 2, batch, n_customer, 1, batch, n_customer, 1] -> batch, n_customer, embedding_dim
|
130 |
+
node_embeddings = self.init_embed(
|
131 |
+
torch.cat(
|
132 |
+
(
|
133 |
+
input["loc"],
|
134 |
+
input["deterministic_prize"][:, :, None],
|
135 |
+
input["penalty"][:, :, None],
|
136 |
+
),
|
137 |
+
-1,
|
138 |
+
)
|
139 |
+
)
|
140 |
+
# batch, n_customer+1, embedding_dim
|
141 |
+
out = torch.cat((depot_embedding, node_embeddings), 1)
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class OPEmbedding(nn.Module):
|
146 |
+
"""
|
147 |
+
Embedding for the orienteering problem.
|
148 |
+
The shape of tensors in ``input`` is summarized as following:
|
149 |
+
|
150 |
+
+----------+-------------------------+
|
151 |
+
| key | size of tensor |
|
152 |
+
+==========+=========================+
|
153 |
+
| 'loc' | [batch, n_customer, 2] |
|
154 |
+
+----------+-------------------------+
|
155 |
+
| 'depot' | [batch, 2] |
|
156 |
+
+----------+-------------------------+
|
157 |
+
| 'prize' | [batch, n_customer, 1] |
|
158 |
+
+----------+-------------------------+
|
159 |
+
|
160 |
+
Args:
|
161 |
+
embedding_dim: dimension of output
|
162 |
+
Inputs: input
|
163 |
+
* **input** : dict of ['loc', 'depot', 'prize']
|
164 |
+
Outputs: out
|
165 |
+
* **out** : [batch, n_customer+1, embedding_dim]
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, embedding_dim):
|
169 |
+
super(OPEmbedding, self).__init__()
|
170 |
+
node_dim = 3 # x, y, prize
|
171 |
+
self.context_dim = embedding_dim + 1 # Embedding of last node + remaining prize to collect
|
172 |
+
|
173 |
+
self.init_embed = nn.Linear(node_dim, embedding_dim)
|
174 |
+
self.init_embed_depot = nn.Linear(2, embedding_dim) # depot embedding
|
175 |
+
|
176 |
+
def forward(self, input): # dict of 'loc', 'prize', 'depot'
|
177 |
+
# batch, 1, 2 -> batch, 1, embedding_dim
|
178 |
+
depot_embedding = self.init_embed_depot(input["depot"])[:, None, :]
|
179 |
+
# [batch, n_customer, 2, batch, n_customer, 1, batch, n_customer, 1] -> batch, n_customer, embedding_dim
|
180 |
+
node_embeddings = self.init_embed(
|
181 |
+
torch.cat((input["loc"], input["prize"][:, :, None]), -1)
|
182 |
+
)
|
183 |
+
# batch, n_customer+1, embedding_dim
|
184 |
+
out = torch.cat((depot_embedding, node_embeddings), 1)
|
185 |
+
return out
|
models/nets/attention_model/encoder.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from ...nets.attention_model.multi_head_attention import MultiHeadAttentionProj
|
4 |
+
|
5 |
+
|
6 |
+
class SkipConnection(nn.Module):
|
7 |
+
def __init__(self, module):
|
8 |
+
super(SkipConnection, self).__init__()
|
9 |
+
self.module = module
|
10 |
+
|
11 |
+
def forward(self, input):
|
12 |
+
return input + self.module(input)
|
13 |
+
|
14 |
+
|
15 |
+
class Normalization(nn.Module):
|
16 |
+
def __init__(self, embedding_dim):
|
17 |
+
super(Normalization, self).__init__()
|
18 |
+
|
19 |
+
self.normalizer = nn.BatchNorm1d(embedding_dim, affine=True)
|
20 |
+
|
21 |
+
def forward(self, input):
|
22 |
+
# out = self.normalizer(input.permute(0,2,1)).permute(0,2,1) # slightly different 3e-6
|
23 |
+
# return out
|
24 |
+
return self.normalizer(input.view(-1, input.size(-1))).view(input.size())
|
25 |
+
|
26 |
+
|
27 |
+
class MultiHeadAttentionLayer(nn.Sequential):
|
28 |
+
r"""
|
29 |
+
A layer with attention mechanism and normalization.
|
30 |
+
|
31 |
+
For an embedding :math:`\pmb{x}`,
|
32 |
+
|
33 |
+
.. math::
|
34 |
+
\pmb{h} = \mathrm{MultiHeadAttentionLayer}(\pmb{x})
|
35 |
+
|
36 |
+
The following is executed:
|
37 |
+
|
38 |
+
.. math::
|
39 |
+
\begin{aligned}
|
40 |
+
\pmb{x}_0&=\pmb{x}+\mathrm{MultiHeadAttentionProj}(\pmb{x}) \\
|
41 |
+
\pmb{x}_1&=\mathrm{BatchNorm}(\pmb{x}_0) \\
|
42 |
+
\pmb{x}_2&=\pmb{x}_1+\mathrm{MLP_{\text{2 layers}}}(\pmb{x}_1)\\
|
43 |
+
\pmb{h} &=\mathrm{BatchNorm}(\pmb{x}_2)
|
44 |
+
\end{aligned}
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
.. seealso::
|
49 |
+
The :math:`\mathrm{MultiHeadAttentionProj}` computes the self attention
|
50 |
+
of the embedding :math:`\pmb{x}`. Check :class:`~.MultiHeadAttentionProj` for details.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
n_heads : number of heads
|
54 |
+
embedding_dim : dimension of the query, keys, values
|
55 |
+
feed_forward_hidden : size of the hidden layer in the MLP
|
56 |
+
Inputs: inputs
|
57 |
+
* **inputs**: embeddin :math:`\pmb{x}`. [batch, graph_size, embedding_dim]
|
58 |
+
Outputs: out
|
59 |
+
* **out**: the output :math:`\pmb{h}` [batch, graph_size, embedding_dim]
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
n_heads,
|
65 |
+
embedding_dim,
|
66 |
+
feed_forward_hidden=512,
|
67 |
+
):
|
68 |
+
super(MultiHeadAttentionLayer, self).__init__(
|
69 |
+
SkipConnection(
|
70 |
+
MultiHeadAttentionProj(
|
71 |
+
embedding_dim=embedding_dim,
|
72 |
+
n_heads=n_heads,
|
73 |
+
)
|
74 |
+
),
|
75 |
+
Normalization(embedding_dim),
|
76 |
+
SkipConnection(
|
77 |
+
nn.Sequential(
|
78 |
+
nn.Linear(embedding_dim, feed_forward_hidden),
|
79 |
+
nn.ReLU(),
|
80 |
+
nn.Linear(feed_forward_hidden, embedding_dim),
|
81 |
+
)
|
82 |
+
if feed_forward_hidden > 0
|
83 |
+
else nn.Linear(embedding_dim, embedding_dim)
|
84 |
+
),
|
85 |
+
Normalization(embedding_dim),
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
class GraphAttentionEncoder(nn.Module):
|
90 |
+
r"""
|
91 |
+
Graph attention by self attention on graph nodes.
|
92 |
+
|
93 |
+
For an embedding :math:`\pmb{x}`, repeat ``n_layers`` time:
|
94 |
+
|
95 |
+
.. math::
|
96 |
+
\pmb{h} = \mathrm{MultiHeadAttentionLayer}(\pmb{x})
|
97 |
+
|
98 |
+
.. seealso::
|
99 |
+
Check :class:`~.MultiHeadAttentionLayer` for details.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
n_heads : number of heads
|
103 |
+
embedding_dim : dimension of the query, keys, values
|
104 |
+
n_layers : number of :class:`~.MultiHeadAttentionLayer` to iterate.
|
105 |
+
feed_forward_hidden : size of the hidden layer in the MLP
|
106 |
+
Inputs: x
|
107 |
+
* **x**: embeddin :math:`\pmb{x}`. [batch, graph_size, embedding_dim]
|
108 |
+
Outputs: (h, h_mean)
|
109 |
+
* **h**: the output :math:`\pmb{h}` [batch, graph_size, embedding_dim]
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, n_heads, embed_dim, n_layers, feed_forward_hidden=512):
|
113 |
+
super(GraphAttentionEncoder, self).__init__()
|
114 |
+
|
115 |
+
self.layers = nn.Sequential(
|
116 |
+
*(
|
117 |
+
MultiHeadAttentionLayer(n_heads, embed_dim, feed_forward_hidden)
|
118 |
+
for _ in range(n_layers)
|
119 |
+
)
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(self, x, mask=None):
|
123 |
+
|
124 |
+
assert mask is None, "TODO mask not yet supported!"
|
125 |
+
|
126 |
+
h = self.layers(x)
|
127 |
+
|
128 |
+
return (h, h.mean(dim=1))
|
models/nets/attention_model/multi_head_attention.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class AttentionScore(nn.Module):
|
8 |
+
r"""
|
9 |
+
A helper class for attention operations.
|
10 |
+
There are no parameters in this module.
|
11 |
+
This module computes the alignment score with mask
|
12 |
+
and return only the attention score.
|
13 |
+
|
14 |
+
The default operation is
|
15 |
+
|
16 |
+
.. math::
|
17 |
+
\pmb{u} = \mathrm{Attention}(q,\pmb{k}, \mathrm{mask})
|
18 |
+
|
19 |
+
where for each key :math:`k_j`, we have
|
20 |
+
|
21 |
+
.. math::
|
22 |
+
u_j =
|
23 |
+
\begin{cases}
|
24 |
+
&\frac{q^Tk_j}{\sqrt{\smash{d_q}}} & \text{ if } j \notin \mathrm{mask}\\
|
25 |
+
&-\infty & \text{ otherwise. }
|
26 |
+
\end{cases}
|
27 |
+
|
28 |
+
If ``use_tanh`` is ``True``, apply clipping on the logits :math:`u_j` before masking:
|
29 |
+
|
30 |
+
.. math::
|
31 |
+
u_j =
|
32 |
+
\begin{cases}
|
33 |
+
&C\mathrm{tanh}\left(\frac{q^Tk_j}{\sqrt{\smash{d_q}}}\right) & \text{ if } j \notin \mathrm{mask}\\
|
34 |
+
&-\infty & \text{ otherwise. }
|
35 |
+
\end{cases}
|
36 |
+
|
37 |
+
Args:
|
38 |
+
use_tanh: if True, use clipping on the logits
|
39 |
+
C: the range of the clipping [-C,C]
|
40 |
+
Inputs: query, keys, mask
|
41 |
+
* **query** : [..., 1, h_dim]
|
42 |
+
* **keys**: [..., graph_size, h_dim]
|
43 |
+
* **mask**: [..., graph_size] ``logits[...,j]==-inf`` if ``mask[...,j]==True``.
|
44 |
+
Outputs: logits
|
45 |
+
* **logits**: [..., 1, graph_size] The attention score for each key.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, use_tanh=False, C=10):
|
49 |
+
super(AttentionScore, self).__init__()
|
50 |
+
self.use_tanh = use_tanh
|
51 |
+
self.C = C
|
52 |
+
|
53 |
+
def forward(self, query, key, mask=torch.zeros([], dtype=torch.bool)):
|
54 |
+
u = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
55 |
+
if self.use_tanh:
|
56 |
+
logits = torch.tanh(u) * self.C
|
57 |
+
else:
|
58 |
+
logits = u
|
59 |
+
|
60 |
+
logits[mask.expand_as(logits)] = float("-inf") # masked after clipping
|
61 |
+
return logits
|
62 |
+
|
63 |
+
|
64 |
+
class MultiHeadAttention(nn.Module):
|
65 |
+
r"""
|
66 |
+
Compute the multi-head attention.
|
67 |
+
|
68 |
+
.. math::
|
69 |
+
q^\prime = \mathrm{MultiHeadAttention}(q,\pmb{k},\pmb{v},\mathrm{mask})
|
70 |
+
|
71 |
+
The following is computed:
|
72 |
+
|
73 |
+
.. math::
|
74 |
+
\begin{aligned}
|
75 |
+
\pmb{a}^{(j)} &= \mathrm{Softmax}(\mathrm{AttentionScore}(q^{(j)},\pmb{k}^{(j)}, \mathrm{mask}))\\
|
76 |
+
h^{(j)} &= \sum\nolimits_i \pmb{a}^{(j)}_i\pmb{v}_i \\
|
77 |
+
q^\prime &= W^O \left[h^{(1)},...,h^{(J)}\right]
|
78 |
+
\end{aligned}
|
79 |
+
|
80 |
+
Args:
|
81 |
+
embedding_dim: dimension of the query, keys, values
|
82 |
+
n_head: number of heads
|
83 |
+
Inputs: query, keys, value, mask
|
84 |
+
* **query** : [batch, n_querys, embedding_dim]
|
85 |
+
* **keys**: [batch, n_keys, embedding_dim]
|
86 |
+
* **value**: [batch, n_keys, embedding_dim]
|
87 |
+
* **mask**: [batch, 1, n_keys] ``logits[batch,j]==-inf`` if ``mask[batch, 0, j]==True``
|
88 |
+
Outputs: logits, out
|
89 |
+
* **out**: [batch, 1, embedding_dim] The output of the multi-head attention
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, embedding_dim, n_heads=8):
|
93 |
+
super(MultiHeadAttention, self).__init__()
|
94 |
+
self.n_heads = n_heads
|
95 |
+
self.attentionScore = AttentionScore()
|
96 |
+
self.project_out = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
97 |
+
|
98 |
+
def forward(self, query, key, value, mask):
|
99 |
+
query_heads = self._make_heads(query)
|
100 |
+
key_heads = self._make_heads(key)
|
101 |
+
value_heads = self._make_heads(value)
|
102 |
+
|
103 |
+
# [n_heads, batch, 1, nkeys]
|
104 |
+
compatibility = self.attentionScore(query_heads, key_heads, mask)
|
105 |
+
|
106 |
+
# [n_heads, batch, 1, head_dim]
|
107 |
+
out_heads = torch.matmul(torch.softmax(compatibility, dim=-1), value_heads)
|
108 |
+
|
109 |
+
# from multihead [nhead, batch, 1, head_dim] -> [batch, 1, nhead* head_dim]
|
110 |
+
out = self.project_out(self._unmake_heads(out_heads))
|
111 |
+
return out
|
112 |
+
|
113 |
+
def _make_heads(self, v):
|
114 |
+
batch_size, nkeys, h_dim = v.shape
|
115 |
+
# [batch_size, ..., n_heads* head_dim] --> [n_heads, batch_size, ..., head_dim]
|
116 |
+
out = v.reshape(batch_size, nkeys, self.n_heads, h_dim // self.n_heads).movedim(-2, 0)
|
117 |
+
return out
|
118 |
+
|
119 |
+
def _unmake_heads(self, v):
|
120 |
+
# [n_heads, batch_size, ..., head_dim] --> [batch_size, ..., n_heads* head_dim]
|
121 |
+
out = v.movedim(0, -2).flatten(-2)
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class MultiHeadAttentionProj(nn.Module):
|
126 |
+
r"""
|
127 |
+
Compute the multi-head attention with projection.
|
128 |
+
Different from :class:`.MultiHeadAttention` which accepts precomputed query, keys, and values,
|
129 |
+
this module computes linear projections from the inputs to query, keys, and values.
|
130 |
+
|
131 |
+
.. math::
|
132 |
+
q^\prime = \mathrm{MultiHeadAttentionProj}(q_0,\pmb{h},\mathrm{mask})
|
133 |
+
|
134 |
+
The following is computed:
|
135 |
+
|
136 |
+
.. math::
|
137 |
+
\begin{aligned}
|
138 |
+
q, \pmb{k}, \pmb{v} &= W^Qq_0, W^K\pmb{h}, W^V\pmb{h}\\
|
139 |
+
\pmb{a}^{(j)} &= \mathrm{Softmax}(\mathrm{AttentionScore}(q^{(j)},\pmb{k}^{(j)}, \mathrm{mask}))\\
|
140 |
+
h^{(j)} &= \sum\nolimits_i \pmb{a}^{(j)}_i\pmb{v}_i \\
|
141 |
+
q^\prime &= W^O \left[h^{(1)},...,h^{(J)}\right]
|
142 |
+
\end{aligned}
|
143 |
+
|
144 |
+
if :math:`\pmb{h}` is not given. This module will compute the self attention of :math:`q_0`.
|
145 |
+
|
146 |
+
.. warning::
|
147 |
+
The results of the in-projection of query, key, value are
|
148 |
+
slightly different (order of ``1e-6``) with the original implementation.
|
149 |
+
This is due to the numerical accuracy.
|
150 |
+
The two implementations differ by the way of multiplying matrix.
|
151 |
+
Thus, different internal implementation libraries of pytorch are called
|
152 |
+
and the results are slightly different.
|
153 |
+
See the pytorch docs on `numerical accruacy <https://pytorch.org/docs/stable/notes/numerical_accuracy.html>`_ for detail.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
embedding_dim: dimension of the query, keys, values
|
157 |
+
n_head: number of heads
|
158 |
+
Inputs: q, h, mask
|
159 |
+
* **q** : [batch, n_querys, embedding_dim]
|
160 |
+
* **h**: [batch, n_keys, embedding_dim]
|
161 |
+
* **mask**: [batch, n_keys] ``logits[batch,j]==-inf`` if ``mask[batch,j]==True``
|
162 |
+
Outputs: out
|
163 |
+
* **out**: [batch, n_querys, embedding_dim] The output of the multi-head attention
|
164 |
+
|
165 |
+
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, embedding_dim, n_heads=8):
|
169 |
+
super(MultiHeadAttentionProj, self).__init__()
|
170 |
+
|
171 |
+
self.queryEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
172 |
+
self.keyEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
173 |
+
self.valueEncoder = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
174 |
+
|
175 |
+
self.MHA = MultiHeadAttention(embedding_dim, n_heads)
|
176 |
+
|
177 |
+
def forward(self, q, h=None, mask=torch.zeros([], dtype=torch.bool)):
|
178 |
+
|
179 |
+
if h is None:
|
180 |
+
h = q # compute self-attention
|
181 |
+
|
182 |
+
query = self.queryEncoder(q)
|
183 |
+
key = self.keyEncoder(h)
|
184 |
+
value = self.valueEncoder(h)
|
185 |
+
|
186 |
+
out = self.MHA(query, key, value, mask)
|
187 |
+
|
188 |
+
return out
|
ppo.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Retrieved from https://github.com/vwxyzjn/cleanrl/blob/28fd178ca182bd83c75ed0d49d52e235ca6cdc88/cleanrl/ppo.py
|
2 |
+
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from distutils.util import strtobool
|
9 |
+
|
10 |
+
import gym
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim as optim
|
15 |
+
from torch.distributions.categorical import Categorical
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
# fmt: off
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
|
23 |
+
help="the name of this experiment")
|
24 |
+
parser.add_argument("--seed", type=int, default=1,
|
25 |
+
help="seed of the experiment")
|
26 |
+
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
27 |
+
help="if toggled, `torch.backends.cudnn.deterministic=False`")
|
28 |
+
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
29 |
+
help="if toggled, cuda will be enabled by default")
|
30 |
+
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
|
31 |
+
help="if toggled, this experiment will be tracked with Weights and Biases")
|
32 |
+
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
|
33 |
+
help="the wandb's project name")
|
34 |
+
parser.add_argument("--wandb-entity", type=str, default=None,
|
35 |
+
help="the entity (team) of wandb's project")
|
36 |
+
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
|
37 |
+
help="whether to capture videos of the agent performances (check out `videos` folder)")
|
38 |
+
|
39 |
+
# Algorithm specific arguments
|
40 |
+
parser.add_argument("--env-id", type=str, default="CartPole-v1",
|
41 |
+
help="the id of the environment")
|
42 |
+
parser.add_argument("--total-timesteps", type=int, default=500000,
|
43 |
+
help="total timesteps of the experiments")
|
44 |
+
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
|
45 |
+
help="the learning rate of the optimizer")
|
46 |
+
parser.add_argument("--num-envs", type=int, default=4,
|
47 |
+
help="the number of parallel game environments")
|
48 |
+
parser.add_argument("--num-steps", type=int, default=128,
|
49 |
+
help="the number of steps to run in each environment per policy rollout")
|
50 |
+
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
51 |
+
help="Toggle learning rate annealing for policy and value networks")
|
52 |
+
parser.add_argument("--gamma", type=float, default=0.99,
|
53 |
+
help="the discount factor gamma")
|
54 |
+
parser.add_argument("--gae-lambda", type=float, default=0.95,
|
55 |
+
help="the lambda for the general advantage estimation")
|
56 |
+
parser.add_argument("--num-minibatches", type=int, default=4,
|
57 |
+
help="the number of mini-batches")
|
58 |
+
parser.add_argument("--update-epochs", type=int, default=4,
|
59 |
+
help="the K epochs to update the policy")
|
60 |
+
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
61 |
+
help="Toggles advantages normalization")
|
62 |
+
parser.add_argument("--clip-coef", type=float, default=0.2,
|
63 |
+
help="the surrogate clipping coefficient")
|
64 |
+
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
65 |
+
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
|
66 |
+
parser.add_argument("--ent-coef", type=float, default=0.01,
|
67 |
+
help="coefficient of the entropy")
|
68 |
+
parser.add_argument("--vf-coef", type=float, default=0.5,
|
69 |
+
help="coefficient of the value function")
|
70 |
+
parser.add_argument("--max-grad-norm", type=float, default=0.5,
|
71 |
+
help="the maximum norm for the gradient clipping")
|
72 |
+
parser.add_argument("--target-kl", type=float, default=None,
|
73 |
+
help="the target KL divergence threshold")
|
74 |
+
args = parser.parse_args()
|
75 |
+
args.batch_size = int(args.num_envs * args.num_steps)
|
76 |
+
args.minibatch_size = int(args.batch_size // args.num_minibatches)
|
77 |
+
# fmt: on
|
78 |
+
return args
|
79 |
+
|
80 |
+
|
81 |
+
def make_env(env_id, seed, idx, capture_video, run_name):
|
82 |
+
def thunk():
|
83 |
+
env = gym.make(env_id)
|
84 |
+
if capture_video:
|
85 |
+
if idx == 0:
|
86 |
+
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
|
87 |
+
env.action_space.seed(seed)
|
88 |
+
env.observation_space.seed(seed)
|
89 |
+
return env
|
90 |
+
|
91 |
+
return thunk
|
92 |
+
|
93 |
+
|
94 |
+
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
95 |
+
torch.nn.init.orthogonal_(layer.weight, std)
|
96 |
+
torch.nn.init.constant_(layer.bias, bias_const)
|
97 |
+
return layer
|
98 |
+
|
99 |
+
|
100 |
+
class Agent(nn.Module):
|
101 |
+
def __init__(self, envs):
|
102 |
+
super().__init__()
|
103 |
+
self.critic = nn.Sequential(
|
104 |
+
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
|
105 |
+
nn.Tanh(),
|
106 |
+
layer_init(nn.Linear(64, 64)),
|
107 |
+
nn.Tanh(),
|
108 |
+
layer_init(nn.Linear(64, 1), std=1.0),
|
109 |
+
)
|
110 |
+
self.actor = nn.Sequential(
|
111 |
+
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
|
112 |
+
nn.Tanh(),
|
113 |
+
layer_init(nn.Linear(64, 64)),
|
114 |
+
nn.Tanh(),
|
115 |
+
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
|
116 |
+
)
|
117 |
+
|
118 |
+
def get_value(self, x):
|
119 |
+
return self.critic(x)
|
120 |
+
|
121 |
+
def get_action_and_value(self, x, action=None):
|
122 |
+
logits = self.actor(x)
|
123 |
+
probs = Categorical(logits=logits)
|
124 |
+
if action is None:
|
125 |
+
action = probs.sample()
|
126 |
+
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
args = parse_args()
|
131 |
+
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
|
132 |
+
if args.track:
|
133 |
+
import wandb
|
134 |
+
|
135 |
+
wandb.init(
|
136 |
+
project=args.wandb_project_name,
|
137 |
+
entity=args.wandb_entity,
|
138 |
+
sync_tensorboard=True,
|
139 |
+
config=vars(args),
|
140 |
+
name=run_name,
|
141 |
+
monitor_gym=True,
|
142 |
+
save_code=True,
|
143 |
+
)
|
144 |
+
writer = SummaryWriter(f"runs/{run_name}")
|
145 |
+
writer.add_text(
|
146 |
+
"hyperparameters",
|
147 |
+
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
148 |
+
)
|
149 |
+
|
150 |
+
# TRY NOT TO MODIFY: seeding
|
151 |
+
random.seed(args.seed)
|
152 |
+
np.random.seed(args.seed)
|
153 |
+
torch.manual_seed(args.seed)
|
154 |
+
torch.backends.cudnn.deterministic = args.torch_deterministic
|
155 |
+
|
156 |
+
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
|
157 |
+
|
158 |
+
# env setup
|
159 |
+
envs = gym.vector.SyncVectorEnv(
|
160 |
+
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
|
161 |
+
)
|
162 |
+
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
163 |
+
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
|
164 |
+
|
165 |
+
agent = Agent(envs).to(device)
|
166 |
+
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
|
167 |
+
|
168 |
+
# ALGO Logic: Storage setup
|
169 |
+
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
|
170 |
+
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
171 |
+
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
172 |
+
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
173 |
+
terminateds = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
174 |
+
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
175 |
+
|
176 |
+
# TRY NOT TO MODIFY: start the game
|
177 |
+
global_step = 0
|
178 |
+
start_time = time.time()
|
179 |
+
next_obs = torch.Tensor(envs.reset()[0]).to(device)
|
180 |
+
next_terminated = torch.zeros(args.num_envs).to(device)
|
181 |
+
num_updates = args.total_timesteps // args.batch_size
|
182 |
+
|
183 |
+
for update in range(1, num_updates + 1):
|
184 |
+
# Annealing the rate if instructed to do so.
|
185 |
+
if args.anneal_lr:
|
186 |
+
frac = 1.0 - (update - 1.0) / num_updates
|
187 |
+
lrnow = frac * args.learning_rate
|
188 |
+
optimizer.param_groups[0]["lr"] = lrnow
|
189 |
+
|
190 |
+
for step in range(0, args.num_steps):
|
191 |
+
global_step += 1 * args.num_envs
|
192 |
+
obs[step] = next_obs
|
193 |
+
terminateds[step] = next_terminated
|
194 |
+
|
195 |
+
# ALGO LOGIC: action logic
|
196 |
+
with torch.no_grad():
|
197 |
+
action, logprob, _, value = agent.get_action_and_value(next_obs)
|
198 |
+
values[step] = value.flatten()
|
199 |
+
actions[step] = action
|
200 |
+
logprobs[step] = logprob
|
201 |
+
|
202 |
+
# TRY NOT TO MODIFY: execute the game and log data.
|
203 |
+
next_obs, reward, terminated, _, info = envs.step(action.cpu().numpy())
|
204 |
+
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
205 |
+
next_obs, next_terminated = torch.Tensor(next_obs).to(device), torch.Tensor(terminated).to(device)
|
206 |
+
|
207 |
+
if "episode" in info:
|
208 |
+
first_idx = info["_episode"].nonzero()[0][0]
|
209 |
+
r = info["episode"]["r"][first_idx]
|
210 |
+
l = info["episode"]["l"][first_idx]
|
211 |
+
print(f"global_step={global_step}, episodic_return={r}")
|
212 |
+
writer.add_scalar("charts/episodic_return", r, global_step)
|
213 |
+
writer.add_scalar("charts/episodic_length", l, global_step)
|
214 |
+
|
215 |
+
# bootstrap value if not terminated
|
216 |
+
with torch.no_grad():
|
217 |
+
next_value = agent.get_value(next_obs).reshape(1, -1)
|
218 |
+
advantages = torch.zeros_like(rewards).to(device)
|
219 |
+
lastgaelam = 0
|
220 |
+
for t in reversed(range(args.num_steps)):
|
221 |
+
if t == args.num_steps - 1:
|
222 |
+
nextnonterminal = 1.0 - next_terminated
|
223 |
+
nextvalues = next_value
|
224 |
+
else:
|
225 |
+
nextnonterminal = 1.0 - terminateds[t + 1]
|
226 |
+
nextvalues = values[t + 1]
|
227 |
+
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
|
228 |
+
advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
|
229 |
+
returns = advantages + values
|
230 |
+
|
231 |
+
# flatten the batch
|
232 |
+
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
|
233 |
+
b_logprobs = logprobs.reshape(-1)
|
234 |
+
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
|
235 |
+
b_advantages = advantages.reshape(-1)
|
236 |
+
b_returns = returns.reshape(-1)
|
237 |
+
b_values = values.reshape(-1)
|
238 |
+
|
239 |
+
# Optimizing the policy and value network
|
240 |
+
b_inds = np.arange(args.batch_size)
|
241 |
+
clipfracs = []
|
242 |
+
for epoch in range(args.update_epochs):
|
243 |
+
np.random.shuffle(b_inds)
|
244 |
+
for start in range(0, args.batch_size, args.minibatch_size):
|
245 |
+
end = start + args.minibatch_size
|
246 |
+
mb_inds = b_inds[start:end]
|
247 |
+
|
248 |
+
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
|
249 |
+
logratio = newlogprob - b_logprobs[mb_inds]
|
250 |
+
ratio = logratio.exp()
|
251 |
+
|
252 |
+
with torch.no_grad():
|
253 |
+
# calculate approx_kl http://joschu.net/blog/kl-approx.html
|
254 |
+
old_approx_kl = (-logratio).mean()
|
255 |
+
approx_kl = ((ratio - 1) - logratio).mean()
|
256 |
+
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
|
257 |
+
|
258 |
+
mb_advantages = b_advantages[mb_inds]
|
259 |
+
if args.norm_adv:
|
260 |
+
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
|
261 |
+
|
262 |
+
# Policy loss
|
263 |
+
pg_loss1 = -mb_advantages * ratio
|
264 |
+
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
|
265 |
+
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
266 |
+
|
267 |
+
# Value loss
|
268 |
+
newvalue = newvalue.view(-1)
|
269 |
+
if args.clip_vloss:
|
270 |
+
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
|
271 |
+
v_clipped = b_values[mb_inds] + torch.clamp(
|
272 |
+
newvalue - b_values[mb_inds],
|
273 |
+
-args.clip_coef,
|
274 |
+
args.clip_coef,
|
275 |
+
)
|
276 |
+
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
|
277 |
+
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
278 |
+
v_loss = 0.5 * v_loss_max.mean()
|
279 |
+
else:
|
280 |
+
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
|
281 |
+
|
282 |
+
entropy_loss = entropy.mean()
|
283 |
+
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
|
284 |
+
|
285 |
+
optimizer.zero_grad()
|
286 |
+
loss.backward()
|
287 |
+
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
|
288 |
+
optimizer.step()
|
289 |
+
|
290 |
+
if args.target_kl is not None:
|
291 |
+
if approx_kl > args.target_kl:
|
292 |
+
break
|
293 |
+
|
294 |
+
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
|
295 |
+
var_y = np.var(y_true)
|
296 |
+
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
297 |
+
|
298 |
+
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
299 |
+
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
|
300 |
+
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
|
301 |
+
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
|
302 |
+
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
|
303 |
+
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
|
304 |
+
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
|
305 |
+
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
|
306 |
+
writer.add_scalar("losses/explained_variance", explained_var, global_step)
|
307 |
+
print("SPS:", int(global_step / (time.time() - start_time)))
|
308 |
+
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
|
309 |
+
|
310 |
+
envs.close()
|
311 |
+
writer.close()
|
ppo_or.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import shutil
|
7 |
+
import time
|
8 |
+
from distutils.util import strtobool
|
9 |
+
|
10 |
+
import gym
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim as optim
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
# fmt: off
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
|
22 |
+
help="the name of this experiment")
|
23 |
+
parser.add_argument("--seed", type=int, default=1,
|
24 |
+
help="seed of the experiment")
|
25 |
+
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
26 |
+
help="if toggled, `torch.backends.cudnn.deterministic=False`")
|
27 |
+
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
28 |
+
help="if toggled, cuda will be enabled by default")
|
29 |
+
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
|
30 |
+
help="if toggled, this experiment will be tracked with Weights and Biases")
|
31 |
+
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
|
32 |
+
help="the wandb's project name")
|
33 |
+
parser.add_argument("--wandb-entity", type=str, default=None,
|
34 |
+
help="the entity (team) of wandb's project")
|
35 |
+
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
|
36 |
+
help="whether to capture videos of the agent performances (check out `videos` folder)")
|
37 |
+
|
38 |
+
# Algorithm specific arguments
|
39 |
+
parser.add_argument("--problem", type=str, default="cvrp",
|
40 |
+
help="the OR problem we are trying to solve, it will be passed to the agent")
|
41 |
+
parser.add_argument("--env-id", type=str, default="cvrp-v0",
|
42 |
+
help="the id of the environment")
|
43 |
+
parser.add_argument("--env-entry-point", type=str, default="envs.cvrp_vector_env:CVRPVectorEnv",
|
44 |
+
help="the path to the definition of the environment, for example `envs.cvrp_vector_env:CVRPVectorEnv` if the `CVRPVectorEnv` class is defined in ./envs/cvrp_vector_env.py")
|
45 |
+
parser.add_argument("--total-timesteps", type=int, default=6_000_000_000,
|
46 |
+
help="total timesteps of the experiments")
|
47 |
+
parser.add_argument("--learning-rate", type=float, default=1e-3,
|
48 |
+
help="the learning rate of the optimizer")
|
49 |
+
parser.add_argument("--weight-decay", type=float, default=0,
|
50 |
+
help="the weight decay of the optimizer")
|
51 |
+
parser.add_argument("--num-envs", type=int, default=1024,
|
52 |
+
help="the number of parallel game environments")
|
53 |
+
parser.add_argument("--num-steps", type=int, default=100,
|
54 |
+
help="the number of steps to run in each environment per policy rollout")
|
55 |
+
parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
56 |
+
help="Toggle learning rate annealing for policy and value networks")
|
57 |
+
parser.add_argument("--gamma", type=float, default=0.99,
|
58 |
+
help="the discount factor gamma")
|
59 |
+
parser.add_argument("--gae-lambda", type=float, default=0.95,
|
60 |
+
help="the lambda for the general advantage estimation")
|
61 |
+
parser.add_argument("--num-minibatches", type=int, default=8,
|
62 |
+
help="the number of mini-batches")
|
63 |
+
parser.add_argument("--update-epochs", type=int, default=2,
|
64 |
+
help="the K epochs to update the policy")
|
65 |
+
parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
66 |
+
help="Toggles advantages normalization")
|
67 |
+
parser.add_argument("--clip-coef", type=float, default=0.2,
|
68 |
+
help="the surrogate clipping coefficient")
|
69 |
+
parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
70 |
+
help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
|
71 |
+
parser.add_argument("--ent-coef", type=float, default=0.01,
|
72 |
+
help="coefficient of the entropy")
|
73 |
+
parser.add_argument("--vf-coef", type=float, default=0.5,
|
74 |
+
help="coefficient of the value function")
|
75 |
+
parser.add_argument("--max-grad-norm", type=float, default=0.5,
|
76 |
+
help="the maximum norm for the gradient clipping")
|
77 |
+
parser.add_argument("--target-kl", type=float, default=None,
|
78 |
+
help="the target KL divergence threshold")
|
79 |
+
parser.add_argument("--n-traj", type=int, default=50,
|
80 |
+
help="number of trajectories in a vectorized sub-environment")
|
81 |
+
parser.add_argument("--n-test", type=int, default=1000,
|
82 |
+
help="how many test instance")
|
83 |
+
parser.add_argument("--multi-greedy-inference", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
84 |
+
help="whether to use multiple trajectory greedy inference")
|
85 |
+
args = parser.parse_args()
|
86 |
+
args.batch_size = int(args.num_envs * args.num_steps)
|
87 |
+
args.minibatch_size = int(args.batch_size // args.num_minibatches)
|
88 |
+
# fmt: on
|
89 |
+
return args
|
90 |
+
|
91 |
+
|
92 |
+
from wrappers.recordWrapper import RecordEpisodeStatistics
|
93 |
+
|
94 |
+
|
95 |
+
def make_env(env_id, seed, cfg={}):
|
96 |
+
def thunk():
|
97 |
+
env = gym.make(env_id, **cfg)
|
98 |
+
env = RecordEpisodeStatistics(env)
|
99 |
+
env.seed(seed)
|
100 |
+
env.action_space.seed(seed)
|
101 |
+
env.observation_space.seed(seed)
|
102 |
+
return env
|
103 |
+
|
104 |
+
return thunk
|
105 |
+
|
106 |
+
|
107 |
+
from models.attention_model_wrapper import Agent
|
108 |
+
from wrappers.syncVectorEnvPomo import SyncVectorEnv
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
args = parse_args()
|
112 |
+
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
|
113 |
+
if args.track:
|
114 |
+
import wandb
|
115 |
+
|
116 |
+
wandb.init(
|
117 |
+
project=args.wandb_project_name,
|
118 |
+
entity=args.wandb_entity,
|
119 |
+
sync_tensorboard=True,
|
120 |
+
config=vars(args),
|
121 |
+
name=run_name,
|
122 |
+
monitor_gym=True,
|
123 |
+
save_code=True,
|
124 |
+
)
|
125 |
+
writer = SummaryWriter(f"runs/{run_name}")
|
126 |
+
writer.add_text(
|
127 |
+
"hyperparameters",
|
128 |
+
"|param|value|\n|-|-|\n%s"
|
129 |
+
% ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
130 |
+
)
|
131 |
+
os.makedirs(os.path.join(f"runs/{run_name}", "ckpt"), exist_ok=True)
|
132 |
+
shutil.copy(__file__, os.path.join(f"runs/{run_name}", "main.py"))
|
133 |
+
# TRY NOT TO MODIFY: seeding
|
134 |
+
random.seed(args.seed)
|
135 |
+
np.random.seed(args.seed)
|
136 |
+
torch.manual_seed(args.seed)
|
137 |
+
torch.backends.cudnn.deterministic = args.torch_deterministic
|
138 |
+
|
139 |
+
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
|
140 |
+
|
141 |
+
#######################
|
142 |
+
#### Env defintion ####
|
143 |
+
#######################
|
144 |
+
|
145 |
+
gym.envs.register(
|
146 |
+
id=args.env_id,
|
147 |
+
entry_point=args.env_entry_point,
|
148 |
+
)
|
149 |
+
|
150 |
+
# training env setup
|
151 |
+
envs = SyncVectorEnv([make_env(args.env_id, args.seed + i) for i in range(args.num_envs)])
|
152 |
+
# evaluation env setup: 1.) from a fix dataset, or 2.) generated with seed
|
153 |
+
# 1.) use test instance from a fix dataset
|
154 |
+
test_envs = SyncVectorEnv(
|
155 |
+
[
|
156 |
+
make_env(
|
157 |
+
args.env_id,
|
158 |
+
args.seed + i,
|
159 |
+
cfg={"eval_data": True, "eval_partition": "eval", "eval_data_idx": i},
|
160 |
+
)
|
161 |
+
for i in range(args.n_test)
|
162 |
+
]
|
163 |
+
)
|
164 |
+
# # 2.) use generated evaluation instance instead
|
165 |
+
# import logging
|
166 |
+
# logging.warning('Using generated evaluation instance. For benchmarking, please download the fix dataset.')
|
167 |
+
# test_envs = SyncVectorEnv([make_env(args.env_id, args.seed + args.num_envs + i) for i in range(args.n_test)])
|
168 |
+
|
169 |
+
assert isinstance(
|
170 |
+
envs.single_action_space, gym.spaces.MultiDiscrete
|
171 |
+
), "only discrete action space is supported"
|
172 |
+
|
173 |
+
#######################
|
174 |
+
### Agent defintion ###
|
175 |
+
#######################
|
176 |
+
|
177 |
+
agent = Agent(device=device, name=args.problem).to(device)
|
178 |
+
# agent.backbone.load_state_dict(torch.load('./vrp50.pt'))
|
179 |
+
optimizer = optim.Adam(
|
180 |
+
agent.parameters(), lr=args.learning_rate, eps=1e-5, weight_decay=args.weight_decay
|
181 |
+
)
|
182 |
+
|
183 |
+
#######################
|
184 |
+
# Algorithm defintion #
|
185 |
+
#######################
|
186 |
+
|
187 |
+
# ALGO Logic: Storage setup
|
188 |
+
obs = [None] * args.num_steps
|
189 |
+
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(
|
190 |
+
device
|
191 |
+
)
|
192 |
+
logprobs = torch.zeros((args.num_steps, args.num_envs, args.n_traj)).to(device)
|
193 |
+
rewards = torch.zeros((args.num_steps, args.num_envs, args.n_traj)).to(device)
|
194 |
+
dones = torch.zeros((args.num_steps, args.num_envs, args.n_traj)).to(device)
|
195 |
+
values = torch.zeros((args.num_steps, args.num_envs, args.n_traj)).to(device)
|
196 |
+
|
197 |
+
# TRY NOT TO MODIFY: start the game
|
198 |
+
global_step = 0
|
199 |
+
start_time = time.time()
|
200 |
+
next_obs = envs.reset()
|
201 |
+
next_done = torch.zeros(args.num_envs, args.n_traj).to(device)
|
202 |
+
num_updates = args.total_timesteps // args.batch_size
|
203 |
+
for update in range(1, num_updates + 1):
|
204 |
+
agent.train()
|
205 |
+
# Annealing the rate if instructed to do so.
|
206 |
+
if args.anneal_lr:
|
207 |
+
frac = 1.0 - (update - 1.0) / num_updates
|
208 |
+
lrnow = frac * args.learning_rate
|
209 |
+
optimizer.param_groups[0]["lr"] = lrnow
|
210 |
+
next_obs = envs.reset()
|
211 |
+
encoder_state = agent.backbone.encode(next_obs)
|
212 |
+
next_done = torch.zeros(args.num_envs, args.n_traj).to(device)
|
213 |
+
r = []
|
214 |
+
for step in range(0, args.num_steps):
|
215 |
+
global_step += 1 * args.num_envs
|
216 |
+
obs[step] = next_obs
|
217 |
+
dones[step] = next_done
|
218 |
+
|
219 |
+
# ALGO LOGIC: action logic
|
220 |
+
with torch.no_grad():
|
221 |
+
action, logprob, _, value, _ = agent.get_action_and_value_cached(
|
222 |
+
next_obs, state=encoder_state
|
223 |
+
)
|
224 |
+
action = action.view(args.num_envs, args.n_traj)
|
225 |
+
values[step] = value.view(args.num_envs, args.n_traj)
|
226 |
+
actions[step] = action
|
227 |
+
logprobs[step] = logprob.view(args.num_envs, args.n_traj)
|
228 |
+
# TRY NOT TO MODIFY: execute the game and log data.
|
229 |
+
next_obs, reward, done, info = envs.step(action.cpu().numpy())
|
230 |
+
rewards[step] = torch.tensor(reward).to(device)
|
231 |
+
next_obs, next_done = next_obs, torch.Tensor(done).to(device)
|
232 |
+
|
233 |
+
for item in info:
|
234 |
+
if "episode" in item.keys():
|
235 |
+
r.append(item)
|
236 |
+
print("completed_episodes=", len(r))
|
237 |
+
avg_episodic_return = np.mean([rollout["episode"]["r"].mean() for rollout in r])
|
238 |
+
max_episodic_return = np.mean([rollout["episode"]["r"].max() for rollout in r])
|
239 |
+
avg_episodic_length = np.mean([rollout["episode"]["l"].mean() for rollout in r])
|
240 |
+
print(
|
241 |
+
f"[Train] global_step={global_step}\n \
|
242 |
+
avg_episodic_return={avg_episodic_return}\n \
|
243 |
+
max_episodic_return={max_episodic_return}\n \
|
244 |
+
avg_episodic_length={avg_episodic_length}"
|
245 |
+
)
|
246 |
+
writer.add_scalar("charts/episodic_return_mean", avg_episodic_return, global_step)
|
247 |
+
writer.add_scalar("charts/episodic_return_max", max_episodic_return, global_step)
|
248 |
+
writer.add_scalar("charts/episodic_length", avg_episodic_length, global_step)
|
249 |
+
# bootstrap value if not done
|
250 |
+
with torch.no_grad():
|
251 |
+
next_value = agent.get_value_cached(next_obs, encoder_state).squeeze(-1) # B x T
|
252 |
+
advantages = torch.zeros_like(rewards).to(device) # steps x B x T
|
253 |
+
lastgaelam = torch.zeros(args.num_envs, args.n_traj).to(device) # B x T
|
254 |
+
for t in reversed(range(args.num_steps)):
|
255 |
+
if t == args.num_steps - 1:
|
256 |
+
nextnonterminal = 1.0 - next_done # next_done: B
|
257 |
+
nextvalues = next_value # B x T
|
258 |
+
else:
|
259 |
+
nextnonterminal = 1.0 - dones[t + 1]
|
260 |
+
nextvalues = values[t + 1] # B x T
|
261 |
+
delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
|
262 |
+
advantages[t] = lastgaelam = (
|
263 |
+
delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
|
264 |
+
)
|
265 |
+
returns = advantages + values
|
266 |
+
|
267 |
+
# flatten the batch
|
268 |
+
b_obs = {
|
269 |
+
k: np.concatenate([obs_[k] for obs_ in obs]) for k in envs.single_observation_space
|
270 |
+
}
|
271 |
+
|
272 |
+
# Edited
|
273 |
+
b_logprobs = logprobs.reshape(-1, args.n_traj)
|
274 |
+
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
|
275 |
+
b_advantages = advantages.reshape(-1, args.n_traj)
|
276 |
+
b_returns = returns.reshape(-1, args.n_traj)
|
277 |
+
b_values = values.reshape(-1, args.n_traj)
|
278 |
+
|
279 |
+
# Optimizing the policy and value network
|
280 |
+
assert args.num_envs % args.num_minibatches == 0
|
281 |
+
envsperbatch = args.num_envs // args.num_minibatches
|
282 |
+
envinds = np.arange(args.num_envs)
|
283 |
+
flatinds = np.arange(args.batch_size).reshape(args.num_steps, args.num_envs)
|
284 |
+
|
285 |
+
clipfracs = []
|
286 |
+
for epoch in range(args.update_epochs):
|
287 |
+
np.random.shuffle(envinds)
|
288 |
+
for start in range(0, args.num_envs, envsperbatch):
|
289 |
+
end = start + envsperbatch
|
290 |
+
mbenvinds = envinds[start:end] # mini batch env id
|
291 |
+
mb_inds = flatinds[:, mbenvinds].ravel() # be really careful about the index
|
292 |
+
r_inds = np.tile(np.arange(envsperbatch), args.num_steps)
|
293 |
+
|
294 |
+
cur_obs = {k: v[mbenvinds] for k, v in obs[0].items()}
|
295 |
+
encoder_state = agent.backbone.encode(cur_obs)
|
296 |
+
_, newlogprob, entropy, newvalue, _ = agent.get_action_and_value_cached(
|
297 |
+
{k: v[mb_inds] for k, v in b_obs.items()},
|
298 |
+
b_actions.long()[mb_inds],
|
299 |
+
(embedding[r_inds, :] for embedding in encoder_state),
|
300 |
+
)
|
301 |
+
# _, newlogprob, entropy, newvalue = agent.get_action_and_value({k:v[mb_inds] for k,v in b_obs.items()}, b_actions.long()[mb_inds])
|
302 |
+
logratio = newlogprob - b_logprobs[mb_inds]
|
303 |
+
ratio = logratio.exp()
|
304 |
+
|
305 |
+
with torch.no_grad():
|
306 |
+
# calculate approx_kl http://joschu.net/blog/kl-approx.html
|
307 |
+
old_approx_kl = (-logratio).mean()
|
308 |
+
approx_kl = ((ratio - 1) - logratio).mean()
|
309 |
+
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
|
310 |
+
|
311 |
+
mb_advantages = b_advantages[mb_inds]
|
312 |
+
if args.norm_adv:
|
313 |
+
mb_advantages = (mb_advantages - mb_advantages.mean()) / (
|
314 |
+
mb_advantages.std() + 1e-8
|
315 |
+
)
|
316 |
+
|
317 |
+
# Policy loss
|
318 |
+
pg_loss1 = -mb_advantages * ratio
|
319 |
+
pg_loss2 = -mb_advantages * torch.clamp(
|
320 |
+
ratio, 1 - args.clip_coef, 1 + args.clip_coef
|
321 |
+
)
|
322 |
+
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
323 |
+
# Value loss
|
324 |
+
newvalue = newvalue.view(-1, args.n_traj)
|
325 |
+
if args.clip_vloss:
|
326 |
+
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
|
327 |
+
v_clipped = b_values[mb_inds] + torch.clamp(
|
328 |
+
newvalue - b_values[mb_inds],
|
329 |
+
-args.clip_coef,
|
330 |
+
args.clip_coef,
|
331 |
+
)
|
332 |
+
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
|
333 |
+
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
|
334 |
+
v_loss = 0.5 * v_loss_max.mean()
|
335 |
+
else:
|
336 |
+
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
|
337 |
+
|
338 |
+
entropy_loss = entropy.mean()
|
339 |
+
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
|
340 |
+
|
341 |
+
optimizer.zero_grad()
|
342 |
+
loss.backward()
|
343 |
+
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
|
344 |
+
optimizer.step()
|
345 |
+
|
346 |
+
if args.target_kl is not None:
|
347 |
+
if approx_kl > args.target_kl:
|
348 |
+
break
|
349 |
+
|
350 |
+
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
|
351 |
+
var_y = np.var(y_true)
|
352 |
+
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
353 |
+
|
354 |
+
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
355 |
+
writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
|
356 |
+
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
|
357 |
+
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
|
358 |
+
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
|
359 |
+
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
|
360 |
+
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
|
361 |
+
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
|
362 |
+
writer.add_scalar("losses/explained_variance", explained_var, global_step)
|
363 |
+
print("SPS:", int(global_step / (time.time() - start_time)))
|
364 |
+
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
|
365 |
+
if update % 1000 == 0 or update == num_updates:
|
366 |
+
torch.save(agent.state_dict(), f"runs/{run_name}/ckpt/{update}.pt")
|
367 |
+
if update % 100 == 0 or update == num_updates:
|
368 |
+
agent.eval()
|
369 |
+
test_obs = test_envs.reset()
|
370 |
+
r = []
|
371 |
+
for step in range(0, args.num_steps):
|
372 |
+
# ALGO LOGIC: action logic
|
373 |
+
with torch.no_grad():
|
374 |
+
action, logits = agent(test_obs)
|
375 |
+
if step == 0:
|
376 |
+
if args.multi_greedy_inference:
|
377 |
+
if args.problem == 'tsp':
|
378 |
+
action = torch.arange(args.n_traj).repeat(args.n_test, 1)
|
379 |
+
elif args.problem == 'cvrp':
|
380 |
+
action = torch.arange(1, args.n_traj + 1).repeat(args.n_test, 1)
|
381 |
+
# TRY NOT TO MODIFY: execute the game and log data.
|
382 |
+
test_obs, _, _, test_info = test_envs.step(action.cpu().numpy())
|
383 |
+
|
384 |
+
for item in test_info:
|
385 |
+
if "episode" in item.keys():
|
386 |
+
r.append(item)
|
387 |
+
|
388 |
+
avg_episodic_return = np.mean([rollout["episode"]["r"].mean() for rollout in r])
|
389 |
+
max_episodic_return = np.mean([rollout["episode"]["r"].max() for rollout in r])
|
390 |
+
avg_episodic_length = np.mean([rollout["episode"]["l"].mean() for rollout in r])
|
391 |
+
print(f"[test] episodic_return={max_episodic_return}")
|
392 |
+
writer.add_scalar("test/episodic_return_mean", avg_episodic_return, global_step)
|
393 |
+
writer.add_scalar("test/episodic_return_max", max_episodic_return, global_step)
|
394 |
+
writer.add_scalar("test/episodic_length", avg_episodic_length, global_step)
|
395 |
+
|
396 |
+
envs.close()
|
397 |
+
writer.close()
|
runs/cvrp-v0__ppo_or__1__1678159979/ckpt/12000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f3be054c6d5517a8bb522713cfb8b8a5b8841a6c2729dedb65f0286f54df856
|
3 |
+
size 2865239
|
runs/tsp-v0__ppo_or__1__1678160003/ckpt/12000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab6ea48bdcef97438227ad0751836ee7fe8e49ee11d220201c4fbb8a6ddc5bd6
|
3 |
+
size 2928924
|
wrappers/recordWrapper.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from collections import deque
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class RecordEpisodeStatistics(gym.Wrapper):
|
9 |
+
def __init__(self, env, deque_size=100):
|
10 |
+
super().__init__(env)
|
11 |
+
self.num_envs = getattr(env, "num_envs", 1)
|
12 |
+
self.n_traj = env.n_traj
|
13 |
+
self.t0 = time.perf_counter()
|
14 |
+
self.episode_count = 0
|
15 |
+
self.episode_returns = None
|
16 |
+
self.episode_lengths = None
|
17 |
+
self.return_queue = deque(maxlen=deque_size)
|
18 |
+
self.length_queue = deque(maxlen=deque_size)
|
19 |
+
self.is_vector_env = getattr(env, "is_vector_env", False)
|
20 |
+
|
21 |
+
def reset(self, **kwargs):
|
22 |
+
observations = super().reset(**kwargs)
|
23 |
+
self.episode_returns = np.zeros((self.num_envs, self.n_traj), dtype=np.float32)
|
24 |
+
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
25 |
+
self.finished = [False] * self.num_envs
|
26 |
+
return observations
|
27 |
+
|
28 |
+
def step(self, action):
|
29 |
+
observations, rewards, dones, infos = super().step(action)
|
30 |
+
self.episode_returns += rewards
|
31 |
+
self.episode_lengths += 1
|
32 |
+
if not self.is_vector_env:
|
33 |
+
infos = [infos]
|
34 |
+
dones = [dones]
|
35 |
+
else:
|
36 |
+
infos = list(infos) # Convert infos to mutable type
|
37 |
+
for i in range(len(dones)):
|
38 |
+
if dones[i].all() and not self.finished[i]:
|
39 |
+
infos[i] = infos[i].copy()
|
40 |
+
episode_return = self.episode_returns[i]
|
41 |
+
episode_length = self.episode_lengths[i]
|
42 |
+
episode_info = {
|
43 |
+
"r": episode_return.copy(),
|
44 |
+
"l": episode_length,
|
45 |
+
"t": round(time.perf_counter() - self.t0, 6),
|
46 |
+
}
|
47 |
+
infos[i]["episode"] = episode_info
|
48 |
+
self.return_queue.append(episode_return)
|
49 |
+
self.length_queue.append(episode_length)
|
50 |
+
self.episode_count += 1
|
51 |
+
self.episode_returns[i] = 0
|
52 |
+
self.episode_lengths[i] = 0
|
53 |
+
self.finished[i] = True
|
54 |
+
|
55 |
+
if self.is_vector_env:
|
56 |
+
infos = tuple(infos)
|
57 |
+
return (
|
58 |
+
observations,
|
59 |
+
rewards,
|
60 |
+
dones if self.is_vector_env else dones[0],
|
61 |
+
infos if self.is_vector_env else infos[0],
|
62 |
+
)
|
wrappers/syncVectorEnvPomo.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from gym.vector.utils import concatenate, create_empty_array, iterate
|
6 |
+
from gym.vector.vector_env import VectorEnv
|
7 |
+
|
8 |
+
__all__ = ["SyncVectorEnv"]
|
9 |
+
|
10 |
+
|
11 |
+
class SyncVectorEnv(VectorEnv):
|
12 |
+
"""Vectorized environment that serially runs multiple environments.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
env_fns : iterable of callable
|
17 |
+
Functions that create the environments.
|
18 |
+
|
19 |
+
observation_space : :class:`gym.spaces.Space`, optional
|
20 |
+
Observation space of a single environment. If ``None``, then the
|
21 |
+
observation space of the first environment is taken.
|
22 |
+
|
23 |
+
action_space : :class:`gym.spaces.Space`, optional
|
24 |
+
Action space of a single environment. If ``None``, then the action space
|
25 |
+
of the first environment is taken.
|
26 |
+
|
27 |
+
copy : bool
|
28 |
+
If ``True``, then the :meth:`reset` and :meth:`step` methods return a
|
29 |
+
copy of the observations.
|
30 |
+
|
31 |
+
Raises
|
32 |
+
------
|
33 |
+
RuntimeError
|
34 |
+
If the observation space of some sub-environment does not match
|
35 |
+
:obj:`observation_space` (or, by default, the observation space of
|
36 |
+
the first sub-environment).
|
37 |
+
|
38 |
+
Example
|
39 |
+
-------
|
40 |
+
|
41 |
+
.. code-block::
|
42 |
+
|
43 |
+
>>> env = gym.vector.SyncVectorEnv([
|
44 |
+
... lambda: gym.make("Pendulum-v0", g=9.81),
|
45 |
+
... lambda: gym.make("Pendulum-v0", g=1.62)
|
46 |
+
... ])
|
47 |
+
>>> env.reset()
|
48 |
+
array([[-0.8286432 , 0.5597771 , 0.90249056],
|
49 |
+
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
|
53 |
+
self.env_fns = env_fns
|
54 |
+
self.envs = [env_fn() for env_fn in env_fns]
|
55 |
+
self.copy = copy
|
56 |
+
self.metadata = self.envs[0].metadata
|
57 |
+
self.n_traj = self.envs[0].n_traj
|
58 |
+
|
59 |
+
if (observation_space is None) or (action_space is None):
|
60 |
+
observation_space = observation_space or self.envs[0].observation_space
|
61 |
+
action_space = action_space or self.envs[0].action_space
|
62 |
+
super().__init__(
|
63 |
+
num_envs=len(env_fns),
|
64 |
+
observation_space=observation_space,
|
65 |
+
action_space=action_space,
|
66 |
+
)
|
67 |
+
|
68 |
+
self._check_spaces()
|
69 |
+
self.observations = create_empty_array(
|
70 |
+
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
71 |
+
)
|
72 |
+
self._rewards = np.zeros((self.num_envs, self.n_traj), dtype=np.float64)
|
73 |
+
self._dones = np.zeros((self.num_envs, self.n_traj), dtype=np.bool_)
|
74 |
+
self._actions = None
|
75 |
+
|
76 |
+
def seed(self, seed=None):
|
77 |
+
super().seed(seed=seed)
|
78 |
+
if seed is None:
|
79 |
+
seed = [None for _ in range(self.num_envs)]
|
80 |
+
if isinstance(seed, int):
|
81 |
+
seed = [seed + i for i in range(self.num_envs)]
|
82 |
+
assert len(seed) == self.num_envs
|
83 |
+
|
84 |
+
for env, single_seed in zip(self.envs, seed):
|
85 |
+
env.seed(single_seed)
|
86 |
+
|
87 |
+
def reset_wait(
|
88 |
+
self,
|
89 |
+
seed: Optional[Union[int, List[int]]] = None,
|
90 |
+
return_info: bool = False,
|
91 |
+
options: Optional[dict] = None,
|
92 |
+
):
|
93 |
+
if seed is None:
|
94 |
+
seed = [None for _ in range(self.num_envs)]
|
95 |
+
if isinstance(seed, int):
|
96 |
+
seed = [seed + i for i in range(self.num_envs)]
|
97 |
+
assert len(seed) == self.num_envs
|
98 |
+
|
99 |
+
self._dones[:] = False
|
100 |
+
observations = []
|
101 |
+
data_list = []
|
102 |
+
for env, single_seed in zip(self.envs, seed):
|
103 |
+
|
104 |
+
kwargs = {}
|
105 |
+
if single_seed is not None:
|
106 |
+
kwargs["seed"] = single_seed
|
107 |
+
if options is not None:
|
108 |
+
kwargs["options"] = options
|
109 |
+
if return_info == True:
|
110 |
+
kwargs["return_info"] = return_info
|
111 |
+
|
112 |
+
if not return_info:
|
113 |
+
observation = env.reset(**kwargs)
|
114 |
+
observations.append(observation)
|
115 |
+
else:
|
116 |
+
observation, data = env.reset(**kwargs)
|
117 |
+
observations.append(observation)
|
118 |
+
data_list.append(data)
|
119 |
+
|
120 |
+
self.observations = concatenate(
|
121 |
+
self.single_observation_space, observations, self.observations
|
122 |
+
)
|
123 |
+
if not return_info:
|
124 |
+
return deepcopy(self.observations) if self.copy else self.observations
|
125 |
+
else:
|
126 |
+
return (deepcopy(self.observations) if self.copy else self.observations), data_list
|
127 |
+
|
128 |
+
def step_async(self, actions):
|
129 |
+
self._actions = iterate(self.action_space, actions)
|
130 |
+
|
131 |
+
def step_wait(self):
|
132 |
+
observations, infos = [], []
|
133 |
+
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
134 |
+
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
135 |
+
# if self._dones[i].all():
|
136 |
+
# observation = env.reset()
|
137 |
+
observations.append(observation)
|
138 |
+
infos.append(info)
|
139 |
+
self.observations = concatenate(
|
140 |
+
self.single_observation_space, observations, self.observations
|
141 |
+
)
|
142 |
+
|
143 |
+
return (
|
144 |
+
deepcopy(self.observations) if self.copy else self.observations,
|
145 |
+
np.copy(self._rewards),
|
146 |
+
np.copy(self._dones),
|
147 |
+
infos,
|
148 |
+
)
|
149 |
+
|
150 |
+
def call(self, name, *args, **kwargs):
|
151 |
+
results = []
|
152 |
+
for env in self.envs:
|
153 |
+
function = getattr(env, name)
|
154 |
+
if callable(function):
|
155 |
+
results.append(function(*args, **kwargs))
|
156 |
+
else:
|
157 |
+
results.append(function)
|
158 |
+
|
159 |
+
return tuple(results)
|
160 |
+
|
161 |
+
def set_attr(self, name, values):
|
162 |
+
if not isinstance(values, (list, tuple)):
|
163 |
+
values = [values for _ in range(self.num_envs)]
|
164 |
+
if len(values) != self.num_envs:
|
165 |
+
raise ValueError(
|
166 |
+
"Values must be a list or tuple with length equal to the "
|
167 |
+
f"number of environments. Got `{len(values)}` values for "
|
168 |
+
f"{self.num_envs} environments."
|
169 |
+
)
|
170 |
+
|
171 |
+
for env, value in zip(self.envs, values):
|
172 |
+
setattr(env, name, value)
|
173 |
+
|
174 |
+
def close_extras(self, **kwargs):
|
175 |
+
"""Close the environments."""
|
176 |
+
[env.close() for env in self.envs]
|
177 |
+
|
178 |
+
def _check_spaces(self):
|
179 |
+
for env in self.envs:
|
180 |
+
if not (env.observation_space == self.single_observation_space):
|
181 |
+
raise RuntimeError(
|
182 |
+
"Some environments have an observation space different from "
|
183 |
+
f"`{self.single_observation_space}`. In order to batch observations, "
|
184 |
+
"the observation spaces from all environments must be equal."
|
185 |
+
)
|
186 |
+
|
187 |
+
if not (env.action_space == self.single_action_space):
|
188 |
+
raise RuntimeError(
|
189 |
+
"Some environments have an action space different from "
|
190 |
+
f"`{self.single_action_space}`. In order to batch actions, the "
|
191 |
+
"action spaces from all environments must be equal."
|
192 |
+
)
|
193 |
+
|
194 |
+
else:
|
195 |
+
return True
|