derektan commited on
Commit
3cbeaeb
Β·
1 Parent(s): f118874

[NEW] Able to launch both gifs concurrently

Browse files
Files changed (2) hide show
  1. app.py +128 -126
  2. env.py +3 -2
app.py CHANGED
@@ -21,6 +21,7 @@ import os, glob, threading, time
21
  import torch
22
  from PIL import Image
23
  import json
 
24
  import spaces # integration with ZeroGPU on hf
25
 
26
  # Import configuration & RL / TTA utilities -------------------------------------------------
@@ -66,8 +67,8 @@ print('Model loaded!')
66
 
67
  # Init Taxabind here (only need to init once)
68
  if TAXABIND_TTA:
69
- # self.clip_seg_tta = None
70
- clip_seg_tta = ClipSegTTA(
71
  img_dir=TAXABIND_IMG_DIR,
72
  imo_dir=TAXABIND_IMO_DIR,
73
  json_path=TAXABIND_INAT_JSON_PATH,
@@ -85,9 +86,28 @@ if TAXABIND_TTA:
85
  sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
86
  # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
87
  )
88
- print("ClipSegTTA Loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  else:
90
- clip_seg_tta = None
91
 
92
  # Load metadata json
93
  tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
@@ -102,132 +122,120 @@ tgts_metadata = json.load(open(tgts_metadata_json_path))
102
  # object. By defining explicit generator functions (with `yield from`) we ensure
103
  # `inspect.isgeneratorfunction` evaluates to True and Gradio streams correctly.
104
 
105
- def process_with_tta(
106
- sat_path: str | None,
107
- ground_path: str | None,
108
- taxonomy: str | None = None,
109
- ):
110
- """Stream search episode **with** TTA enabled while disabling buttons."""
111
- # Disable buttons initially (image reset)
112
- yield gr.update(interactive=False), gr.update(interactive=False), gr.update(value=None), gr.update(value="Initializing model…")
113
-
114
- last_img = None
115
- for img in process(sat_path, ground_path, taxonomy, True):
116
- last_img = img
117
- yield gr.update(interactive=False), gr.update(interactive=False), img, gr.update(value="Running…")
118
-
119
- # Re-enable buttons at the end
120
- yield gr.update(interactive=True), gr.update(interactive=True), last_img, gr.update(value="Done.")
121
-
122
-
123
-
124
- def process_no_tta(
125
  sat_path: str | None,
126
  ground_path: str | None,
127
  taxonomy: str | None = None,
128
  ):
129
- """Stream search episode **without** TTA enabled while disabling buttons."""
130
- yield gr.update(interactive=False), gr.update(interactive=False), gr.update(value=None), gr.update(value="Initializing model…")
131
-
132
- last_img = None
133
- for img in process(sat_path, ground_path, taxonomy, False):
134
- last_img = img
135
- yield gr.update(interactive=False), gr.update(interactive=False), img, gr.update(value="Running…")
136
-
137
- yield gr.update(interactive=True), gr.update(interactive=True), last_img, gr.update(value="Done.")
138
 
 
 
139
 
140
- # # integration with ZeroGPU on hf
141
- # @spaces.GPU
142
- def process(
143
- sat_path: str | None,
144
- ground_path: str | None,
145
- taxonomy: str | None = None,
146
- with_tta: bool = True,
147
- ):
148
- """Callback executed when the user presses **Run** in the UI.
149
-
150
- At test-time we simply trigger the RL search episode via
151
- ``planner.run_episode`` and return its performance metrics.
152
- The image inputs are currently *not* used directly here but are
153
- retained to conform to the requested interface.
154
- """
155
- # If no satellite image is provided we bail out early.
156
  if sat_path is None:
157
- return None
 
158
 
159
- # ------------------------------------------------------------------
160
- # Load images from paths and configure ClipSegTTA inputs
161
  sat_img = Image.open(sat_path).convert("RGB")
162
  ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
163
- tgts = [tuple(tgt) for tgt in tgts_metadata[taxonomy]["target_positions"]]
164
-
165
- clip_seg_tta.img_paths = [ground_path] if ground_path else []
166
- clip_seg_tta.imo_path = sat_path
167
- clip_seg_tta.imgs = ([clip_seg_tta.dataset.img_transform(ground_img_pil).to(device)]
168
- if ground_img_pil else [])
169
- clip_seg_tta.imo = clip_seg_tta.dataset.imo_transform(sat_img).to(device)
170
- clip_seg_tta.sounds = []
171
- clip_seg_tta.sound_ids = [] # None
172
- clip_seg_tta.species_name = taxonomy or ""
173
- clip_seg_tta.gt_mask_name = taxonomy.replace(" ", "_") # None
174
- clip_seg_tta.target_positions = tgts if tgts != [] else [(0,0)]
175
-
176
- # Define TestWorker
177
- planner = TestWorker(
178
- meta_agent_id=0,
179
- n_agent=1,
180
- policy_net=policy_net,
181
- global_step=-1,
182
- device=device,
183
- greedy=True,
184
- save_image=SAVE_GIFS,
185
- clip_seg_tta=clip_seg_tta
186
- )
187
-
188
- # ------------------------------------------------------------------
189
 
190
- # Define save gif dir
191
- planner.execute_tta = with_tta
192
- gifs_save_dir = os.path.join(gifs_path, "no_tta") if not with_tta else os.path.join(gifs_path, "with_tta")
193
- planner.gifs_path = gifs_save_dir
194
-
195
- # Empty gifs_path folder
196
- if os.path.exists(gifs_save_dir):
197
- for file in os.listdir(gifs_save_dir):
198
- os.remove(os.path.join(gifs_save_dir, file))
199
-
200
- # Optionally you may want to reset episode index or make it configurable.
201
- # For now we hard-code episode 0, mirroring the snippet.
202
- # Set execute_tta flag depending on button pressed
203
-
204
- t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True)
205
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- sent: set[str] = set()
208
- last_img: str | None = None
209
  try:
