derektan commited on
Commit
e93ae13
·
1 Parent(s): 9780b6e

Keep track of threads running per instance, to kill only those when switching tabs

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -161,9 +161,13 @@ def process_search_tta(
161
  sat_path: str | None,
162
  ground_path: str | None,
163
  taxonomy: str | None = None,
 
164
  ):
165
  """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
166
 
 
 
 
167
  # Disable Run button and clear image/status outputs, hide sliders, clear frame states
168
  yield (
169
  gr.update(interactive=False),
@@ -175,11 +179,23 @@ def process_search_tta(
175
  gr.update(visible=False),
176
  [],
177
  [],
 
178
  )
179
 
180
  # Bail early if satellite image missing
181
  if sat_path is None:
182
- yield gr.update(interactive=True), gr.update(value=None), gr.update(value=None), gr.update(value="No satellite image provided.", visible=True), gr.update(value="", visible=True)
 
 
 
 
 
 
 
 
 
 
 
183
  return
184
 
185
  # Prepare PIL images
@@ -282,8 +298,8 @@ def process_search_tta(
282
  args=(False, gifs_dir_no, None, "no"),
283
  daemon=True,
284
  )
285
- _register_thread(thread_tta)
286
- _register_thread(thread_no)
287
  thread_tta.start()
288
  thread_no.start()
289
 
@@ -358,6 +374,7 @@ def process_search_tta(
358
  gr.update(visible=False),
359
  gr.update(),
360
  gr.update(),
 
361
  )
362
  prev_status_tta = status_tta
363
  prev_status_no = status_no
@@ -372,10 +389,8 @@ def process_search_tta(
372
 
373
  # Remove finished threads from global registry
374
  with _running_threads_lock:
375
- if thread_tta in _running_threads:
376
- _running_threads.remove(thread_tta)
377
- if thread_no in _running_threads:
378
- _running_threads.remove(thread_no)
379
 
380
  # Small delay to ensure last frame files are fully flushed
381
  time.sleep(0.2)
@@ -412,6 +427,7 @@ def process_search_tta(
412
  gr.update(visible=True, minimum=1, maximum=n_no, value=n_no),
413
  frames_tta,
414
  frames_no,
 
415
  )
416
 
417
 
@@ -482,6 +498,7 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
482
 
483
  frames_state_tta = gr.State([])
484
  frames_state_no = gr.State([])
 
485
 
486
  # Slider callbacks (updates image when user drags slider)
487
  def _show_frame(idx: int, frames: list[str]):
@@ -532,8 +549,8 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
532
 
533
  run_btn.click(
534
  fn=process_search_tta,
535
- inputs=[sat_input, ground_input, taxonomy_input],
536
- outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no],
537
  )
538
 
539
  # Footer to point out to model and data from app page.
@@ -561,14 +578,19 @@ if __name__ == "__main__":
561
 
562
  outputs_on_tab = [_cleanup_status]
563
 
564
- def _on_tab_change(evt: gr.SelectData):
565
  # evt.value contains the name of the newly-selected tab.
566
  if evt.value == "Multimodal Inference":
567
- _kill_running_threads()
 
 
 
 
 
568
  return "Stopped running Search-TTA threads."
569
  return ""
570
 
571
- tabs.select(_on_tab_change, outputs=outputs_on_tab)
572
 
573
  root.queue(max_size=15)
574
  root.launch(share=True)
 
161
  sat_path: str | None,
162
  ground_path: str | None,
163
  taxonomy: str | None = None,
164
+ session_threads: list[threading.Thread] | None = None,
165
  ):
166
  """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""
167
 
168
+ if session_threads is None:
169
+ session_threads = []
170
+
171
  # Disable Run button and clear image/status outputs, hide sliders, clear frame states
172
  yield (
173
  gr.update(interactive=False),
 
179
  gr.update(visible=False),
180
  [],
181
  [],
182
+ session_threads,
183
  )
184
 
185
  # Bail early if satellite image missing
186
  if sat_path is None:
187
+ yield (
188
+ gr.update(interactive=True),
189
+ gr.update(value=None),
190
+ gr.update(value=None),
191
+ gr.update(value="No satellite image provided.", visible=True),
192
+ gr.update(value="", visible=True),
193
+ gr.update(visible=False),
194
+ gr.update(visible=False),
195
+ [],
196
+ [],
197
+ session_threads,
198
+ )
199
  return
200
 
201
  # Prepare PIL images
 
298
  args=(False, gifs_dir_no, None, "no"),
299
  daemon=True,
300
  )
301
+ # Track threads for this user session
302
+ session_threads.extend([thread_tta, thread_no])
303
  thread_tta.start()
304
  thread_no.start()
305
 
 
374
  gr.update(visible=False),
375
  gr.update(),
376
  gr.update(),
377
+ session_threads,
378
  )
379
  prev_status_tta = status_tta
380
  prev_status_no = status_no
 
389
 
390
  # Remove finished threads from global registry
391
  with _running_threads_lock:
392
+ # Clear session thread list
393
+ session_threads.clear()
 
 
394
 
395
  # Small delay to ensure last frame files are fully flushed
396
  time.sleep(0.2)
 
427
  gr.update(visible=True, minimum=1, maximum=n_no, value=n_no),
428
  frames_tta,
429
  frames_no,
430
+ session_threads,
431
  )
432
 
433
 
 
498
 
499
  frames_state_tta = gr.State([])
500
  frames_state_no = gr.State([])
501
+ session_threads_state = gr.State([])
502
 
503
  # Slider callbacks (updates image when user drags slider)
504
  def _show_frame(idx: int, frames: list[str]):
 
549
 
550
  run_btn.click(
551
  fn=process_search_tta,
552
+ inputs=[sat_input, ground_input, taxonomy_input, session_threads_state],
553
+ outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state],
554
  )
555
 
556
  # Footer to point out to model and data from app page.
 
578
 
579
  outputs_on_tab = [_cleanup_status]
580
 
581
+ def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]):
582
  # evt.value contains the name of the newly-selected tab.
583
  if evt.value == "Multimodal Inference":
584
+ # Stop only threads started in this session
585
+ for th in list(session_threads):
586
+ if th is not None and th.is_alive():
587
+ _stop_thread(th)
588
+ th.join(timeout=1)
589
+ session_threads.clear()
590
  return "Stopped running Search-TTA threads."
591
  return ""
592
 
593
+ tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab)
594
 
595
  root.queue(max_size=15)
596
  root.launch(share=True)