Spaces:
Running
on
Zero
Running
on
Zero
""" | |
EcoMonitor β’ multimodal heat-map demo (with custom preprocessing) | |
""" | |
# ββββββββββββββββββββββββββ imports βββββββββββββββββββββββββββββββββββ | |
import cv2 | |
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from torchvision import transforms | |
import open_clip | |
from clip_vision_per_patch_model import CLIPVisionPerPatchModel | |
# ββββββββββββββββββββββββββ global config & models ββββββββββββββββββββ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 1οΈβ£ BioCLIP (ground-image & text encoder) | |
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip") | |
bio_model = bio_model.to(device).eval() | |
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip") | |
# 2οΈβ£ Satellite patch encoder (CLIP-L-336 per-patch) | |
sat_model: CLIPVisionPerPatchModel = ( | |
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta") | |
.to(device) | |
.eval() | |
) | |
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
logit_scale = logit_scale.exp() | |
blur_kernel = (5,5) | |
# ββββββββββββββββββββββββββ transforms (exact spec) βββββββββββββββββββ | |
img_transform = transforms.Compose( | |
[ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
] | |
) | |
imo_transform = transforms.Compose( | |
[ | |
transforms.Resize((336, 336)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
] | |
) | |
# ββββββββββββββββββββββββββ helpers βββββββββββββββββββββββββββββββββββ | |
# def _tensor_ground(img_pil: Image.Image) -> torch.Tensor: | |
# return img_transform(img_pil).unsqueeze(0).to(device) | |
# def _tensor_sat(img_pil: Image.Image) -> torch.Tensor: | |
# return imo_transform(img_pil).unsqueeze(0).to(device) | |
def _encode_ground(img_pil: Image.Image) -> torch.Tensor: | |
img = img_transform(img_pil).unsqueeze(0).to(device) | |
img_embeds, *_ = bio_model(img) | |
return img_embeds | |
# feats = bio_model.encode_image(_tensor_ground(img_pil)) | |
# return torch.nn.functional.normalize(feats, dim=-1) | |
def _encode_text(text: str) -> torch.Tensor: | |
toks = bio_tokenizer(text).to(device) | |
_, txt_embeds, _ = bio_model(text=toks) | |
return txt_embeds | |
# return torch.nn.functional.normalize(feats, dim=-1) | |
def _encode_sat(img_pil: Image.Image) -> torch.Tensor: | |
imo = imo_transform(img_pil).unsqueeze(0).to(device) | |
imo_embeds = sat_model(imo) | |
return imo_embeds | |
# out = sat_model(_tensor_sat(img_pil)) | |
# if hasattr(out, "last_hidden_state"): | |
# out = out.last_hidden_state | |
# return torch.nn.functional.normalize(out.squeeze(0), dim=-1) # (P, D) | |
# return out | |
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray: | |
sims = torch.matmul(query, patches.t()) * logit_scale | |
sims = sims.t().sigmoid() | |
# sims = torch.sigmoid(patches @ query.squeeze(0)) # (P,) | |
sims = sims[1:].squeeze() # drop CLS token | |
side = int(np.sqrt(len(sims))) | |
sims = sims.reshape(side, side) | |
return sims.cpu().detach().numpy() | |
# return sims[: side * side].view(side, side).cpu().numpy() | |
def _array_to_pil(arr: np.ndarray) -> Image.Image: | |
""" | |
Render arr with viridis, automatically stretching its own minβmax to 0β1 | |
so that the most-similar patches appear yellow. | |
""" | |
# Gausian Smoothing | |
if blur_kernel != (0,0): | |
arr = cv2.GaussianBlur(arr, blur_kernel, 0) | |
# --- contrast-stretch to local 0-1 range -------------------------- | |
arr_min, arr_max = float(arr.min()), float(arr.max()) | |
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat | |
arr_scaled = np.zeros_like(arr) | |
else: | |
arr_scaled = (arr - arr_min) / (arr_max - arr_min) | |
# ------------------------------------------------------------------ | |
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96) | |
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0) | |
ax.axis("off") | |
buf = io.BytesIO() | |
plt.tight_layout(pad=0) | |
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
plt.close(fig) | |
buf.seek(0) | |
return Image.open(buf) | |
# ββββββββββββββββββββββββββ main inference ββββββββββββββββββββββββββββ | |
def process( | |
sat_img: Image.Image, | |
taxonomy: str, | |
ground_img: Image.Image | None, | |
): | |
if sat_img is None: | |
return None, None | |
patches = _encode_sat(sat_img) | |
heat_ground, heat_text = None, None | |
if ground_img is not None: | |
q_img = _encode_ground(ground_img) | |
heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches)) | |
if taxonomy.strip(): | |
q_txt = _encode_text(taxonomy.strip()) | |
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches)) | |
return heat_ground, heat_text | |
# ββββββββββββββββββββββββββ Gradio UI βββββββββββββββββββββββββββββββββ | |
with gr.Blocks(title="EcoMonitor", theme=gr.themes.Base()) as demo: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1> | |
<span></span> | |
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\ | |
<a href="https://search-tta.github.io">Project Website</a> | |
</h2> | |
</div> | |
</div> | |
""" | |
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2> | |
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\ | |
# <a href="https://derektan95.github.io">Derek M. S. Tan</a>, | |
# <a href="https://chinchinati.github.io/">Shailesh</a>, | |
# <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>, | |
# <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>, | |
# <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>, | |
# <a href="https://weihengdai.top">Weiheng Dai</a>, | |
# <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>, | |
# <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>, | |
# <a href="https://www.yuhongcao.online/">Yuhong Cao</a>, | |
# <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>, | |
# <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a> | |
# </h2> | |
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2> | |
) | |
with gr.Row(variant="panel"): | |
# LEFT COLUMN (satellite, taxonomy, run) | |
with gr.Column(): | |
sat_input = gr.Image( | |
label="Satellite Image", | |
sources=["upload"], | |
type="pil", | |
height=320, | |
) | |
taxonomy_input = gr.Textbox( | |
label="Full Taxonomy Name (optional)", | |
placeholder="e.g. Animalia Chordata Mammalia Carnivora Ursidae Ursus arctos", | |
) | |
run_btn = gr.Button("Run", variant="primary") | |
# RIGHT COLUMN (ground image + two heat-maps) | |
with gr.Column(): | |
ground_input = gr.Image( | |
label="Ground-level Image (optional)", | |
sources=["upload"], | |
type="pil", | |
height=320, | |
) | |
heat_ground_out = gr.Image( | |
label="Heat-map (Ground query)", | |
height=160, | |
) | |
heat_text_out = gr.Image( | |
label="Heat-map (Text query)", | |
height=160, | |
) | |
# EXAMPLES | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[ | |
"examples/NAIP_yosemite_v3_resized.png", | |
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus", | |
"examples/american_black_bear_inat_248820933.jpeg", | |
], | |
# [ | |
# "examples/satellite_coast.png", | |
# "", | |
# "examples/ground_gull.jpg", | |
# ], | |
# [ | |
# "examples/satellite_coast.png", | |
# "Animalia Chordata Aves Charadriiformes Laridae Larus argentatus", | |
# None, | |
# ], | |
], | |
inputs=[sat_input, taxonomy_input, ground_input], | |
outputs=[heat_ground_out, heat_text_out], | |
fn=process, | |
cache_examples=False, | |
) | |
# CALLBACK | |
run_btn.click( | |
fn=process, | |
inputs=[sat_input, taxonomy_input, ground_input], | |
outputs=[heat_ground_out, heat_text_out], | |
) | |
# Footer to point out to model and data from app page. | |
gr.Markdown( | |
""" | |
This model is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite images and taxonomy images and locations from [iNaturalist](https://inaturalist.org/). | |
""" | |
) | |
# LAUNCH | |
if __name__ == "__main__": | |
demo.queue(max_size=15) | |
demo.launch(share=True) | |