210
- while t.is_alive():
211
- # discover any new pngs written by TestWorker
212
- pngs = glob.glob(os.path.join(gifs_save_dir, "*.png"))
 
 
 
 
 
 
 
 
 
213
  pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
214
  for fp in pngs:
215
- if fp not in sent:
216
- sent.add(fp)
217
- last_img = fp
218
- yield fp # stream update
 
 
 
 
 
 
219
  time.sleep(POLL_INTERVAL)
220
  finally:
221
- # This block runs when the generator is cancelled (e.g. page refresh)
222
- if t.is_alive():
223
- _stop_thread(t)
224
- t.join(timeout=1)
225
-
226
- # If the episode finished naturally, send the last frame once more
227
- if last_img is not None:
228
- yield last_img
229
 
230
- print("planner.perf_metrics: ", planner.perf_metrics)
 
231
 
232
 
233
  # ────────────────────────── Gradio UI ─────────────────────────────────
@@ -280,15 +288,14 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
280
  type="filepath",
281
  height=320,
282
  )
283
- run_tta_btn = gr.Button("Run (with TTA)", variant="primary")
284
- run_no_tta_btn = gr.Button("Run (without TTA)", variant="secondary")
285
 
286
  with gr.Column():
287
  gr.Markdown("### Live Heatmap (with TTA)")
288
- display_img_tta = gr.Image(label="Heatmap (TTA)", type="filepath", height=512)
289
  status_tta = gr.Markdown("")
290
  gr.Markdown("### Live Heatmap (without TTA)")
291
- display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=512)
292
  status_no_tta = gr.Markdown("")
293
 
294
  # Bind callback
@@ -321,21 +328,16 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
321
  ],
322
  ],
323
  inputs=[sat_input, ground_input, taxonomy_input],
324
- outputs=[run_tta_btn, run_no_tta_btn, display_img_tta, status_tta],
325
- fn=process_with_tta,
326
  cache_examples=False,
327
  )
328
 
329
 
330
- run_tta_btn.click(
331
- fn=process_with_tta,
332
- inputs=[sat_input, ground_input, taxonomy_input],
333
- outputs=[run_tta_btn, run_no_tta_btn, display_img_tta, status_tta],
334
- )
335
- run_no_tta_btn.click(
336
- fn=process_no_tta,
337
  inputs=[sat_input, ground_input, taxonomy_input],
338
- outputs=[run_tta_btn, run_no_tta_btn, display_img_no_tta, status_no_tta],
339
  )
340
 
341
  # Footer to point out to model and data from app page.
 
21
  import torch
22
  from PIL import Image
23
  import json
24
+ import copy
25
  import spaces # integration with ZeroGPU on hf
26
 
27
  # Import configuration & RL / TTA utilities -------------------------------------------------
 
67
 
68
  # Init Taxabind here (only need to init once)
