Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Simplified Gradio demo for Search-TTA evaluation. | |
| This version mirrors the layout of `app_BACKUP.py` but: | |
| 1. Loads no OpenCLIP / CLAP / Satellite encoders at import-time. | |
| 2. Keeps only the Satellite and Ground-level image inputs. | |
| 3. Exposes the high-level wrapper classes `ClipSegTTA` and | |
| `TestWorker` and calls `TestWorker.run_episode` inside the | |
| `process` callback. | |
| """ | |
| # ββββββββββββββββββββββββββ imports βββββββββββββββββββββββββββββββββββ | |
| from pathlib import Path | |
| # Use non-GUI backend to avoid Tkinter errors in background threads | |
| import matplotlib | |
| matplotlib.use("Agg", force=True) | |
| import gradio as gr | |
| import ctypes # for safely stopping background threads | |
| import os, glob, threading, time | |
| import torch | |
| from PIL import Image | |
| import json | |
| import spaces # integration with ZeroGPU on hf | |
| # Import configuration & RL / TTA utilities ------------------------------------------------- | |
| # NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.) | |
| # are available exactly as referenced later in the unchanged snippet. | |
| from test_parameter import * # noqa: F403, F401 (wild-import is intentional here) | |
| from model import PolicyNet # noqa: E402 β after wild import on purpose | |
| from test_multi_robot_worker import TestWorker # noqa: E402 | |
| from Taxabind.TaxaBind.SatBind.clip_seg_tta import ClipSegTTA # noqa: E402 | |
| # Helper to kill a Python thread by injecting SystemExit | |
| def _stop_thread(thread: threading.Thread): | |
| """Forcefully raise SystemExit in the given thread (best-effort).""" | |
| if thread is None or not thread.is_alive(): | |
| return | |
| tid = thread.ident | |
| if tid is None: | |
| return | |
| # Ask CPython to raise SystemExit in the thread context | |
| res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit)) | |
| if res > 1: | |
| # If it returned >1, cleanup and fail safe | |
| ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None) | |
| # CHANGE ME! | |
| POLL_INTERVAL = 0.1 # For visualization | |
| # Prepare the model | |
| # device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu') | |
| device = torch.device('cuda') if USE_GPU else torch.device('cpu') | |
| policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device) | |
| # script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| script_dir = Path(__file__).resolve().parent | |
| print("real_script_dir: ", script_dir) | |
| # checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}') | |
| checkpoint = torch.load(f'{model_path}/{MODEL_NAME}') | |
| policy_net.load_state_dict(checkpoint['policy_model']) | |
| print('Model loaded!') | |
| # print(next(policy_net.parameters()).device) | |
| # Init Taxabind here (only need to init once) | |
| if TAXABIND_TTA: | |
| # self.clip_seg_tta = None | |
| clip_seg_tta = ClipSegTTA( | |
| img_dir=TAXABIND_IMG_DIR, | |
| imo_dir=TAXABIND_IMO_DIR, | |
| json_path=TAXABIND_INAT_JSON_PATH, | |
| sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH, | |
| patch_size=TAXABIND_PATCH_SIZE, | |
| sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH, | |
| sample_index = -1, # Set using 'reset' in worker | |
| blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL, | |
| device=device, | |
| sat_to_img_ids_json_is_train_dict=False, # for search ds val | |
| tax_to_filter_val=QUERY_TAX, | |
| load_model=USE_CLIP_PREDS, | |
| initial_modality=INITIAL_MODALITY, | |
| sound_data_path = TAXABIND_SOUND_DATA_PATH, | |
| sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH, | |
| # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH, | |
| ) | |
| print("ClipSegTTA Loaded!") | |
| else: | |
| clip_seg_tta = None | |
| # Load metadata json | |
| tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json") | |
| tgts_metadata = json.load(open(tgts_metadata_json_path)) | |
| # ββββββββββββββββββββββββββ Gradio process fn βββββββββββββββββββββββββ | |
| # Helper wrappers so that Gradio recognises streaming (generator) functions | |
| # NOTE: A lambda that *returns* a generator is NOT itself a generator *function*, | |
| # hence Gradio fails to detect streaming and treats the return value as a plain | |
| # object. By defining explicit generator functions (with `yield from`) we ensure | |
| # `inspect.isgeneratorfunction` evaluates to True and Gradio streams correctly. | |
| def process_with_tta( | |
| sat_path: str | None, | |
| ground_path: str | None, | |
| taxonomy: str | None = None, | |
| ): | |
| """Stream search episode **with** TTA enabled while disabling buttons.""" | |
| # Disable buttons initially (image reset) | |
| yield gr.update(interactive=False), gr.update(interactive=False), gr.update(value=None) | |
| last_img = None | |
| for img in process(sat_path, ground_path, taxonomy, True): | |
| last_img = img | |
| yield gr.update(interactive=False), gr.update(interactive=False), img | |
| # Re-enable buttons at the end | |
| yield gr.update(interactive=True), gr.update(interactive=True), last_img | |
| def process_no_tta( | |
| sat_path: str | None, | |
| ground_path: str | None, | |
| taxonomy: str | None = None, | |
| ): | |
| """Stream search episode **without** TTA enabled while disabling buttons.""" | |
| yield gr.update(interactive=False), gr.update(interactive=False), gr.update(value=None) | |
| last_img = None | |
| for img in process(sat_path, ground_path, taxonomy, False): | |
| last_img = img | |
| yield gr.update(interactive=False), gr.update(interactive=False), img | |
| yield gr.update(interactive=True), gr.update(interactive=True), last_img | |
| # integration with ZeroGPU on hf | |
| def process( | |
| sat_path: str | None, | |
| ground_path: str | None, | |
| taxonomy: str | None = None, | |
| with_tta: bool = True, | |
| ): | |
| """Callback executed when the user presses **Run** in the UI. | |
| At test-time we simply trigger the RL search episode via | |
| ``planner.run_episode`` and return its performance metrics. | |
| The image inputs are currently *not* used directly here but are | |
| retained to conform to the requested interface. | |
| """ | |
| # If no satellite image is provided we bail out early. | |
| if sat_path is None: | |
| return None | |
| # ------------------------------------------------------------------ | |
| # Load images from paths and configure ClipSegTTA inputs | |
| sat_img = Image.open(sat_path).convert("RGB") | |
| ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None | |
| tgts = [tuple(tgt) for tgt in tgts_metadata[taxonomy]["target_positions"]] | |
| clip_seg_tta.img_paths = [ground_path] if ground_path else [] | |
| clip_seg_tta.imo_path = sat_path | |
| clip_seg_tta.imgs = ([clip_seg_tta.dataset.img_transform(ground_img_pil).to(device)] | |
| if ground_img_pil else []) | |
| clip_seg_tta.imo = clip_seg_tta.dataset.imo_transform(sat_img).to(device) | |
| clip_seg_tta.sounds = [] | |
| clip_seg_tta.sound_ids = [] # None | |
| clip_seg_tta.species_name = taxonomy or "" | |
| clip_seg_tta.gt_mask_name = taxonomy.replace(" ", "_") # None | |
| clip_seg_tta.target_positions = tgts if tgts != [] else [(0,0)] | |
| # Define TestWorker | |
| planner = TestWorker( | |
| meta_agent_id=0, | |
| n_agent=1, | |
| policy_net=policy_net, | |
| global_step=-1, | |
| device='cuda', | |
| greedy=True, | |
| save_image=SAVE_GIFS, | |
| clip_seg_tta=clip_seg_tta | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Empty gifs_path folder | |
| if os.path.exists(gifs_path): | |
| for file in os.listdir(gifs_path): | |
| os.remove(os.path.join(gifs_path, file)) | |
| # Optionally you may want to reset episode index or make it configurable. | |
| # For now we hard-code episode 0, mirroring the snippet. | |
| # Set execute_tta flag depending on button pressed | |
| planner.execute_tta = with_tta | |
| t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True) | |
| t.start() | |
| sent: set[str] = set() | |
| last_img: str | None = None | |
| try: | |
| while t.is_alive(): | |
| # discover any new pngs written by TestWorker | |
| pngs = glob.glob(os.path.join(gifs_path, "*.png")) | |
| pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) | |
| for fp in pngs: | |
| if fp not in sent: | |
| sent.add(fp) | |
| last_img = fp | |
| yield fp # stream update | |
| time.sleep(POLL_INTERVAL) | |
| finally: | |
| # This block runs when the generator is cancelled (e.g. page refresh) | |
| if t.is_alive(): | |
| _stop_thread(t) | |
| t.join(timeout=1) | |
| # If the episode finished naturally, send the last frame once more | |
| if last_img is not None: | |
| yield last_img | |
| print("planner.perf_metrics: ", planner.perf_metrics) | |
| # ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Search-TTA β Simplified Demo | |
| **Satellite β Ground-level Visual Search** via RL Test-Time Adaptation. | |
| """ | |
| ) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| sat_input = gr.Image( | |
| label="Satellite Image", | |
| sources=["upload"], | |
| type="filepath", | |
| height=320, | |
| ) | |
| taxonomy_input = gr.Textbox( | |
| label="Full Taxonomy Name (optional)", | |
| placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", | |
| ) | |
| ground_input = gr.Image( | |
| label="Ground-level Image (optional)", | |
| sources=["upload"], | |
| type="filepath", | |
| height=320, | |
| ) | |
| run_tta_btn = gr.Button("Run (with TTA)", variant="primary") | |
| run_no_tta_btn = gr.Button("Run (without TTA)", variant="secondary") | |
| with gr.Column(): | |
| gr.Markdown("### Live Heatmap (with TTA)") | |
| display_img_tta = gr.Image(label="Heatmap (TTA)", type="filepath", height=512) | |
| gr.Markdown("### Live Heatmap (without TTA)") | |
| display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=512) | |
| # Bind callback | |
| # EXAMPLES β copied from original demo (satellite, ground, taxonomy only) | |
| with gr.Row(): | |
| gr.Markdown("### In-Domain Taxonomy") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg", | |
| "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg", | |
| "Animalia Chordata Aves Charadriiformes Laridae Larus marinus", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg", | |
| "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris", | |
| ], | |
| [ | |
| "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg", | |
| "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg", | |
| "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg", | |
| "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", | |
| "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg", | |
| "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", | |
| ], | |
| ], | |
| inputs=[sat_input, ground_input, taxonomy_input], | |
| outputs=[run_tta_btn, run_no_tta_btn, display_img_tta], | |
| fn=process_with_tta, | |
| cache_examples=False, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### Out-Domain Taxonomy") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg", | |
| "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg", | |
| "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png", | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg", | |
| "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus", | |
| ], | |
| [ | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", | |
| "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg", | |
| "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", | |
| ], | |
| ], | |
| inputs=[sat_input, ground_input, taxonomy_input], | |
| outputs=[run_tta_btn, run_no_tta_btn, display_img_tta], | |
| fn=process_with_tta, | |
| cache_examples=False, | |
| ) | |
| run_tta_btn.click( | |
| fn=process_with_tta, | |
| inputs=[sat_input, ground_input, taxonomy_input], | |
| outputs=[run_tta_btn, run_no_tta_btn, display_img_tta], | |
| ) | |
| run_no_tta_btn.click( | |
| fn=process_no_tta, | |
| inputs=[sat_input, ground_input, taxonomy_input], | |
| outputs=[run_tta_btn, run_no_tta_btn, display_img_no_tta], | |
| ) | |
| # ββββββββββββββββββββββββββ unchanged worker initialisation βββββββββββ | |
| # NOTE: **Do NOT modify the code below.** It is copied verbatim from the | |
| # user-provided snippet so that the exact same objects are created. | |
| # The variables referenced here come from `test_parameter` which we | |
| # imported with a wildcard earlier. | |
| # if def main | |
| if __name__ == "__main__": | |
| # Finally launch the Gradio interface (queue for concurrency). | |
| demo.queue(max_size=15) | |
| demo.launch(share=True) | |