search-tta-demo / app.py
derektan
First commit. Using Git LFS for binaries
dd3c1c5
raw
history blame
10.2 kB
"""
EcoMonitor β€’ multimodal heat-map demo (with custom preprocessing)
"""
# ────────────────────────── imports ───────────────────────────────────
import cv2
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from torchvision import transforms
import open_clip
from clip_vision_per_patch_model import CLIPVisionPerPatchModel
# ────────────────────────── global config & models ────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1️⃣ BioCLIP (ground-image & text encoder)
bio_model, _, _ = open_clip.create_model_and_transforms("hf-hub:imageomics/bioclip")
bio_model = bio_model.to(device).eval()
bio_tokenizer = open_clip.get_tokenizer("hf-hub:imageomics/bioclip")
# 2️⃣ Satellite patch encoder (CLIP-L-336 per-patch)
sat_model: CLIPVisionPerPatchModel = (
CLIPVisionPerPatchModel.from_pretrained("derektan95/search-tta")
.to(device)
.eval()
)
logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
logit_scale = logit_scale.exp()
blur_kernel = (5,5)
# ────────────────────────── transforms (exact spec) ───────────────────
img_transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
imo_transform = transforms.Compose(
[
transforms.Resize((336, 336)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
# ────────────────────────── helpers ───────────────────────────────────
# def _tensor_ground(img_pil: Image.Image) -> torch.Tensor:
# return img_transform(img_pil).unsqueeze(0).to(device)
# def _tensor_sat(img_pil: Image.Image) -> torch.Tensor:
# return imo_transform(img_pil).unsqueeze(0).to(device)
@torch.no_grad()
def _encode_ground(img_pil: Image.Image) -> torch.Tensor:
img = img_transform(img_pil).unsqueeze(0).to(device)
img_embeds, *_ = bio_model(img)
return img_embeds
# feats = bio_model.encode_image(_tensor_ground(img_pil))
# return torch.nn.functional.normalize(feats, dim=-1)
@torch.no_grad()
def _encode_text(text: str) -> torch.Tensor:
toks = bio_tokenizer(text).to(device)
_, txt_embeds, _ = bio_model(text=toks)
return txt_embeds
# return torch.nn.functional.normalize(feats, dim=-1)
@torch.no_grad()
def _encode_sat(img_pil: Image.Image) -> torch.Tensor:
imo = imo_transform(img_pil).unsqueeze(0).to(device)
imo_embeds = sat_model(imo)
return imo_embeds
# out = sat_model(_tensor_sat(img_pil))
# if hasattr(out, "last_hidden_state"):
# out = out.last_hidden_state
# return torch.nn.functional.normalize(out.squeeze(0), dim=-1) # (P, D)
# return out
def _similarity_heatmap(query: torch.Tensor, patches: torch.Tensor) -> np.ndarray:
sims = torch.matmul(query, patches.t()) * logit_scale
sims = sims.t().sigmoid()
# sims = torch.sigmoid(patches @ query.squeeze(0)) # (P,)
sims = sims[1:].squeeze() # drop CLS token
side = int(np.sqrt(len(sims)))
sims = sims.reshape(side, side)
return sims.cpu().detach().numpy()
# return sims[: side * side].view(side, side).cpu().numpy()
def _array_to_pil(arr: np.ndarray) -> Image.Image:
"""
Render arr with viridis, automatically stretching its own min→max to 0→1
so that the most-similar patches appear yellow.
"""
# Gausian Smoothing
if blur_kernel != (0,0):
arr = cv2.GaussianBlur(arr, blur_kernel, 0)
# --- contrast-stretch to local 0-1 range --------------------------
arr_min, arr_max = float(arr.min()), float(arr.max())
if arr_max - arr_min < 1e-6: # avoid /0 when the heat-map is flat
arr_scaled = np.zeros_like(arr)
else:
arr_scaled = (arr - arr_min) / (arr_max - arr_min)
# ------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(2.6, 2.6), dpi=96)
ax.imshow(arr_scaled, cmap="viridis", vmin=0.0, vmax=1.0)
ax.axis("off")
buf = io.BytesIO()
plt.tight_layout(pad=0)
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# ────────────────────────── main inference ────────────────────────────
def process(
sat_img: Image.Image,
taxonomy: str,
ground_img: Image.Image | None,
):
if sat_img is None:
return None, None
patches = _encode_sat(sat_img)
heat_ground, heat_text = None, None
if ground_img is not None:
q_img = _encode_ground(ground_img)
heat_ground = _array_to_pil(_similarity_heatmap(q_img, patches))
if taxonomy.strip():
q_txt = _encode_text(taxonomy.strip())
heat_text = _array_to_pil(_similarity_heatmap(q_txt, patches))
return heat_ground, heat_text
# ────────────────────────── Gradio UI ─────────────────────────────────
with gr.Blocks(title="EcoMonitor", theme=gr.themes.Base()) as demo:
with gr.Row():
gr.Markdown(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
<span></span>
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
<a href="https://search-tta.github.io">Project Website</a>
</h2>
</div>
</div>
"""
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
# <a href="https://derektan95.github.io">Derek M. S. Tan</a>,
# <a href="https://chinchinati.github.io/">Shailesh</a>,
# <a href="https://www.linkedin.com/in/boyang-liu-nus">Boyang Liu</a>,
# <a href="https://www.linkedin.com/in/loki-silvres">Alok Raj</a>,
# <a href="https://www.linkedin.com/in/ang-qi-xuan-714347142">Qi Xuan Ang</a>,
# <a href="https://weihengdai.top">Weiheng Dai</a>,
# <a href="https://www.linkedin.com/in/tanishqduhan">Tanishq Duhan</a>,
# <a href="https://www.linkedin.com/in/jimmychiun">Jimmy Chiun</a>,
# <a href="https://www.yuhongcao.online/">Yuhong Cao</a>,
# <a href="https://www.cs.toronto.edu/~florian/">Florian Shkurti</a>,
# <a href="https://www.marmotlab.org/bio.html">Guillaume Sartoretti</a>
# </h2>
# <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>National University of Singapore, University of Toronto, IIT-Dhanbad, Singapore Technologies Engineering</h2>
)
with gr.Row(variant="panel"):
# LEFT COLUMN (satellite, taxonomy, run)
with gr.Column():
sat_input = gr.Image(
label="Satellite Image",
sources=["upload"],
type="pil",
height=320,
)
taxonomy_input = gr.Textbox(
label="Full Taxonomy Name (optional)",
placeholder="e.g. Animalia Chordata Mammalia Carnivora Ursidae Ursus arctos",
)
run_btn = gr.Button("Run", variant="primary")
# RIGHT COLUMN (ground image + two heat-maps)
with gr.Column():
ground_input = gr.Image(
label="Ground-level Image (optional)",
sources=["upload"],
type="pil",
height=320,
)
heat_ground_out = gr.Image(
label="Heat-map (Ground query)",
height=160,
)
heat_text_out = gr.Image(
label="Heat-map (Text query)",
height=160,
)
# EXAMPLES
with gr.Row():
gr.Examples(
examples=[
[
"examples/NAIP_yosemite_v3_resized.png",
"Animalia Chordata Mammalia Carnivora Ursidae Ursus americanus",
"examples/american_black_bear_inat_248820933.jpeg",
],
# [
# "examples/satellite_coast.png",
# "",
# "examples/ground_gull.jpg",
# ],
# [
# "examples/satellite_coast.png",
# "Animalia Chordata Aves Charadriiformes Laridae Larus argentatus",
# None,
# ],
],
inputs=[sat_input, taxonomy_input, ground_input],
outputs=[heat_ground_out, heat_text_out],
fn=process,
cache_examples=False,
)
# CALLBACK
run_btn.click(
fn=process,
inputs=[sat_input, taxonomy_input, ground_input],
outputs=[heat_ground_out, heat_text_out],
)
# Footer to point out to model and data from app page.
gr.Markdown(
"""
This model is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite images and taxonomy images and locations from [iNaturalist](https://inaturalist.org/).
"""
)
# LAUNCH
if __name__ == "__main__":
demo.queue(max_size=15)
demo.launch(share=True)