69
  if TAXABIND_TTA:
70
+ # Instantiate TWO independent ClipSegTTA objects (one per concurrent run)
71
+ clip_seg_tta_1 = ClipSegTTA(
72
  img_dir=TAXABIND_IMG_DIR,
73
  imo_dir=TAXABIND_IMO_DIR,
74
  json_path=TAXABIND_INAT_JSON_PATH,
 
86
  sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
87
  # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
88
  )
89
+ clip_seg_tta_2 = ClipSegTTA(
90
+ img_dir=TAXABIND_IMG_DIR,
91
+ imo_dir=TAXABIND_IMO_DIR,
92
+ json_path=TAXABIND_INAT_JSON_PATH,
93
+ sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
94
+ patch_size=TAXABIND_PATCH_SIZE,
95
+ sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
96
+ sample_index = -1, # Set using 'reset' in worker
97
+ blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
98
+ device=device,
99
+ sat_to_img_ids_json_is_train_dict=False,
100
+ tax_to_filter_val=QUERY_TAX,
101
+ load_model=USE_CLIP_PREDS,
102
+ initial_modality=INITIAL_MODALITY,
103
+ sound_data_path=TAXABIND_SOUND_DATA_PATH,
104
+ sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
105
+ )
106
+ print("ClipSegTTA instances loaded!")
107
+ # Keep original name for single-run mode compatibility
108
+ clip_seg_tta = clip_seg_tta_1
109
  else:
110
+ clip_seg_tta_1 = clip_seg_tta_2 = clip_seg_tta = None
111
 
112
  # Load metadata json
113
  tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
 
122
  # object. By defining explicit generator functions (with `yield from`) we ensure
123
  # `inspect.isgeneratorfunction` evaluates to True and Gradio streams correctly.
124
 
125
+ # # # integration with ZeroGPU on hf
126
+ # @spaces.GPU
127
+ def process_search_tta(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  sat_path: str | None,
129
  ground_path: str | None,
130
  taxonomy: str | None = None,
131
  ):
