Spaces:
Running
on
Zero
Running
on
Zero
derektan
commited on
Commit
·
40d2b47
1
Parent(s):
68a0d40
[UPDATE] Moved Search-tta init to process thread to solve ZeroGPU issue
Browse files- app.py +64 -47
- app_multimodal_inference.py +1 -0
app.py
CHANGED
@@ -84,49 +84,47 @@ policy_net.load_state_dict(checkpoint['policy_model'])
|
|
84 |
print('Model loaded!')
|
85 |
# print(next(policy_net.parameters()).device)
|
86 |
|
87 |
-
#
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
else:
|
129 |
-
clip_seg_tta_1 = clip_seg_tta_2 = clip_seg_tta = None
|
130 |
|
131 |
# Load metadata json
|
132 |
tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
|
@@ -179,7 +177,26 @@ def process_search_tta(
|
|
179 |
|
180 |
# Helper to build a TestWorker with/without TTA
|
181 |
def build_planner(enable_tta: bool, save_dir: str, clip_obj):
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
if local_clip is not None:
|
184 |
# Feed inputs to ClipSegTTA copy
|
185 |
local_clip.img_paths = [ground_path] if ground_path else []
|
@@ -235,12 +252,12 @@ def process_search_tta(
|
|
235 |
# Launch both planners in background threads – preparation included
|
236 |
thread_tta = threading.Thread(
|
237 |
target=_planner_thread,
|
238 |
-
args=(True, gifs_dir_tta,
|
239 |
daemon=True,
|
240 |
)
|
241 |
thread_no = threading.Thread(
|
242 |
target=_planner_thread,
|
243 |
-
args=(False, gifs_dir_no,
|
244 |
daemon=True,
|
245 |
)
|
246 |
_register_thread(thread_tta)
|
@@ -383,7 +400,7 @@ with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:
|
|
383 |
"""
|
384 |
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
|
385 |
Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the next tab above. <br>
|
386 |
-
Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. If you encounter an 'Error' status, refresh the browser and rerun the demo. We will improve this in the future. <br>
|
387 |
<a href="https://search-tta.github.io">Project Website</a>
|
388 |
"""
|
389 |
)
|
|
|
84 |
print('Model loaded!')
|
85 |
# print(next(policy_net.parameters()).device)
|
86 |
|
87 |
+
# # (ClipSegTTA will now be instantiated lazily inside each planner thread)
|
88 |
+
# clip_seg_tta_1 = clip_seg_tta_2 = None # placeholder; real instances created per thread
|
89 |
+
# if False and TAXABIND_TTA:
|
90 |
+
# # Instantiate TWO independent ClipSegTTA objects (one per concurrent run)
|
91 |
+
# clip_seg_tta_1 = ClipSegTTA(
|
92 |
+
# img_dir=TAXABIND_IMG_DIR,
|
93 |
+
# imo_dir=TAXABIND_IMO_DIR,
|
94 |
+
# json_path=TAXABIND_INAT_JSON_PATH,
|
95 |
+
# sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
96 |
+
# patch_size=TAXABIND_PATCH_SIZE,
|
97 |
+
# sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
98 |
+
# sample_index = -1, # Set using 'reset' in worker
|
99 |
+
# blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
100 |
+
# device=device, # device,
|
101 |
+
# sat_to_img_ids_json_is_train_dict=False, # for search ds val
|
102 |
+
# tax_to_filter_val=QUERY_TAX,
|
103 |
+
# load_model=USE_CLIP_PREDS,
|
104 |
+
# initial_modality=INITIAL_MODALITY,
|
105 |
+
# sound_data_path = TAXABIND_SOUND_DATA_PATH,
|
106 |
+
# sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
107 |
+
# # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
|
108 |
+
# )
|
109 |
+
# clip_seg_tta_2 = ClipSegTTA(
|
110 |
+
# img_dir=TAXABIND_IMG_DIR,
|
111 |
+
# imo_dir=TAXABIND_IMO_DIR,
|
112 |
+
# json_path=TAXABIND_INAT_JSON_PATH,
|
113 |
+
# sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
114 |
+
# patch_size=TAXABIND_PATCH_SIZE,
|
115 |
+
# sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
116 |
+
# sample_index = -1, # Set using 'reset' in worker
|
117 |
+
# blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
118 |
+
# device=device,
|
119 |
+
# sat_to_img_ids_json_is_train_dict=False,
|
120 |
+
# tax_to_filter_val=QUERY_TAX,
|
121 |
+
# load_model=USE_CLIP_PREDS,
|
122 |
+
# initial_modality=INITIAL_MODALITY,
|
123 |
+
# sound_data_path=TAXABIND_SOUND_DATA_PATH,
|
124 |
+
# sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
125 |
+
# )
|
126 |
+
|
127 |
+
|
|
|
|
|
128 |
|
129 |
# Load metadata json
|
130 |
tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
|
|
|
177 |
|
178 |
# Helper to build a TestWorker with/without TTA
|
179 |
def build_planner(enable_tta: bool, save_dir: str, clip_obj):
|
180 |
+
# Lazily (re)create a ClipSegTTA instance per thread if not provided
|
181 |
+
local_clip = clip_obj
|
182 |
+
if TAXABIND_TTA and local_clip is None:
|
183 |
+
local_clip = ClipSegTTA(
|
184 |
+
img_dir=TAXABIND_IMG_DIR,
|
185 |
+
imo_dir=TAXABIND_IMO_DIR,
|
186 |
+
json_path=TAXABIND_INAT_JSON_PATH,
|
187 |
+
sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
|
188 |
+
patch_size=TAXABIND_PATCH_SIZE,
|
189 |
+
sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
|
190 |
+
sample_index=-1,
|
191 |
+
blur_kernel=TAXABIND_GAUSSIAN_BLUR_KERNEL,
|
192 |
+
device=device,
|
193 |
+
sat_to_img_ids_json_is_train_dict=False,
|
194 |
+
tax_to_filter_val=QUERY_TAX,
|
195 |
+
load_model=USE_CLIP_PREDS,
|
196 |
+
initial_modality=INITIAL_MODALITY,
|
197 |
+
sound_data_path=TAXABIND_SOUND_DATA_PATH,
|
198 |
+
sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
|
199 |
+
)
|
200 |
if local_clip is not None:
|
201 |
# Feed inputs to ClipSegTTA copy
|
202 |
local_clip.img_paths = [ground_path] if ground_path else []
|
|
|
252 |
# Launch both planners in background threads – preparation included
|
253 |
thread_tta = threading.Thread(
|
254 |
target=_planner_thread,
|
255 |
+
args=(True, gifs_dir_tta, None, "tta"),
|
256 |
daemon=True,
|
257 |
)
|
258 |
thread_no = threading.Thread(
|
259 |
target=_planner_thread,
|
260 |
+
args=(False, gifs_dir_no, None, "no"),
|
261 |
daemon=True,
|
262 |
)
|
263 |
_register_thread(thread_tta)
|
|
|
400 |
"""
|
401 |
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
|
402 |
Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the next tab above. <br>
|
403 |
+
Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
|
404 |
<a href="https://search-tta.github.io">Project Website</a>
|
405 |
"""
|
406 |
)
|
app_multimodal_inference.py
CHANGED
@@ -185,6 +185,7 @@ with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo:
|
|
185 |
"""
|
186 |
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
|
187 |
Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the previous tab above. <br>
|
|
|
188 |
<a href="https://search-tta.github.io">Project Website</a>
|
189 |
"""
|
190 |
)
|
|
|
185 |
"""
|
186 |
# Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
|
187 |
Click on any of the <b>examples below</b> and run the <b>multimodal inference demo</b>. Check out the <b>test-time adaptation feature</b> by switching to the previous tab above. <br>
|
188 |
+
If you encounter any errors, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future. <br>
|
189 |
<a href="https://search-tta.github.io">Project Website</a>
|
190 |
"""
|
191 |
)
|