""" Simplified Gradio demo for Search-TTA evaluation. This version mirrors the layout of `app_BACKUP.py` but: 1. Loads no OpenCLIP / CLAP / Satellite encoders at import-time. 2. Keeps only the Satellite and Ground-level image inputs. 3. Exposes the high-level wrapper classes `ClipSegTTA` and `TestWorker` and calls `TestWorker.run_episode` inside the `process` callback. """ # ────────────────────────── imports ─────────────────────────────────── from pathlib import Path # Use non-GUI backend to avoid Tkinter errors in background threads import matplotlib matplotlib.use("Agg", force=True) import gradio as gr import os, glob, threading, time import torch from PIL import Image # Import configuration & RL / TTA utilities ------------------------------------------------- # NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.) # are available exactly as referenced later in the unchanged snippet. from test_parameter import * # noqa: F403, F401 (wild-import is intentional here) from model import PolicyNet # noqa: E402 – after wild import on purpose from test_multi_robot_worker import TestWorker # noqa: E402 from Taxabind.TaxaBind.SatBind.clip_seg_tta import ClipSegTTA # noqa: E402 # CHANGE ME! POLL_INTERVAL = 0.1 # For visualization # Prepare the model # device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu') device = torch.device('cuda') if USE_GPU else torch.device('cpu') policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device) # script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = Path(__file__).resolve().parent print("real_script_dir: ", script_dir) # checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}') checkpoint = torch.load(f'{model_path}/{MODEL_NAME}') policy_net.load_state_dict(checkpoint['policy_model']) print('Model loaded!') # print(next(policy_net.parameters()).device) # Init Taxabind here (only need to init once) if TAXABIND_TTA: # self.clip_seg_tta = None clip_seg_tta = ClipSegTTA( img_dir=TAXABIND_IMG_DIR, imo_dir=TAXABIND_IMO_DIR, json_path=TAXABIND_INAT_JSON_PATH, sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH, patch_size=TAXABIND_PATCH_SIZE, sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH, sample_index = -1, # Set using 'reset' in worker blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL, device=device, sat_to_img_ids_json_is_train_dict=False, # for search ds val tax_to_filter_val=QUERY_TAX, load_model=USE_CLIP_PREDS, initial_modality=INITIAL_MODALITY, sound_data_path = TAXABIND_SOUND_DATA_PATH, sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH, # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH, ) print("ClipSegTTA Loaded!") else: clip_seg_tta = None # ────────────────────────── Gradio process fn ───────────────────────── def process( sat_path: str | None, ground_path: str | None, taxonomy: str | None = None, ): """Callback executed when the user presses **Run** in the UI. At test-time we simply trigger the RL search episode via ``planner.run_episode`` and return its performance metrics. The image inputs are currently *not* used directly here but are retained to conform to the requested interface. """ # If no satellite image is provided we bail out early. if sat_path is None: return None # ------------------------------------------------------------------ # Load images from paths and configure ClipSegTTA inputs sat_img = Image.open(sat_path).convert("RGB") ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None clip_seg_tta.img_paths = [ground_path] if ground_path else [] clip_seg_tta.imo_path = sat_path clip_seg_tta.imgs = ([clip_seg_tta.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else []) clip_seg_tta.imo = clip_seg_tta.dataset.imo_transform(sat_img).to(device) clip_seg_tta.sounds = [] clip_seg_tta.sound_ids = [] # None clip_seg_tta.species_name = taxonomy or "" clip_seg_tta.target_positions = [(0,0)] # PLACEHOLDERS clip_seg_tta.gt_mask_name = taxonomy.replace(" ", "_") # None # Define TestWorker planner = TestWorker( meta_agent_id=0, n_agent=1, policy_net=policy_net, global_step=-1, device='cuda', greedy=True, save_image=SAVE_GIFS, clip_seg_tta=clip_seg_tta ) # ------------------------------------------------------------------ # Empty gifs_path folder if os.path.exists(gifs_path): for file in os.listdir(gifs_path): os.remove(os.path.join(gifs_path, file)) # Optionally you may want to reset episode index or make it configurable. # For now we hard-code episode 0, mirroring the snippet. t = threading.Thread(target=planner.run_episode, args=(0,), daemon=True) t.start() # planner.run_episode(0) sent = set() last_img = None while t.is_alive(): # discover any new pngs written by TestWorker pngs = glob.glob(os.path.join(gifs_path, "*.png")) pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0])) for fp in pngs: if fp not in sent: sent.add(fp) last_img = fp yield fp # stream update time.sleep(POLL_INTERVAL) # one final yield after the loop finishes yield last_img print("planner.perf_metrics: ", planner.perf_metrics) # ────────────────────────── Gradio UI ───────────────────────────────── with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo: gr.Markdown( """ # Search-TTA – Simplified Demo **Satellite ↔ Ground-level Visual Search** via RL Test-Time Adaptation. """ ) with gr.Row(variant="panel"): with gr.Column(): sat_input = gr.Image( label="Satellite Image", sources=["upload"], type="filepath", height=320, ) taxonomy_input = gr.Textbox( label="Full Taxonomy Name (optional)", placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", ) ground_input = gr.Image( label="Ground-level Image (optional)", sources=["upload"], type="filepath", height=320, ) run_btn = gr.Button("Run", variant="primary") with gr.Column(): gr.Markdown("### Live Heatmap") display_img = gr.Image(label="Current Heatmap", type="filepath", height=512) # Bind callback # EXAMPLES – copied from original demo (satellite, ground, taxonomy only) with gr.Row(): gr.Markdown("### In-Domain Taxonomy") with gr.Row(): gr.Examples( examples=[ [ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg", "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg", "Animalia Chordata Aves Charadriiformes Laridae Larus marinus", ], [ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg", "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg", "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris", ], [ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg", "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg", "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata", ], [ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg", "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg", "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", ], [ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg", "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", ], ], inputs=[sat_input, ground_input, taxonomy_input], outputs=[display_img], fn=lambda sat, grd, tax: process(sat, grd, tax), cache_examples=False, ) with gr.Row(): gr.Markdown("### Out-Domain Taxonomy") with gr.Row(): gr.Examples( examples=[ [ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg", "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg", "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris", ], [ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg", "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", ], [ "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png", "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg", "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus", ], [ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg", "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", ], ], inputs=[sat_input, ground_input, taxonomy_input], outputs=[display_img], fn=lambda sat, grd, tax: process(sat, grd, tax), cache_examples=False, ) run_btn.click( fn=process, inputs=[sat_input, ground_input, taxonomy_input], outputs=display_img, ) # ────────────────────────── unchanged worker initialisation ─────────── # NOTE: **Do NOT modify the code below.** It is copied verbatim from the # user-provided snippet so that the exact same objects are created. # The variables referenced here come from `test_parameter` which we # imported with a wildcard earlier. # if def main if __name__ == "__main__": # Finally launch the Gradio interface (queue for concurrency). demo.queue(max_size=15) demo.launch(share=True)