""" Search-TTA demo """ # ────────────────────────── imports ─────────────────────────────────── import cv2 import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import io import torchaudio import spaces # integration with ZeroGPU on hf from torchvision import transforms import open_clip from clip_vision_per_patch_model import CLIPVisionPerPatchModel from transformers import ClapAudioModelWithProjection from transformers import ClapProcessor # ────────────────────────── global config & models ──────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # BioCLIP (ground-image & text encoder) bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip") bio_model = bio_model.to(device).eval() bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip") # Satellite patch encoder CLIP-L-336 per-patch) sat_model: CLIPVisionPerPatchModel = ( CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta-sat") .to(device) .eval() ) # Sound CLAP model sound_model: ClapAudioModelWithProjection = ( ClapAudioModelWithProjection.from_pretrained("derektan95/search-tta-sound") .to(device) .eval() ) sound_processor: ClapProcessor = ClapProcessor.from_pretrained("derektan95/search-tta-sound") SAMPLE_RATE = 48000 logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) logit_scale = logit_scale.exp() blur_kernel = (5,5) # ────────────────────────── transforms (exact spec) ─────────────────── img_transform = transforms.Compose( [ transforms.Resize((256, 256)), transforms.CenterCrop((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) imo_transform = transforms.Compose( [ transforms.Resize((336, 336)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) def get_audio_clap(path_to_audio,format="mp3",padding="repeatpad",truncation="fusion"): track, sr = torchaudio.load(path_to_audio, format=format) # torchaudio.load(path_to_audio) track = track.mean(axis=0) track = torchaudio.functional.resample(track, orig_freq=sr, new_freq=SAMPLE_RATE) output = sound_processor(audios=track, sampling_rate=SAMPLE_RATE, max_length_s=10, return_tensors="pt",padding=padding,truncation=truncation) return output # ────────────────────────── helpers ─────────────────────────────────── @torch.no_grad() def _encode_ground(img_pil: Image.Image) -> torch.Tensor: img = img_transform(img_pil).unsqueeze(0).to(device) img_embeds, *_ = bio_model(img) return img_embeds @torch.no_grad() def _encode_text(text: str) -> torch.Tensor: toks = bio_tokenizer(text).to(device) _, txt_embeds, _ = bio_model(text=toks) return txt_embeds @torch.no_grad() def _encode_sat(img_pil: Image.Image) -> torch.Tensor: imo = imo_transform(img_pil).unsqueeze(0).to(device) imo_embeds = sat_model(imo) return imo_embeds @torch.no_grad() def _encode_sound(sound) -> torch.Tensor: processed_sound = get_audio_clap(sound) for k in processed_sound.keys(): processed_sound[k] = processed_sound[k].to(device) unnormalized_audio_embeds = sound_model(**processed_sound).audio_embeds sound_embeds = torch.nn.functional.normalize(unnormalized_audio_embeds, dim=-1) return sound_embeds def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray: sims = torch.matmul(query, patches.t()) * logit_scale sims = sims.t().sigmoid() sims = sims[1:].squeeze() # drop CLS token side = int(np.sqrt(len(sims))) sims = sims.reshape(side, side) return sims.cpu().detach().numpy() def _array_to_pil(arr: np.ndarray) -> Image.Image: """ Render arr with viridis, automatically stretching its own min→max to 0→1 so that the most-similar patches appear yellow. """ # Gausian Smoothing if blur_kernel != (0,0): arr = cv2.GaussianBlur(arr, blur_kernel, 0) # --- contrast-stretch to local 0-1 range -------------------------- arr_min, arr_max = float(arr.min()), float(arr.max()) if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat arr_scaled = np.zeros_like(arr) else: arr_scaled = (arr - arr_min) / (arr_max - arr_min) # ------------------------------------------------------------------ fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96) ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0) ax.axis("off") buf = io.BytesIO() plt.tight_layout(pad=0) fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) return Image.open(buf) # ────────────────────────── main inference ──────────────────────────── # integration with ZeroGPU on hf @spaces.GPU def process( sat_img: Image.Image, taxonomy: str, ground_img: Image.Image | None, sound: torch.Tensor | None, ): if sat_img is None: return None, None patches = _encode_sat(sat_img) heat_ground, heat_text, heat_sound = None, None, None if ground_img is not None: q_img = _encode_ground(ground_img) heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches)) if taxonomy.strip(): q_txt = _encode_text(taxonomy.strip()) heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches)) if sound is not None: q_sound = _encode_sound(sound) heat_sound = _array_to_pil(_similarity_heatmap(q_sound, patches)) return heat_ground, heat_text, heat_sound # ────────────────────────── Gradio UI ───────────────────────────────── with gr.Blocks(title="Search-TTA", theme=gr.themes.Base()) as demo: gr.Markdown( """ # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo Project Website
[Work in Progress] """ ) # with gr.Row(): # gr.Markdown( # """ #
#
#

Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild

