Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
f118874
1
Parent(s):
9bf5bc2
[UPDATE] Saved to different TTA config folders depends on button press
Browse files- app.py +9 -5
- test_multi_robot_worker.py +9 -5
app.py
CHANGED
|
@@ -187,15 +187,19 @@ def process(
|
|
| 187 |
|
| 188 |
# ------------------------------------------------------------------
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
# Empty gifs_path folder
|
| 191 |
-
if os.path.exists(
|
| 192 |
-
for file in os.listdir(
|
| 193 |
-
os.remove(os.path.join(
|
| 194 |
|
| 195 |
# Optionally you may want to reset episode index or make it configurable.
|
| 196 |
# For now we hard-code episode 0, mirroring the snippet.
|
| 197 |
# Set execute_tta flag depending on button pressed
|
| 198 |
-
planner.execute_tta = with_tta
|
| 199 |
|
| 200 |
t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True)
|
| 201 |
t.start()
|
|
@@ -205,7 +209,7 @@ def process(
|
|
| 205 |
try:
|
| 206 |
while t.is_alive():
|
| 207 |
# discover any new pngs written by TestWorker
|
| 208 |
-
pngs = glob.glob(os.path.join(
|
| 209 |
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 210 |
for fp in pngs:
|
| 211 |
if fp not in sent:
|
|
|
|
| 187 |
|
| 188 |
# ------------------------------------------------------------------
|
| 189 |
|
| 190 |
+
# Define save gif dir
|
| 191 |
+
planner.execute_tta = with_tta
|
| 192 |
+
gifs_save_dir = os.path.join(gifs_path, "no_tta") if not with_tta else os.path.join(gifs_path, "with_tta")
|
| 193 |
+
planner.gifs_path = gifs_save_dir
|
| 194 |
+
|
| 195 |
# Empty gifs_path folder
|
| 196 |
+
if os.path.exists(gifs_save_dir):
|
| 197 |
+
for file in os.listdir(gifs_save_dir):
|
| 198 |
+
os.remove(os.path.join(gifs_save_dir, file))
|
| 199 |
|
| 200 |
# Optionally you may want to reset episode index or make it configurable.
|
| 201 |
# For now we hard-code episode 0, mirroring the snippet.
|
| 202 |
# Set execute_tta flag depending on button pressed
|
|
|
|
| 203 |
|
| 204 |
t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True)
|
| 205 |
t.start()
|
|
|
|
| 209 |
try:
|
| 210 |
while t.is_alive():
|
| 211 |
# discover any new pngs written by TestWorker
|
| 212 |
+
pngs = glob.glob(os.path.join(gifs_save_dir, "*.png"))
|
| 213 |
pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
|
| 214 |
for fp in pngs:
|
| 215 |
if fp not in sent:
|
test_multi_robot_worker.py
CHANGED
|
@@ -56,6 +56,9 @@ class TestWorker:
|
|
| 56 |
self.perf_metrics = dict()
|
| 57 |
self.bad_mask_init = False
|
| 58 |
|
|
|
|
|
|
|
|
|
|
| 59 |
# # TEMP - EXPORT START POSES FOR BASELINES
|
| 60 |
# json_path = "eval_start_positions.json"
|
| 61 |
# sat_to_start_pose_dict = {}
|
|
@@ -152,7 +155,7 @@ class TestWorker:
|
|
| 152 |
n_smooth_iter=2, # smoothing parameter
|
| 153 |
ignore_label=-1,
|
| 154 |
plot=False, # NOTE: Set to false since using app.py
|
| 155 |
-
gifs_dir = gifs_path
|
| 156 |
)
|
| 157 |
# Fit & predict (this will also plot the clusters before & after smoothing)
|
| 158 |
map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])))
|
|
@@ -284,13 +287,14 @@ class TestWorker:
|
|
| 284 |
robots_route = []
|
| 285 |
for robot in self.robot_list:
|
| 286 |
robots_route.append([robot.xPoints, robot.yPoints])
|
| 287 |
-
|
| 288 |
-
|
|
|
|
| 289 |
sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
|
| 290 |
|
| 291 |
## NOTE: Replaced since using app.py
|
| 292 |
if TAXABIND_TTA and USE_CLIP_PREDS:
|
| 293 |
-
self.env.plot_heatmap(gifs_path, step, max(travel_dist_list), robots_route)
|
| 294 |
# if TAXABIND_TTA and USE_CLIP_PREDS:
|
| 295 |
# self.env.plot_env(
|
| 296 |
# self.global_step,
|
|
@@ -363,7 +367,7 @@ class TestWorker:
|
|
| 363 |
|
| 364 |
# save gif
|
| 365 |
if self.save_image:
|
| 366 |
-
path = gifs_path
|
| 367 |
self.make_gif(path, curr_episode)
|
| 368 |
|
| 369 |
print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
|
|
|
|
| 56 |
self.perf_metrics = dict()
|
| 57 |
self.bad_mask_init = False
|
| 58 |
|
| 59 |
+
# NOTE: Option to override gifs_path to interface with app.py
|
| 60 |
+
self.gifs_path = gifs_path
|
| 61 |
+
|
| 62 |
# # TEMP - EXPORT START POSES FOR BASELINES
|
| 63 |
# json_path = "eval_start_positions.json"
|
| 64 |
# sat_to_start_pose_dict = {}
|
|
|
|
| 155 |
n_smooth_iter=2, # smoothing parameter
|
| 156 |
ignore_label=-1,
|
| 157 |
plot=False, # NOTE: Set to false since using app.py
|
| 158 |
+
gifs_dir = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
|
| 159 |
)
|
| 160 |
# Fit & predict (this will also plot the clusters before & after smoothing)
|
| 161 |
map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])))
|
|
|
|
| 287 |
robots_route = []
|
| 288 |
for robot in self.robot_list:
|
| 289 |
robots_route.append([robot.xPoints, robot.yPoints])
|
| 290 |
+
# NOTE: Set to self.gifs_path since using app.py
|
| 291 |
+
if not os.path.exists(self.gifs_path):
|
| 292 |
+
os.makedirs(self.gifs_path)
|
| 293 |
sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
|
| 294 |
|
| 295 |
## NOTE: Replaced since using app.py
|
| 296 |
if TAXABIND_TTA and USE_CLIP_PREDS:
|
| 297 |
+
self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route)
|
| 298 |
# if TAXABIND_TTA and USE_CLIP_PREDS:
|
| 299 |
# self.env.plot_env(
|
| 300 |
# self.global_step,
|
|
|
|
| 367 |
|
| 368 |
# save gif
|
| 369 |
if self.save_image:
|
| 370 |
+
path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
|
| 371 |
self.make_gif(path, curr_episode)
|
| 372 |
|
| 373 |
print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
|