manu02's picture
Initial commit
faca9e1 verified
raw
history blame
18.4 kB
# filepath: DINOv3-PatchSimilarity/DINOv3CosSimilarity2Images.py
# filepath: DINOv3-PatchSimilarity/DINOv3CosSimilarity.py
# Interactive DINOv3 patch similarity viewer for one or two images
# Interactive DINOv3 patch similarity viewer (NO AutoImageProcessor, NO resize)
# - Single-image mode (0 or 1 image given): original behavior
# - Two-image mode (2 images given): when you click or move on one image,
# shows BOTH overlays (self on source, cross on target).
# NEW: Red selection rectangle is hidden on the non-active image.
import sys, math, io, urllib.request, argparse, os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import matplotlib
try:
matplotlib.use("TkAgg")
except Exception:
pass
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from transformers import AutoModel
# ---------- Defaults / knobs ----------
DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
DEFAULT_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m"
SHOW_GRID = True
ANNOTATE_INDICES = False
OVERLAY_ALPHA = 0.55
PATCH_SIZE_OVERRIDE = None # set 16 to force; None = read from model if available
# ---------- Image I/O ----------
def load_image(path_or_url):
if str(path_or_url).startswith(("http://", "https://")):
with urllib.request.urlopen(path_or_url) as resp:
data = resp.read()
return Image.open(io.BytesIO(data)).convert("RGB")
return Image.open(path_or_url).convert("RGB")
# ---------- Preprocessing (custom, no resize) ----------
def pad_to_multiple(pil_img, multiple=16):
W, H = pil_img.size
H_pad = int(math.ceil(H / multiple) * multiple)
W_pad = int(math.ceil(W / multiple) * multiple)
if (H_pad, W_pad) == (H, W):
return pil_img, (0, 0, 0, 0)
canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
canvas.paste(pil_img, (0, 0))
return canvas, (0, 0, W_pad - W, H_pad - H)
def preprocess_image_no_resize(pil_img, multiple=16):
img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
disp_np = np.array(img_padded, dtype=np.uint8) # (H,W,3) for display
return {"pixel_values": pixel_tensor}, disp_np, pad_box
# ---------- Utilities ----------
def upsample_nearest(arr, H, W, ps):
if arr.ndim == 2:
return arr.repeat(ps, 0).repeat(ps, 1)
elif arr.ndim == 3:
# arr shape: (rows, cols, channels)
rows, cols, channels = arr.shape
arr_up = arr.repeat(ps, 0).repeat(ps, 1)
return arr_up.reshape(rows * ps, cols * ps, channels)
raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
def draw_grid(ax, rows, cols, ps):
for r in range(1, rows):
ax.axhline(r * ps - 0.5, lw=0.8, alpha=0.6, color="white", zorder=3)
for c in range(1, cols):
ax.axvline(c * ps - 0.5, lw=0.8, alpha=0.6, color="white", zorder=3)
def draw_indices(ax, rows, cols, ps):
for r in range(rows):
for c in range(cols):
idx = r * cols + c
ax.text(c * ps + ps / 2, r * ps + ps / 2, str(idx),
ha="center", va="center", fontsize=7,
color="white", alpha=0.95, zorder=4)
def rc_to_idx(r, c, cols): return int(r) * cols + int(c)
def idx_to_rc(i, cols): return (int(i) // cols, int(i) % cols)
# ---------- Per-image embeddings ----------
class PatchImageState:
def __init__(self, pil_img, model, device, ps):
self.pil = pil_img
self.ps = ps
inputs, disp_np, _ = preprocess_image_no_resize(pil_img, multiple=ps)
self.disp = disp_np
self.pixel_values = inputs["pixel_values"].to(device) # (1,3,H,W)
_, _, self.H, self.W = self.pixel_values.shape
self.rows, self.cols = self.H // ps, self.W // ps
with torch.no_grad():
out = model(pixel_values=self.pixel_values)
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
T, D = hs.shape
n_patches = self.rows * self.cols
n_special = T - n_patches # class + possible register tokens
if n_special < 1:
raise RuntimeError(
f"[error] Token shape mismatch. T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
)
self.D = D
self.patch_embs = hs[n_special:, :].reshape(self.rows, self.cols, D) # (rows,cols,D)
self.X = self.patch_embs.reshape(-1, D) # (N,D)
self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8) # normalized
# UI bits (set later by the viewers)
self.ax = None
self.overlay_im = None
self.sel_rect = None
self.best_rect = None
# ---------- Single-image mode ----------
def run_single_image(img_path, model, device, ps, show_grid, annotate_indices, overlay_alpha):
img = load_image(img_path)
st = PatchImageState(img, model, device, ps)
fig, ax = plt.subplots(figsize=(9, 9))
st.ax = ax
ax.imshow(st.disp, zorder=0)
ax.set_axis_off()
if show_grid:
draw_grid(ax, st.rows, st.cols, st.ps)
if annotate_indices:
draw_indices(ax, st.rows, st.cols, st.ps)
# neutral overlay to start
init_scalar = 0.5 * np.ones((st.rows, st.cols), dtype=np.float32)
rgba = plt.get_cmap("magma")(init_scalar)
rgba_up = upsample_nearest(rgba, st.H, st.W, st.ps)
st.overlay_im = ax.imshow(rgba_up, alpha=overlay_alpha, zorder=1)
st.sel_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="red", zorder=5)
ax.add_patch(st.sel_rect)
current_idx = (st.rows // 2) * st.cols + st.cols // 2
cmap = plt.get_cmap("magma")
def update(idx):
nonlocal current_idx
current_idx = int(np.clip(idx, 0, st.rows * st.cols - 1))
r, c = idx_to_rc(current_idx, st.cols)
q = st.X[current_idx]
qn = q / (np.linalg.norm(q) + 1e-8)
cos = st.Xn @ qn
cos_map = cos.reshape(st.rows, st.cols)
disp = (cos_map - cos_map.min()) / (cos_map.ptp() + 1e-8)
rgba = cmap(disp)
# Force selected cell to pure RED (and full alpha in the RGBA array)
rgba[r, c, 0:3] = np.array([1.0, 0.0, 0.0])
rgba[r, c, 3] = 1.0
st.overlay_im.set_data(upsample_nearest(rgba, st.H, st.W, st.ps))
st.overlay_im.set_alpha(overlay_alpha) # global alpha
st.sel_rect.set_xy((c * st.ps, r * st.ps))
ax.set_title(
f"Single-image • idx={current_idx} (r={r}, c={c}) • cos∈[{cos_map.min():.3f},{cos_map.max():.3f}]",
fontsize=11
)
fig.canvas.draw_idle()
def on_click(event):
if event.inaxes != ax or event.xdata is None or event.ydata is None:
return
r = int(np.clip(event.ydata // st.ps, 0, st.rows - 1))
c = int(np.clip(event.xdata // st.ps, 0, st.cols - 1))
update(rc_to_idx(r, c, st.cols))
def on_key(event):
nonlocal current_idx
r, c = idx_to_rc(current_idx, st.cols)
if event.key == "left":
c = max(0, c - 1)
elif event.key == "right":
c = min(st.cols - 1, c + 1)
elif event.key == "up":
r = max(0, r - 1)
elif event.key == "down":
r = min(st.rows - 1, r + 1)
elif event.key == "q":
plt.close(fig); return
else:
return
update(rc_to_idx(r, c, st.cols))
fig.canvas.mpl_connect("button_press_event", on_click)
fig.canvas.mpl_connect("key_press_event", on_key)
update(current_idx)
print("[single-image] Controls: click to select • arrows to move • 'q' to quit")
plt.tight_layout()
plt.show()
# ---------- Two-image mode (shows BOTH overlays; hides red rect on non-active) ----------
def run_two_images(img1_path, img2_path, model, device, ps, show_grid, annotate_indices, overlay_alpha):
img1, img2 = load_image(img1_path), load_image(img2_path)
S = [PatchImageState(img1, model, device, ps),
PatchImageState(img2, model, device, ps)]
if S[0].D != S[1].D:
raise RuntimeError("Embedding dims differ — use the same model for both images.")
fig, (axL, axR) = plt.subplots(1, 2, figsize=(12, 6))
axs = [axL, axR]
for i, (ax, st) in enumerate(zip(axs, S)):
st.ax = ax
ax.imshow(st.disp, zorder=0)
ax.set_axis_off()
if show_grid:
draw_grid(ax, st.rows, st.cols, st.ps)
if annotate_indices:
draw_indices(ax, st.rows, st.cols, st.ps)
# start overlays (hidden until first render)
init_scalar = 0.5 * np.ones((st.rows, st.cols), dtype=np.float32)
rgba = plt.get_cmap("magma")(init_scalar)
rgba_up = upsample_nearest(rgba, st.H, st.W, st.ps)
st.overlay_im = ax.imshow(rgba_up, alpha=0.0, zorder=1)
st.sel_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="red", zorder=5)
st.best_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="yellow", zorder=6)
ax.add_patch(st.sel_rect)
ax.add_patch(st.best_rect)
st.best_rect.set_visible(False)
active_side = 0 # 0=left, 1=right
current_idx = [ (S[0].rows//2)*S[0].cols + S[0].cols//2,
(S[1].rows//2)*S[1].cols + S[1].cols//2 ]
cmap = plt.get_cmap("magma")
def set_titles(src_i=None, self_stats=None, cross_stats=None):
axs[0].set_title(f"LEFT • {S[0].rows}x{S[0].cols} patches • {'ACTIVE' if active_side==0 else ''}", fontsize=10)
axs[1].set_title(f"RIGHT • {S[1].rows}x{S[1].cols} patches • {'ACTIVE' if active_side==1 else ''}", fontsize=10)
if src_i is not None and self_stats is not None and cross_stats is not None:
src_name = "LEFT" if src_i == 0 else "RIGHT"
tgt_name = "RIGHT" if src_i == 0 else "LEFT"
fig.suptitle(
f"Source: {src_name} | Self cos∈[{self_stats[0]:.3f},{self_stats[1]:.3f}] • "
f"{tgt_name} cos∈[{cross_stats[0]:.3f},{cross_stats[1]:.3f}] | "
f"Controls: click=select • arrows=move • '1'/'2'/'t'=switch side • 'q'=quit",
fontsize=11
)
else:
fig.suptitle(
"Controls: click=select • arrows=move • '1'/'2'/'t'=switch side • 'q'=quit",
fontsize=11
)
def clamp_idx(i, st):
return int(np.clip(i, 0, st.rows*st.cols - 1))
def update_selection_rects():
# position rects
for i, st in enumerate(S):
r, c = idx_to_rc(current_idx[i], st.cols)
st.sel_rect.set_xy((c * st.ps, r * st.ps))
# visibility: only active side shows red rect
for i, st in enumerate(S):
st.sel_rect.set_visible(i == active_side)
def compute_and_show_both_from_src(src_i):
"""Show self-similarity on src and cross-similarity on the other image."""
src = S[src_i]
tgt_i = 1 - src_i
tgt = S[tgt_i]
q_idx = clamp_idx(current_idx[src_i], src)
q = src.X[q_idx]
qn = q / (np.linalg.norm(q) + 1e-8)
# --- Self on src ---
cos_self = src.Xn @ qn
cos_map_self = cos_self.reshape(src.rows, src.cols)
disp_self = (cos_map_self - cos_map_self.min()) / (cos_map_self.ptp() + 1e-8)
rgba_self = cmap(disp_self)
r0, c0 = idx_to_rc(q_idx, src.cols)
rgba_self[r0, c0, 0:3] = np.array([1.0, 0.0, 0.0])
rgba_self[r0, c0, 3] = 1.0
src.overlay_im.set_data(upsample_nearest(rgba_self, src.H, src.W, src.ps))
src.overlay_im.set_alpha(overlay_alpha)
# --- Cross on tgt ---
cos_cross = tgt.Xn @ qn
cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
disp_cross = (cos_map_cross - cos_map_cross.min()) / (cos_map_cross.ptp() + 1e-8)
rgba_cross = cmap(disp_cross)
tgt.overlay_im.set_data(upsample_nearest(rgba_cross, tgt.H, tgt.W, tgt.ps))
tgt.overlay_im.set_alpha(overlay_alpha)
# highlight best match on target
best = int(np.argmax(cos_cross))
br, bc = idx_to_rc(best, tgt.cols)
tgt.best_rect.set_xy((bc * tgt.ps, br * tgt.ps))
tgt.best_rect.set_visible(True)
# Hide best on source (self best is the selected cell)
src.best_rect.set_visible(False)
set_titles(src_i, (cos_map_self.min(), cos_map_self.max()),
(cos_map_cross.min(), cos_map_cross.max()))
fig.canvas.draw_idle()
def on_click(event):
nonlocal active_side
if event.inaxes is None or event.xdata is None or event.ydata is None:
return
side = 0 if event.inaxes is axs[0] else (1 if event.inaxes is axs[1] else None)
if side is None: return
st = S[side]
r = int(np.clip(event.ydata // st.ps, 0, st.rows - 1))
c = int(np.clip(event.xdata // st.ps, 0, st.cols - 1))
current_idx[side] = rc_to_idx(r, c, st.cols)
active_side = side
update_selection_rects() # <-- updates visibility
compute_and_show_both_from_src(active_side)
def on_key(event):
nonlocal active_side
if event.key == "q":
plt.close(fig); return
if event.key in ("t", "T"):
active_side = 1 - active_side
update_selection_rects() # <-- updates visibility
compute_and_show_both_from_src(active_side); return
if event.key == "1":
active_side = 0
update_selection_rects()
compute_and_show_both_from_src(active_side); return
if event.key == "2":
active_side = 1
update_selection_rects()
compute_and_show_both_from_src(active_side); return
st = S[active_side]
r, c = idx_to_rc(current_idx[active_side], st.cols)
if event.key == "left":
c = max(0, c - 1)
elif event.key == "right":
c = min(st.cols - 1, c + 1)
elif event.key == "up":
r = max(0, r - 1)
elif event.key == "down":
r = min(st.rows - 1, r + 1)
else:
return
current_idx[active_side] = rc_to_idx(r, c, st.cols)
update_selection_rects() # keep visibility rule consistent
compute_and_show_both_from_src(active_side)
update_selection_rects() # initialize positions + visibility
set_titles()
compute_and_show_both_from_src(active_side)
fig.canvas.mpl_connect("button_press_event", on_click)
fig.canvas.mpl_connect("key_press_event", on_key)
print("[two-image BOTH] Controls:")
print(" • Click on LEFT/RIGHT to select query patch (shows self + cross overlays)")
print(" • Arrow keys move selection on ACTIVE side")
print(" • '1'/'2'/'t' to switch side • 'q' to quit")
plt.tight_layout()
plt.show()
# ---------- Main ----------
def main():
parser = argparse.ArgumentParser(description="DINOv3 Patch Similarity Viewer (1 or 2 images; two-image shows BOTH overlays)")
# Accept either --image (single) or --image1/--image2 (two)
parser.add_argument("--image", type=str, default=None,
help="Path/URL to image (single-image mode if only this is provided)")
parser.add_argument("--image1", type=str, default=None,
help="Path/URL to first image (two-image mode if image2 is also provided)")
parser.add_argument("--image2", type=str, default=None,
help="Path/URL to second image (two-image mode when given)")
parser.add_argument("--model", type=str, default=DEFAULT_MODEL_ID,
help="DINOv3 model repo id (e.g., facebook/dinov3-vits16-pretrain-lvd1689m)")
parser.add_argument("--show_grid", action="store_true", help="Draw patch grid")
parser.add_argument("--annotate_indices", action="store_true", help="Write patch indices on cells")
parser.add_argument("--overlay_alpha", type=float, default=OVERLAY_ALPHA, help="Heatmap alpha")
parser.add_argument("--patch_size", type=int, default=(PATCH_SIZE_OVERRIDE or -1),
help="Override patch size. Set 16 to force. Default: model's patch size")
args = parser.parse_args()
show_grid = args.show_grid or SHOW_GRID
annotate_indices = args.annotate_indices or ANNOTATE_INDICES
overlay_alpha = args.overlay_alpha
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[info] Device: {device}")
model = AutoModel.from_pretrained(args.model).to(device)
model.eval()
def get_patch_size_from_model(model, default=16):
# Try to get patch size from model config if available
if hasattr(model, "config") and hasattr(model.config, "patch_size"):
ps = model.config.patch_size
if isinstance(ps, (tuple, list)):
return ps[0]
return ps
# Try to get from model attributes
if hasattr(model, "patch_size"):
ps = model.patch_size
if isinstance(ps, (tuple, list)):
return ps[0]
return ps
return default
ps = args.patch_size if args.patch_size and args.patch_size > 0 else get_patch_size_from_model(model, 16)
print(f"[info] Using patch size: {ps}")
# Routing logic:
img1 = args.image1 or args.image
img2 = args.image2
if img1 and img2:
run_two_images(img1, img2, model, device, ps, show_grid, annotate_indices, overlay_alpha)
else:
img_single = img1 or DEFAULT_URL
run_single_image(img_single, model, device, ps, show_grid, annotate_indices, overlay_alpha)
if __name__ == "__main__":
main()