derektan commited on
Commit
f118874
·
1 Parent(s): 9bf5bc2

[UPDATE] Saved to different TTA config folders depends on button press

Browse files
Files changed (2) hide show
  1. app.py +9 -5
  2. test_multi_robot_worker.py +9 -5
app.py CHANGED
@@ -187,15 +187,19 @@ def process(
187
 
188
  # ------------------------------------------------------------------
189
 
 
 
 
 
 
190
  # Empty gifs_path folder
191
- if os.path.exists(gifs_path):
192
- for file in os.listdir(gifs_path):
193
- os.remove(os.path.join(gifs_path, file))
194
 
195
  # Optionally you may want to reset episode index or make it configurable.
196
  # For now we hard-code episode 0, mirroring the snippet.
197
  # Set execute_tta flag depending on button pressed
198
- planner.execute_tta = with_tta
199
 
200
  t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True)
201
  t.start()
@@ -205,7 +209,7 @@ def process(
205
  try:
206
  while t.is_alive():
207
  # discover any new pngs written by TestWorker
208
- pngs = glob.glob(os.path.join(gifs_path, "*.png"))
209
  pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
210
  for fp in pngs:
211
  if fp not in sent:
 
187
 
188
  # ------------------------------------------------------------------
189
 
190
+ # Define save gif dir
191
+ planner.execute_tta = with_tta
192
+ gifs_save_dir = os.path.join(gifs_path, "no_tta") if not with_tta else os.path.join(gifs_path, "with_tta")
193
+ planner.gifs_path = gifs_save_dir
194
+
195
  # Empty gifs_path folder
196
+ if os.path.exists(gifs_save_dir):
197
+ for file in os.listdir(gifs_save_dir):
198
+ os.remove(os.path.join(gifs_save_dir, file))
199
 
200
  # Optionally you may want to reset episode index or make it configurable.
201
  # For now we hard-code episode 0, mirroring the snippet.
202
  # Set execute_tta flag depending on button pressed
 
203
 
204
  t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True)
205
  t.start()
 
209
  try:
210
  while t.is_alive():
211
  # discover any new pngs written by TestWorker
212
+ pngs = glob.glob(os.path.join(gifs_save_dir, "*.png"))
213
  pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
214
  for fp in pngs:
215
  if fp not in sent:
test_multi_robot_worker.py CHANGED
@@ -56,6 +56,9 @@ class TestWorker:
56
  self.perf_metrics = dict()
57
  self.bad_mask_init = False
58
 
 
 
 
59
  # # TEMP - EXPORT START POSES FOR BASELINES
60
  # json_path = "eval_start_positions.json"
61
  # sat_to_start_pose_dict = {}
@@ -152,7 +155,7 @@ class TestWorker:
152
  n_smooth_iter=2, # smoothing parameter
153
  ignore_label=-1,
154
  plot=False, # NOTE: Set to false since using app.py
155
- gifs_dir = gifs_path
156
  )
157
  # Fit & predict (this will also plot the clusters before & after smoothing)
158
  map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])))
@@ -284,13 +287,14 @@ class TestWorker:
284
  robots_route = []
285
  for robot in self.robot_list:
286
  robots_route.append([robot.xPoints, robot.yPoints])
287
- if not os.path.exists(gifs_path):
288
- os.makedirs(gifs_path)
 
289
  sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
290
 
291
  ## NOTE: Replaced since using app.py
292
  if TAXABIND_TTA and USE_CLIP_PREDS:
293
- self.env.plot_heatmap(gifs_path, step, max(travel_dist_list), robots_route)
294
  # if TAXABIND_TTA and USE_CLIP_PREDS:
295
  # self.env.plot_env(
296
  # self.global_step,
@@ -363,7 +367,7 @@ class TestWorker:
363
 
364
  # save gif
365
  if self.save_image:
366
- path = gifs_path
367
  self.make_gif(path, curr_episode)
368
 
369
  print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)
 
56
  self.perf_metrics = dict()
57
  self.bad_mask_init = False
58
 
59
+ # NOTE: Option to override gifs_path to interface with app.py
60
+ self.gifs_path = gifs_path
61
+
62
  # # TEMP - EXPORT START POSES FOR BASELINES
63
  # json_path = "eval_start_positions.json"
64
  # sat_to_start_pose_dict = {}
 
155
  n_smooth_iter=2, # smoothing parameter
156
  ignore_label=-1,
157
  plot=False, # NOTE: Set to false since using app.py
158
+ gifs_dir = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
159
  )
160
  # Fit & predict (this will also plot the clusters before & after smoothing)
161
  map_shape = (int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])), int(np.sqrt(self.clip_seg_tta.patch_embeds.shape[0])))
 
287
  robots_route = []
288
  for robot in self.robot_list:
289
  robots_route.append([robot.xPoints, robot.yPoints])
290
+ # NOTE: Set to self.gifs_path since using app.py
291
+ if not os.path.exists(self.gifs_path):
292
+ os.makedirs(self.gifs_path)
293
  sound_id_override = None if self.clip_seg_tta.sound_ids == [] else self.clip_seg_tta.sound_ids[0]
294
 
295
  ## NOTE: Replaced since using app.py
296
  if TAXABIND_TTA and USE_CLIP_PREDS:
297
+ self.env.plot_heatmap(self.gifs_path, step, max(travel_dist_list), robots_route)
298
  # if TAXABIND_TTA and USE_CLIP_PREDS:
299
  # self.env.plot_env(
300
  # self.global_step,
 
367
 
368
  # save gif
369
  if self.save_image:
370
+ path = self.gifs_path # NOTE: Set to self.gifs_path since using app.py
371
  self.make_gif(path, curr_episode)
372
 
373
  print(YELLOW, f"[Eps {curr_episode} Completed] Time Taken: {time()-eps_start:.2f}s, Steps: {step+1}", NC)