|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io
|
|
import math
|
|
import urllib.request
|
|
from functools import lru_cache
|
|
from typing import Optional, Tuple, Dict, List
|
|
|
|
import gradio as gr
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw
|
|
import torch
|
|
from torchvision import transforms
|
|
from transformers import AutoModel
|
|
from matplotlib import colormaps as cm
|
|
|
|
|
|
MODEL_ID_LIST = [
|
|
"facebook/dinov3-vits16-pretrain-lvd1689m",
|
|
"facebook/dinov3-vits16plus-pretrain-lvd1689m",
|
|
"facebook/dinov3-vitb16-pretrain-lvd1689m",
|
|
"facebook/dinov3-vitl16-pretrain-lvd1689m",
|
|
"facebook/dinov3-vith16plus-pretrain-lvd1689m",
|
|
"facebook/dinov3-vit7b16-pretrain-lvd1689m",
|
|
"facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
|
|
"facebook/dinov3-convnext-small-pretrain-lvd1689m",
|
|
"facebook/dinov3-convnext-base-pretrain-lvd1689m",
|
|
"facebook/dinov3-convnext-large-pretrain-lvd1689m",
|
|
"facebook/dinov3-vitl16-pretrain-sat493m",
|
|
"facebook/dinov3-vit7b16-pretrain-sat493m",
|
|
]
|
|
|
|
DATASET_LABELS = {
|
|
"LVD-1689M": "lvd1689m",
|
|
"SAT-493M": "sat493m",
|
|
}
|
|
|
|
def build_model_maps(model_ids: List[str]):
|
|
"""
|
|
Returns:
|
|
valid_map[(dataset_key, short_name)] -> full_model_id
|
|
options_by_dataset[dataset_key] -> [short_name,...] (display order preserved)
|
|
"""
|
|
valid_map: Dict[Tuple[str, str], str] = {}
|
|
options_by_dataset: Dict[str, List[str]] = {"lvd1689m": [], "sat493m": []}
|
|
|
|
for mid in model_ids:
|
|
|
|
try:
|
|
prefix = "facebook/dinov3-"
|
|
start = mid.index(prefix) + len(prefix)
|
|
pre_idx = mid.index("-pretrain", start)
|
|
short = mid[start:pre_idx]
|
|
dataset = mid.split("-pretrain-")[-1].strip()
|
|
except Exception:
|
|
|
|
continue
|
|
|
|
key = (dataset, short)
|
|
valid_map[key] = mid
|
|
if dataset in options_by_dataset and short not in options_by_dataset[dataset]:
|
|
options_by_dataset[dataset].append(short)
|
|
|
|
return valid_map, options_by_dataset
|
|
|
|
VALID_MODEL_MAP, MODEL_OPTIONS_BY_DATASET = build_model_maps(MODEL_ID_LIST)
|
|
|
|
|
|
DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
DEFAULT_DATASET_LABEL = "LVD-1689M"
|
|
DEFAULT_OVERLAY_ALPHA = 0.55
|
|
DEFAULT_SHOW_GRID = True
|
|
|
|
|
|
SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
|
|
|
|
"lvd1689m": [
|
|
("– choose a sample –", ""),
|
|
("COCO: 2 Cats on sofa (039769)", "http://images.cocodataset.org/val2017/000000039769.jpg"),
|
|
("COCO: Person skiing (000785)", "http://images.cocodataset.org/val2017/000000000785.jpg"),
|
|
("COCO: People running (000872)", "http://images.cocodataset.org/val2017/000000000872.jpg"),
|
|
("Picsum: Mountain (ID=1000)", "https://picsum.photos/id/1000/800/600"),
|
|
("Picsum: Kayak (ID=1011)", "https://picsum.photos/id/1011/800/600"),
|
|
("Picsum: Man and dog (ID=1012)", "https://picsum.photos/id/1012/800/600"),
|
|
],
|
|
|
|
"sat493m": [
|
|
("– choose a satellite sample –", ""),
|
|
("Blue Marble (NASA)", "https://upload.wikimedia.org/wikipedia/commons/9/9d/The_Blue_Marble_%28remastered%29.jpg"),
|
|
("GOES-16 Hurricane Florence (2018)", "https://upload.wikimedia.org/wikipedia/commons/5/5e/Hurricane_Florence_GOES-16_2018-09-12_1510Z.jpg"),
|
|
("NASA Earth Observatory: Philippines", "https://eoimages.gsfc.nasa.gov/images/imagerecords/151000/151639/philippines_tmo_2020118_lrg.jpg"),
|
|
],
|
|
}
|
|
|
|
def _sample_labels_for(dataset_label: str):
|
|
key = DATASET_LABELS.get(dataset_label, "lvd1689m")
|
|
return [label for label, _ in SAMPLE_URL_CHOICES.get(key, [])]
|
|
|
|
def _apply_sample(dataset_label: str, sample_label: str):
|
|
"""Fill textbox with chosen sample URL and clear any uploaded image."""
|
|
key = DATASET_LABELS.get(dataset_label, "lvd1689m")
|
|
sample_map = dict(SAMPLE_URL_CHOICES.get(key, []))
|
|
url = sample_map.get(sample_label, "")
|
|
return gr.update(value=url), None
|
|
|
|
|
|
def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]:
|
|
|
|
if url and str(url).strip().lower().startswith(("http://", "https://")):
|
|
with urllib.request.urlopen(url) as resp:
|
|
data = resp.read()
|
|
return Image.open(io.BytesIO(data)).convert("RGB")
|
|
if isinstance(src, Image.Image):
|
|
return src.convert("RGB")
|
|
return None
|
|
|
|
def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Image, Tuple[int, int, int, int]]:
|
|
W, H = pil_img.size
|
|
H_pad = int(math.ceil(H / multiple) * multiple)
|
|
W_pad = int(math.ceil(W / multiple) * multiple)
|
|
if (H_pad, W_pad) == (H, W):
|
|
return pil_img, (0, 0, 0, 0)
|
|
canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
|
|
canvas.paste(pil_img, (0, 0))
|
|
return canvas, (0, 0, W_pad - W, H_pad - H)
|
|
|
|
def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
|
|
img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
std =[0.229, 0.224, 0.225]),
|
|
])
|
|
pixel_tensor = transform(img_padded).unsqueeze(0)
|
|
disp_np = np.array(img_padded, dtype=np.uint8)
|
|
return {"pixel_values": pixel_tensor}, disp_np, pad_box
|
|
|
|
def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
|
|
if arr.ndim == 2:
|
|
return arr.repeat(ps, 0).repeat(ps, 1)
|
|
elif arr.ndim == 3:
|
|
rows, cols, ch = arr.shape
|
|
arr2 = arr.repeat(ps, 0).repeat(ps, 1)
|
|
return arr2.reshape(rows * ps, cols * ps, ch)
|
|
raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
|
|
|
|
def blend_overlay(base_uint8: np.ndarray, overlay_rgb_float: np.ndarray, alpha: float) -> np.ndarray:
|
|
base = base_uint8.astype(np.float32)
|
|
over = (overlay_rgb_float * 255.0).astype(np.float32)
|
|
out = (1.0 - alpha) * base + alpha * over
|
|
return np.clip(out, 0, 255).astype(np.uint8)
|
|
|
|
def draw_grid(img: Image.Image, rows: int, cols: int, ps: int):
|
|
d = ImageDraw.Draw(img)
|
|
W, H = img.size
|
|
for r in range(1, rows):
|
|
y = r * ps
|
|
d.line([(0, y), (W, y)], fill=(255, 255, 255), width=1)
|
|
for c in range(1, cols):
|
|
x = c * ps
|
|
d.line([(x, 0), (x, H)], fill=(255, 255, 255), width=1)
|
|
|
|
def rc_to_idx(r: int, c: int, cols: int) -> int:
|
|
return int(r) * cols + int(c)
|
|
|
|
def idx_to_rc(i: int, cols: int) -> Tuple[int, int]:
|
|
return int(i) // cols, int(i) % cols
|
|
|
|
|
|
@lru_cache(maxsize=3)
|
|
def load_model_cached(full_model_id: str, device_str: str):
|
|
device = torch.device(device_str)
|
|
model = AutoModel.from_pretrained(full_model_id).to(device)
|
|
model.eval()
|
|
return model
|
|
|
|
def infer_patch_size(model, default: int = 16) -> int:
|
|
if hasattr(model, "config") and hasattr(model.config, "patch_size"):
|
|
ps = model.config.patch_size
|
|
if isinstance(ps, (tuple, list)): return int(ps[0])
|
|
return int(ps)
|
|
if hasattr(model, "patch_size"):
|
|
ps = model.patch_size
|
|
if isinstance(ps, (tuple, list)): return int(ps[0])
|
|
return int(ps)
|
|
return default
|
|
|
|
|
|
class PatchImageState:
|
|
def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
|
|
self.pil = pil_img
|
|
self.ps = ps
|
|
inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
|
|
self.disp = disp_np
|
|
pv = inputs["pixel_values"].to(device_str)
|
|
_, _, H, W = pv.shape
|
|
self.H, self.W = int(H), int(W)
|
|
self.rows, self.cols = self.H // ps, self.W // ps
|
|
|
|
with torch.no_grad():
|
|
out = model(pixel_values=pv)
|
|
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()
|
|
|
|
T, D = hs.shape
|
|
n_patches = self.rows * self.cols
|
|
n_special = T - n_patches
|
|
if n_special < 1:
|
|
raise RuntimeError(
|
|
f"Token mismatch: T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
|
|
)
|
|
self.D = D
|
|
patches = hs[n_special:, :].reshape(self.rows, self.cols, D)
|
|
self.X = patches.reshape(-1, D)
|
|
self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8)
|
|
|
|
|
|
def render_with_cosmap(
|
|
st: PatchImageState,
|
|
cos_map: Optional[np.ndarray],
|
|
overlay_alpha: float,
|
|
show_grid_flag: bool,
|
|
select_idx: Optional[int] = None,
|
|
best_idx: Optional[int] = None,
|
|
) -> Image.Image:
|
|
H, W, ps = st.H, st.W, st.ps
|
|
rows, cols = st.rows, st.cols
|
|
|
|
if cos_map is None:
|
|
disp = np.full((rows, cols), 0.5, dtype=np.float32)
|
|
else:
|
|
vmin, vmax = float(cos_map.min()), float(cos_map.max())
|
|
rng = vmax - vmin if vmax > vmin else 1e-8
|
|
disp = (cos_map - vmin) / rng
|
|
|
|
cmap = cm.get_cmap("magma")
|
|
rgba = cmap(disp)
|
|
rgb = rgba[..., :3]
|
|
|
|
if select_idx is not None:
|
|
rs, cs = idx_to_rc(select_idx, cols)
|
|
rgb[rs, cs, :] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
|
|
|
|
over_rgb_up = upsample_nearest(rgb, H, W, ps)
|
|
blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha))
|
|
pil = Image.fromarray(blended)
|
|
|
|
draw = ImageDraw.Draw(pil)
|
|
if show_grid_flag:
|
|
draw_grid(pil, rows, cols, ps)
|
|
|
|
if select_idx is not None:
|
|
r, c = idx_to_rc(select_idx, cols)
|
|
x0, y0 = c * ps, r * ps
|
|
x1, y1 = x0 + ps - 1, y0 + ps - 1
|
|
draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=2)
|
|
|
|
if best_idx is not None:
|
|
r, c = idx_to_rc(best_idx, cols)
|
|
x0, y0 = c * ps, r * ps
|
|
x1, y1 = x0 + ps - 1, y0 + ps - 1
|
|
draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 255, 0), width=2)
|
|
|
|
return pil
|
|
|
|
def compute_self_and_cross(
|
|
src: PatchImageState,
|
|
tgt: Optional[PatchImageState],
|
|
q_idx: int,
|
|
):
|
|
q = src.X[q_idx]
|
|
qn = q / (np.linalg.norm(q) + 1e-8)
|
|
|
|
cos_self = src.Xn @ qn
|
|
cos_map_self = cos_self.reshape(src.rows, src.cols)
|
|
self_stats = (float(cos_map_self.min()), float(cos_map_self.max()))
|
|
|
|
cross_result = None
|
|
cos_map_cross = None
|
|
if tgt is not None:
|
|
cos_cross = tgt.Xn @ qn
|
|
cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
|
|
cross_min, cross_max = float(cos_map_cross.min()), float(cos_map_cross.max())
|
|
best_idx = int(np.argmax(cos_cross))
|
|
cross_result = (cross_min, cross_max, best_idx)
|
|
|
|
return cos_map_self, cos_map_cross, self_stats, cross_result
|
|
|
|
|
|
def dataset_label_to_key(label: str) -> str:
|
|
return DATASET_LABELS.get(label, "lvd1689m")
|
|
|
|
def update_model_dropdown(dataset_label: str):
|
|
key = dataset_label_to_key(dataset_label)
|
|
opts = MODEL_OPTIONS_BY_DATASET.get(key, [])
|
|
default_val = opts[0] if opts else None
|
|
return gr.update(choices=opts, value=default_val)
|
|
|
|
def update_model_and_samples(dataset_label: str):
|
|
|
|
model_update = update_model_dropdown(dataset_label)
|
|
|
|
labels = _sample_labels_for(dataset_label)
|
|
sample_update = gr.update(choices=labels, value=(labels[0] if labels else None))
|
|
return model_update, sample_update, sample_update
|
|
|
|
def resolve_full_model_id(dataset_label: str, short_name: str) -> Optional[str]:
|
|
key = (dataset_label_to_key(dataset_label), short_name)
|
|
return VALID_MODEL_MAP.get(key)
|
|
|
|
|
|
def init_states(
|
|
left_img_in: Optional[Image.Image],
|
|
left_url: str,
|
|
right_img_in: Optional[Image.Image],
|
|
right_url: str,
|
|
dataset_label: str,
|
|
short_model: str,
|
|
show_grid_flag: bool,
|
|
overlay_alpha: float,
|
|
):
|
|
|
|
left_img = load_image_from_any(left_img_in, left_url)
|
|
right_img = load_image_from_any(right_img_in, right_url)
|
|
if left_img is None and right_img is None:
|
|
left_img = load_image_from_any(None, DEFAULT_URL)
|
|
|
|
|
|
full_model_id = resolve_full_model_id(dataset_label, short_model)
|
|
if not full_model_id:
|
|
return (gr.update(), gr.update(), None, None, 0, -1, -1, 16,
|
|
f"❌ Model not available: {dataset_label} / {short_model}")
|
|
|
|
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = load_model_cached(full_model_id, device_str)
|
|
ps = infer_patch_size(model, 16)
|
|
|
|
left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
|
|
right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
|
|
|
|
active_side = 0 if left_state is not None else 1
|
|
|
|
status = f"✔ Loaded: {full_model_id} | ps={ps}"
|
|
out_left, out_right = None, None
|
|
|
|
if left_state is not None and right_state is not None:
|
|
q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
|
|
cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
|
|
best_idx = cross_info[2] if cross_info else None
|
|
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
|
|
select_idx=None, best_idx=best_idx)
|
|
status += (f" | LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] "
|
|
f"| RIGHT cross best={best_idx}")
|
|
left_idx, right_idx = q_idx, (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
|
|
elif left_state is not None:
|
|
q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
|
|
cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
|
|
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
status += f" | Single LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]"
|
|
left_idx, right_idx = q_idx, -1
|
|
else:
|
|
q_idx = (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
|
|
cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
|
|
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
status += f" | Single RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]"
|
|
left_idx, right_idx = -1, q_idx
|
|
|
|
return (
|
|
out_left, out_right,
|
|
left_state, right_state,
|
|
active_side,
|
|
left_idx, right_idx,
|
|
ps,
|
|
status
|
|
)
|
|
|
|
def _coords_to_idx(x: int, y: int, st: PatchImageState) -> int:
|
|
r = int(np.clip(y // st.ps, 0, st.rows - 1))
|
|
c = int(np.clip(x // st.ps, 0, st.cols - 1))
|
|
return rc_to_idx(r, c, st.cols)
|
|
|
|
def on_select_left(
|
|
evt: gr.SelectData,
|
|
left_state: Optional[PatchImageState],
|
|
right_state: Optional[PatchImageState],
|
|
show_grid_flag: bool,
|
|
overlay_alpha: float,
|
|
ps: int,
|
|
):
|
|
if left_state is None:
|
|
return gr.update(), gr.update(), 0, -1, -1, "Upload/Load a LEFT image first."
|
|
|
|
x, y = evt.index
|
|
q_idx = _coords_to_idx(x, y, left_state)
|
|
|
|
if right_state is not None:
|
|
cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
|
|
best_idx = cross_info[2]
|
|
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
|
|
select_idx=None, best_idx=best_idx)
|
|
status = (f"LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
|
|
f"RIGHT cross best idx={best_idx}")
|
|
return out_left, out_right, 0, q_idx, -1, status
|
|
else:
|
|
cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
|
|
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
status = f"Single LEFT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
|
|
return out_left, gr.update(), 0, q_idx, -1, status
|
|
|
|
def on_select_right(
|
|
evt: gr.SelectData,
|
|
left_state: Optional[PatchImageState],
|
|
right_state: Optional[PatchImageState],
|
|
show_grid_flag: bool,
|
|
overlay_alpha: float,
|
|
ps: int,
|
|
):
|
|
if right_state is None:
|
|
return gr.update(), gr.update(), 1, -1, -1, "Upload/Load a RIGHT image first."
|
|
|
|
x, y = evt.index
|
|
q_idx = _coords_to_idx(x, y, right_state)
|
|
|
|
if left_state is not None:
|
|
cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(right_state, left_state, q_idx)
|
|
best_idx = cross_info[2]
|
|
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
|
|
select_idx=None, best_idx=best_idx)
|
|
status = (f"RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
|
|
f"LEFT cross best idx={best_idx}")
|
|
return out_left, out_right, 1, -1, q_idx, status
|
|
else:
|
|
cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
|
|
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
status = f"Single RIGHT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
|
|
return gr.update(), out_right, 1, -1, q_idx, status
|
|
|
|
def rebuild_with_settings(
|
|
left_state: Optional[PatchImageState],
|
|
right_state: Optional[PatchImageState],
|
|
active_side: int,
|
|
left_idx: int,
|
|
right_idx: int,
|
|
show_grid_flag: bool,
|
|
overlay_alpha: float,
|
|
ps: int,
|
|
):
|
|
if left_state is None and right_state is None:
|
|
return gr.update(), gr.update(), "Load an image first."
|
|
|
|
if left_state is not None and right_state is not None:
|
|
if active_side == 0:
|
|
q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
|
|
cos_self, cos_cross, _, cross_info = compute_self_and_cross(left_state, right_state, q_idx)
|
|
best_idx = cross_info[2]
|
|
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
|
|
select_idx=None, best_idx=best_idx)
|
|
else:
|
|
q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
|
|
cos_self, cos_cross, _, cross_info = compute_self_and_cross(right_state, left_state, q_idx)
|
|
best_idx = cross_info[2]
|
|
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
|
|
select_idx=None, best_idx=best_idx)
|
|
return out_left, out_right, "Updated overlays."
|
|
elif left_state is not None:
|
|
q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
|
|
cos_self, _, _, _ = compute_self_and_cross(left_state, None, q_idx)
|
|
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
return out_left, gr.update(), "Updated overlays."
|
|
else:
|
|
q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
|
|
cos_self, _, _, _ = compute_self_and_cross(right_state, None, q_idx)
|
|
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
|
select_idx=q_idx, best_idx=None)
|
|
return gr.update(), out_right, "Updated overlays."
|
|
|
|
|
|
with gr.Blocks(title="DINOv3 Patch Similarity (Self & Cross)") as demo:
|
|
gr.Markdown(
|
|
"""
|
|
# DINOv3 Patch Similarity (Self & Cross)
|
|
1) Pick **Dataset** (LVD-1689M / SAT-493M).
|
|
2) Pick **Model**.
|
|
3) Upload one or two images (or paste URLs) and press **Initialize / Update**.
|
|
- Click on a patch to update overlays.
|
|
- In two-image mode, the non-active image hides the red selection and shows **yellow** best match.
|
|
"""
|
|
)
|
|
|
|
with gr.Row():
|
|
dataset_radio = gr.Radio(
|
|
label="Dataset",
|
|
choices=list(DATASET_LABELS.keys()),
|
|
value=DEFAULT_DATASET_LABEL,
|
|
interactive=True
|
|
)
|
|
initial_key = DATASET_LABELS[DEFAULT_DATASET_LABEL]
|
|
initial_models = MODEL_OPTIONS_BY_DATASET.get(initial_key, [])
|
|
model_dropdown = gr.Dropdown(
|
|
label="Model name",
|
|
choices=initial_models,
|
|
value=(initial_models[0] if initial_models else None),
|
|
interactive=True
|
|
)
|
|
|
|
|
|
initial_sample_labels = [label for label, _ in SAMPLE_URL_CHOICES.get(initial_key, [])]
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
left_input = gr.Image(label="Left Image (upload)", type="pil",
|
|
sources=["upload", "clipboard", "webcam"], interactive=True)
|
|
left_url = gr.Textbox(label="Left Image URL (optional)", placeholder="https://...")
|
|
left_sample = gr.Dropdown(label="Use a sample URL",
|
|
choices=initial_sample_labels,
|
|
value=(initial_sample_labels[0] if initial_sample_labels else None),
|
|
interactive=True)
|
|
with gr.Column():
|
|
right_input = gr.Image(label="Right Image (upload)", type="pil",
|
|
sources=["upload", "clipboard", "webcam"], interactive=True)
|
|
right_url = gr.Textbox(label="Right Image URL (optional)", placeholder="https://...")
|
|
right_sample = gr.Dropdown(label="Use a sample URL",
|
|
choices=initial_sample_labels,
|
|
value=(initial_sample_labels[0] if initial_sample_labels else None),
|
|
interactive=True)
|
|
|
|
with gr.Accordion("Overlay Settings", open=True):
|
|
show_grid = gr.Checkbox(label="Show patch grid", value=DEFAULT_SHOW_GRID)
|
|
overlay_alpha = gr.Slider(label="Overlay alpha", minimum=0.0, maximum=1.0,
|
|
value=DEFAULT_OVERLAY_ALPHA, step=0.01)
|
|
|
|
init_btn = gr.Button("Initialize / Update", variant="primary")
|
|
|
|
with gr.Row():
|
|
left_view = gr.Image(label="LEFT (click to select patch)", interactive=True)
|
|
right_view = gr.Image(label="RIGHT (click to select patch)", interactive=True)
|
|
|
|
status = gr.Markdown("")
|
|
|
|
|
|
left_state = gr.State(None)
|
|
right_state = gr.State(None)
|
|
active_side = gr.State(0)
|
|
left_idx = gr.State(-1)
|
|
right_idx = gr.State(-1)
|
|
ps_state = gr.State(16)
|
|
|
|
|
|
dataset_radio.change(
|
|
fn=update_model_and_samples,
|
|
inputs=[dataset_radio],
|
|
outputs=[model_dropdown, left_sample, right_sample]
|
|
)
|
|
|
|
|
|
left_sample.change(
|
|
fn=_apply_sample,
|
|
inputs=[dataset_radio, left_sample],
|
|
outputs=[left_url, left_input]
|
|
)
|
|
right_sample.change(
|
|
fn=_apply_sample,
|
|
inputs=[dataset_radio, right_sample],
|
|
outputs=[right_url, right_input]
|
|
)
|
|
|
|
|
|
init_btn.click(
|
|
fn=init_states,
|
|
inputs=[left_input, left_url, right_input, right_url, dataset_radio, model_dropdown, show_grid, overlay_alpha],
|
|
outputs=[left_view, right_view, left_state, right_state, active_side, left_idx, right_idx, ps_state, status],
|
|
show_progress=True
|
|
)
|
|
|
|
|
|
left_view.select(
|
|
fn=on_select_left,
|
|
inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
|
|
outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
|
|
)
|
|
right_view.select(
|
|
fn=on_select_right,
|
|
inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
|
|
outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
|
|
)
|
|
|
|
|
|
show_grid.change(
|
|
fn=rebuild_with_settings,
|
|
inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
|
|
outputs=[left_view, right_view, status]
|
|
)
|
|
overlay_alpha.change(
|
|
fn=rebuild_with_settings,
|
|
inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
|
|
outputs=[left_view, right_view, status]
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
demo.queue().launch()
|
|
|