derektan commited on
Commit
40d2b47
·
1 Parent(s): 68a0d40

[UPDATE] Moved Search-tta init to process thread to solve ZeroGPU issue

Browse files
Files changed (2) hide show
  1. app.py +64 -47
  2. app_multimodal_inference.py +1 -0
app.py CHANGED
@@ -84,49 +84,47 @@ policy_net.load_state_dict(checkpoint['policy_model'])
84
  print('Model loaded!')
85
  # print(next(policy_net.parameters()).device)
86
 
87
- # Init Taxabind here (only need to init once)
88
- if TAXABIND_TTA:
89
- # Instantiate TWO independent ClipSegTTA objects (one per concurrent run)
90
- clip_seg_tta_1 = ClipSegTTA(
91
- img_dir=TAXABIND_IMG_DIR,
92
- imo_dir=TAXABIND_IMO_DIR,
93
- json_path=TAXABIND_INAT_JSON_PATH,
94
- sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
95
- patch_size=TAXABIND_PATCH_SIZE,
96
- sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
97
- sample_index = -1, # Set using 'reset' in worker
98
- blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
99
- device=device, # device,
100
- sat_to_img_ids_json_is_train_dict=False, # for search ds val
101
- tax_to_filter_val=QUERY_TAX,
102
- load_model=USE_CLIP_PREDS,
103
- initial_modality=INITIAL_MODALITY,
104
- sound_data_path = TAXABIND_SOUND_DATA_PATH,
105
- sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
106
- # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
107
- )
108
- clip_seg_tta_2 = ClipSegTTA(
109
- img_dir=TAXABIND_IMG_DIR,
110
- imo_dir=TAXABIND_IMO_DIR,
111
- json_path=TAXABIND_INAT_JSON_PATH,
112
- sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
113
- patch_size=TAXABIND_PATCH_SIZE,
114
- sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
115
- sample_index = -1, # Set using 'reset' in worker
116
- blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
117
- device=device,
118
- sat_to_img_ids_json_is_train_dict=False,
119
- tax_to_filter_val=QUERY_TAX,
120
- load_model=USE_CLIP_PREDS,
121
- initial_modality=INITIAL_MODALITY,
122
- sound_data_path=TAXABIND_SOUND_DATA_PATH,
123
- sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
124
- )
125
- print("ClipSegTTA instances loaded!")
126
- # Keep original name for single-run mode compatibility
127
- clip_seg_tta = clip_seg_tta_1
128
- else:
129
- clip_seg_tta_1 = clip_seg_tta_2 = clip_seg_tta = None
130
 
131
  # Load metadata json
132
  tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