132
+ """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
 
 
 
 
 
 
 
 
133
 
134
+ # Disable Run button and clear image/status outputs
135
+ yield gr.update(interactive=False), gr.update(value=None), gr.update(value=None), gr.update(value="Initializing model…"), gr.update(value="Initializing model…")
136
 
137
+ # Bail early if satellite image missing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if sat_path is None:
139
+ yield gr.update(interactive=True), gr.update(value=None), gr.update(value=None), gr.update(value="No satellite image provided."), gr.update(value="")
140
+ return
141
 
142
+ # Prepare PIL images
 
143
  sat_img = Image.open(sat_path).convert("RGB")
144
  ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Lookup target positions metadata (may be empty)
147
+ tgt_positions = []
148
+ if taxonomy and taxonomy in tgts_metadata:
149
+ tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]]
150
+
151
+ # Helper to build a TestWorker with/without TTA
152
+ def build_planner(enable_tta: bool, save_dir: str, clip_obj):
153
+ local_clip = clip_obj # re-use the pre-instantiated ClipSegTTA
154
+ if local_clip is not None:
155
+ # Feed inputs to ClipSegTTA copy
156
+ local_clip.img_paths = [ground_path] if ground_path else []
157
+ local_clip.imo_path = sat_path
158
+ local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else [])
159
+ local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device)
160
+ local_clip.sounds = []
161
+ local_clip.sound_ids = []
162
+ local_clip.species_name = taxonomy or ""
163
+ local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else ""
164
+ local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)]
165
+
166
+ planner = TestWorker(
167
+ meta_agent_id=0,
168
+ n_agent=1,
169
+ policy_net=policy_net,
170
+ global_step=-1,
171
+ device=device,
172
+ greedy=True,
173
+ save_image=SAVE_GIFS,
174
+ clip_seg_tta=local_clip,
175
+ )
176
+ planner.execute_tta = enable_tta
177
+ planner.gifs_path = save_dir
178
+ return planner
179
+
180
+ # Prepare gif directories
181
+ gifs_dir_tta = os.path.join(gifs_path, "with_tta")
182
+ gifs_dir_no = os.path.join(gifs_path, "no_tta")
183
+ for d in (gifs_dir_tta, gifs_dir_no):
184
+ os.makedirs(d, exist_ok=True)
185
+ # Clean previous pngs
186
+ for f in os.listdir(d):
187
+ # if f.endswith(".png"):
188
+ os.remove(os.path.join(d, f))
189
+
190
+ planner_tta = build_planner(True, gifs_dir_tta, clip_seg_tta_1)
191
+ planner_no = build_planner(False, gifs_dir_no, clip_seg_tta_2)
192
+
193
+ # Launch both planners in background threads
194
+ thread_tta = threading.Thread(target=planner_tta.run_episode, args=(0,), daemon=True)
195
+ thread_no = threading.Thread(target=planner_no.run_episode, args=(0,), daemon=True)
196
+ thread_tta.start()
197
+ thread_no.start()
198
+
199
+ sent_tta: set[str] = set()
200
+ sent_no: set[str] = set()
201
+ last_tta = None
202
+ last_no = None
203
 
 
 
204
  try:
205
+ while thread_tta.is_alive() or thread_no.is_alive():
206
+ updated = False
207
+ # Collect new frames from TTA dir
208
+ pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png"))
209
+ pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
210
+ for fp in pngs:
211
+ if fp not in sent_tta:
212
+ sent_tta.add(fp)
213
+ last_tta = fp
214
+ updated = True
215
+ # Collect new frames from no-TTA dir
216
+ pngs = glob.glob(os.path.join(gifs_dir_no, "*.png"))
217
  pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
218
  for fp in pngs:
219
+ if fp not in sent_no:
220
+ sent_no.add(fp)
221
+ last_no = fp
222
+ updated = True
223
+
224
+ if updated:
225
+ status_tta = "Running…" if thread_tta.is_alive() else "Done."
226
+ status_no = "Running…" if thread_no.is_alive() else "Done."
227
+ yield gr.update(interactive=False), last_tta, last_no, gr.update(value=status_tta), gr.update(value=status_no)
228
+
229
  time.sleep(POLL_INTERVAL)
230
  finally:
231
+ # Ensure background threads are stopped on cancel
232
+ for th in (thread_tta, thread_no):
233
+ if th.is_alive():
234
+ _stop_thread(th)
235
+ th.join(timeout=1)
 
 
 
236
 
237
+ # Final emit after both finish
238
+ yield gr.update(interactive=True), last_tta, last_no, gr.update(value="Done."), gr.update(value="Done.")
239
 
240
 
241
  # ────────────────────────── Gradio UI ─────────────────────────────────
 
288
  type="filepath",
289
  height=320,
290
  )
291
+ run_btn = gr.Button("Run Search-TTA", variant="primary")
 
292
 
293
  with gr.Column():
294
  gr.Markdown("### Live Heatmap (with TTA)")
295
+ display_img_tta = gr.Image(label="Heatmap (TTA)", type="filepath", height=400) # 512
296
  status_tta = gr.Markdown("")
297
  gr.Markdown("### Live Heatmap (without TTA)")
298
+ display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400) # 512
299
  status_no_tta = gr.Markdown("")
300
 
301
  # Bind callback
 
328
  ],
329
  ],
330
  inputs=[sat_input, ground_input, taxonomy_input],
331
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta],
332
+ fn=process_search_tta,
333
  cache_examples=False,
334
  )
335
 
336
 
337
+ run_btn.click(
338
+ fn=process_search_tta,
 
 
 
 
 
339
  inputs=[sat_input, ground_input, taxonomy_input],
340
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta],
341
  )
342
 
343
  # Footer to point out to model and data from app page.
env.py CHANGED
@@ -816,8 +816,9 @@ class Env():
816
  cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
817
  cbar.set_label("Normalized Probs")
818
 
819
- plt.suptitle('Targets Found: {}/{} Coverage ratio: {:.4g} Travel Dist: {:.4g}'.format(self.num_targets_found, \
820
- len(self.target_positions), self.explored_rate, travel_dist))
 
821
 
822
  os.makedirs(save_dir, exist_ok=True)
823
  out_path = os.path.join(save_dir, f"{step}.png")
 
816
  cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
817
  cbar.set_label("Normalized Probs")
818
 
819
+ # Change coverage to 1dp
820
+ plt.suptitle('Targets Found: {}/{} Coverage: {:.1f}% Steps: {}/{}'.format(self.num_targets_found, \
821
+ len(self.target_positions), self.explored_rate*100, step, NUM_EPS_STEPS))
822
 
823
  os.makedirs(save_dir, exist_ok=True)
824
  out_path = os.path.join(save_dir, f"{step}.png")