# #

\ # Project Website #

# #

[Work in Progress]

#
#
# """ #

WACV 2025

#

\ # Derek M. S. Tan, # Shailesh, # Boyang Liu, # Alok Raj, # Qi Xuan Ang, # Weiheng Dai, # Tanishq Duhan, # Jimmy Chiun, # Yuhong Cao, # Florian Shkurti, # Guillaume Sartoretti #

#

National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering

# ) with gr.Row(variant="panel"): # LEFT COLUMN (satellite, taxonomy, run) with gr.Column(): sat_input = gr.Image( label="Satellite Image", sources=["upload"], type="pil", height=320, ) taxonomy_input = gr.Textbox( label="Full Taxonomy Name (optional)", placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", ) # ─── NEW: sound input ─────────────────────────── sound_input = gr.Audio( label="Sound Input (optional)", sources=["upload"], # or "microphone" / "url" as you prefer type="filepath", # or "numpy" if you want raw arrays ) run_btn = gr.Button("Run", variant="primary") # RIGHT COLUMN (ground image + two heat-maps) with gr.Column(): ground_input = gr.Image( label="Ground-level Image (optional)", sources=["upload"], type="pil", height=320, ) gr.Markdown("### Heat-map Results") with gr.Row(): # Separate label and image to avoid overlap with gr.Column(scale=1, min_width=100): gr.Markdown("**Ground Image Query**", elem_id="label-ground") heat_ground_out = gr.Image( show_label=False, height=160, # width=160, ) with gr.Column(scale=1, min_width=100): gr.Markdown("**Text Query**", elem_id="label-text") heat_text_out = gr.Image( show_label=False, height=160, # width=160, ) with gr.Column(scale=1, min_width=100): gr.Markdown("**Sound Query**", elem_id="label-sound") heat_sound_out = gr.Image( show_label=False, height=160, # width=160, ) # ─── NEW: sound output ───────────────────────── # sound_output = gr.Audio( # label="Playback", # ) # EXAMPLES with gr.Row(): gr.Markdown("### In-Domain Taxonomy") with gr.Row(): gr.Examples( examples=[ [ "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg", "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg", "Animalia Chordata Aves Charadriiformes Laridae Larus marinus", "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/89758229.mp3" ], [ "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg", "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/1b8064f8-7deb-4b30-98cd-69da98ba6a3d.jpg", "Animalia Chordata Mammalia Rodentia Caviidae Hydrochoerus hydrochaeris", "examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/166631961.mp3" ], [ "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg", "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg", "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata", "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3" ], [ "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg", "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg", "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota", "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/59677071.mp3" ], [ "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg", "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator", None ], ], inputs=[sat_input, ground_input, taxonomy_input, sound_input], outputs=[heat_ground_out, heat_text_out, heat_sound_out], fn=process, cache_examples=False, ) # EXAMPLES with gr.Row(): gr.Markdown("### Out-Domain Taxonomy") with gr.Row(): gr.Examples( examples=[ [ "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg", "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3aac526d-c921-452a-af6a-cb4f2f52e2c4.jpg", "Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris", "examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/3123948.mp3" ], [ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/37faabd2-a613-4461-b27e-82fe5955ecaf.jpg", "Animalia Chordata Mammalia Carnivora Canidae Canis aureus", "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/189318716.mp3" ], [ "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/yosemite_v3_resized.png", "examples/Animalia_Chordata_Mammalia_Carnivora_Ursidae_Ursus_americanus/248820933.jpeg", "Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus", None ], [ "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", "examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/0cbdfbf2-6cfe-4d61-9602-c949f24d0293.jpg", "Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis", None ], ], inputs=[sat_input, ground_input, taxonomy_input, sound_input], outputs=[heat_ground_out, heat_text_out, heat_sound_out], fn=process, cache_examples=False, ) # CALLBACK run_btn.click( fn=process, inputs=[sat_input, taxonomy_input, ground_input, sound_input], outputs=[heat_ground_out, heat_text_out, heat_sound_out], ) # Footer to point out to model and data from app page. gr.Markdown( """ The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process. """ ) # LAUNCH if __name__ == "__main__": demo.queue(max_size=15) demo.launch(share=True)