@@ -179,7 +177,26 @@ def process_search_tta(
179
 
180
  # Helper to build a TestWorker with/without TTA
181
  def build_planner(enable_tta: bool, save_dir: str, clip_obj):
182
- local_clip = clip_obj # re-use the pre-instantiated ClipSegTTA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if local_clip is not None:
184
  # Feed inputs to ClipSegTTA copy
185
  local_clip.img_paths = [ground_path] if ground_path else []
@@ -235,12 +252,12 @@ def process_search_tta(
235
  # Launch both planners in background threads – preparation included
236
  thread_tta = threading.Thread(
237
  target=_planner_thread,
238
- args=(True, gifs_dir_tta, clip_seg_tta_1, "tta"),
239
  daemon=True,
240
  )
241
  thread_no = threading.Thread(
242
  target=_planner_thread,
243
- args=(False, gifs_dir_no, clip_seg_tta_2, "no"),
244
  daemon=True,
245
  )
246
  _register_thread(thread_tta)
@@ -383,7 +400,7 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
383
  """
384
  # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
385
  Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the next tab above. <br>
386
- Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. If you encounter an 'Error' status, refresh the browser and rerun the demo. We will improve this in the future. <br>
387
  <a href="https://search-tta.github.io">Project Website</a>
388
  """
389
  )
 
84
  print('Model loaded!')
85
  # print(next(policy_net.parameters()).device)
86
 
87
+ # # (ClipSegTTA will now be instantiated lazily inside each planner thread)
88
+ # clip_seg_tta_1 = clip_seg_tta_2 = None # placeholder; real instances created per thread
89
+ # if False and TAXABIND_TTA:
90
+ # # Instantiate TWO independent ClipSegTTA objects (one per concurrent run)
91
+ # clip_seg_tta_1 = ClipSegTTA(
92
+ # img_dir=TAXABIND_IMG_DIR,
93
+ # imo_dir=TAXABIND_IMO_DIR,
94
+ # json_path=TAXABIND_INAT_JSON_PATH,
95
+ # sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
96
+ # patch_size=TAXABIND_PATCH_SIZE,
97
+ # sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
98
+ # sample_index = -1, # Set using 'reset' in worker
99
+ # blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
100
+ # device=device, # device,
101
+ # sat_to_img_ids_json_is_train_dict=False, # for search ds val
102
+ # tax_to_filter_val=QUERY_TAX,
103
+ # load_model=USE_CLIP_PREDS,
104
+ # initial_modality=INITIAL_MODALITY,
105
+ # sound_data_path = TAXABIND_SOUND_DATA_PATH,
106
+ # sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
107
+ # # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
108
+ # )
109
+ # clip_seg_tta_2 = ClipSegTTA(
110
+ # img_dir=TAXABIND_IMG_DIR,
111
+ # imo_dir=TAXABIND_IMO_DIR,
112
+ # json_path=TAXABIND_INAT_JSON_PATH,
113
+ # sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
114
+ # patch_size=TAXABIND_PATCH_SIZE,
115
+ # sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
116
+ # sample_index = -1, # Set using 'reset' in worker
117
+ # blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
118
+ # device=device,
119
+ # sat_to_img_ids_json_is_train_dict=False,
120
+ # tax_to_filter_val=QUERY_TAX,
121
+ # load_model=USE_CLIP_PREDS,
122
+ # initial_modality=INITIAL_MODALITY,
123
+ # sound_data_path=TAXABIND_SOUND_DATA_PATH,
124
+ # sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
125
+ # )
126
+
127
+
 
 
128
 
129
  # Load metadata json
130
  tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
 
177
 
178
  # Helper to build a TestWorker with/without TTA
179
  def build_planner(enable_tta: bool, save_dir: str, clip_obj):
180
+ # Lazily (re)create a ClipSegTTA instance per thread if not provided
181
+ local_clip = clip_obj
182
+ if TAXABIND_TTA and local_clip is None:
183
+ local_clip = ClipSegTTA(
184
+ img_dir=TAXABIND_IMG_DIR,
185
+ imo_dir=TAXABIND_IMO_DIR,
186
+ json_path=TAXABIND_INAT_JSON_PATH,
187
+ sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
188
+ patch_size=TAXABIND_PATCH_SIZE,
189
+ sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
190
+ sample_index=-1,
191
+ blur_kernel=TAXABIND_GAUSSIAN_BLUR_KERNEL,
192
+ device=device,
193
+ sat_to_img_ids_json_is_train_dict=False,
194
+ tax_to_filter_val=QUERY_TAX,
195
+ load_model=USE_CLIP_PREDS,
196
+ initial_modality=INITIAL_MODALITY,
197
+ sound_data_path=TAXABIND_SOUND_DATA_PATH,
198
+ sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
199
+ )
200
  if local_clip is not None:
201
  # Feed inputs to ClipSegTTA copy
202
  local_clip.img_paths = [ground_path] if ground_path else []
 
252
  # Launch both planners in background threads – preparation included
253
  thread_tta = threading.Thread(
254
  target=_planner_thread,
255
+ args=(True, gifs_dir_tta, None, "tta"),
256
  daemon=True,
257
  )
258
  thread_no = threading.Thread(
259
  target=_planner_thread,
260
+ args=(False, gifs_dir_no, None, "no"),
261
  daemon=True,
262
  )
263
  _register_thread(thread_tta)
 
400
  """
401
  # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
402
  Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the next tab above. <br>
403
+ Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
404
  <a href="https://search-tta.github.io">Project Website</a>
405
  """
406
  )
app_multimodal_inference.py CHANGED
@@ -185,6 +185,7 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
185
  """
186
  # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
187
  Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the previous tab above. <br>
 
188
  <a href="https://search-tta.github.io">Project Website</a>
189
  """
190
  )
 
185
  """
186
  # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
187
  Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the previous tab above. <br>
188
+ If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
189
  <a href="https://search-tta.github.io">Project Website</a>
190
  """
191
  )