Initial commit
Browse files- DINOv3CosSimilarity.py +434 -0
- LICENSE +9 -0
- PatchCosSimilarity.ipynb +0 -0
- README.md +161 -14
- app.py +622 -0
- requirements.txt +10 -0
DINOv3CosSimilarity.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# filepath: DINOv3-PatchSimilarity/DINOv3CosSimilarity2Images.py
|
2 |
+
# filepath: DINOv3-PatchSimilarity/DINOv3CosSimilarity.py
|
3 |
+
# Interactive DINOv3 patch similarity viewer for one or two images
|
4 |
+
# Interactive DINOv3 patch similarity viewer (NO AutoImageProcessor, NO resize)
|
5 |
+
# - Single-image mode (0 or 1 image given): original behavior
|
6 |
+
# - Two-image mode (2 images given): when you click or move on one image,
|
7 |
+
# shows BOTH overlays (self on source, cross on target).
|
8 |
+
# NEW: Red selection rectangle is hidden on the non-active image.
|
9 |
+
|
10 |
+
import sys, math, io, urllib.request, argparse, os
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import torch
|
14 |
+
from torchvision import transforms
|
15 |
+
import matplotlib
|
16 |
+
try:
|
17 |
+
matplotlib.use("TkAgg")
|
18 |
+
except Exception:
|
19 |
+
pass
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
from matplotlib.patches import Rectangle
|
22 |
+
from transformers import AutoModel
|
23 |
+
|
24 |
+
# ---------- Defaults / knobs ----------
|
25 |
+
DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
26 |
+
DEFAULT_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m"
|
27 |
+
|
28 |
+
SHOW_GRID = True
|
29 |
+
ANNOTATE_INDICES = False
|
30 |
+
OVERLAY_ALPHA = 0.55
|
31 |
+
PATCH_SIZE_OVERRIDE = None # set 16 to force; None = read from model if available
|
32 |
+
|
33 |
+
# ---------- Image I/O ----------
|
34 |
+
def load_image(path_or_url):
|
35 |
+
if str(path_or_url).startswith(("http://", "https://")):
|
36 |
+
with urllib.request.urlopen(path_or_url) as resp:
|
37 |
+
data = resp.read()
|
38 |
+
return Image.open(io.BytesIO(data)).convert("RGB")
|
39 |
+
return Image.open(path_or_url).convert("RGB")
|
40 |
+
|
41 |
+
# ---------- Preprocessing (custom, no resize) ----------
|
42 |
+
def pad_to_multiple(pil_img, multiple=16):
|
43 |
+
W, H = pil_img.size
|
44 |
+
H_pad = int(math.ceil(H / multiple) * multiple)
|
45 |
+
W_pad = int(math.ceil(W / multiple) * multiple)
|
46 |
+
if (H_pad, W_pad) == (H, W):
|
47 |
+
return pil_img, (0, 0, 0, 0)
|
48 |
+
canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
|
49 |
+
canvas.paste(pil_img, (0, 0))
|
50 |
+
return canvas, (0, 0, W_pad - W, H_pad - H)
|
51 |
+
|
52 |
+
def preprocess_image_no_resize(pil_img, multiple=16):
|
53 |
+
img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
|
54 |
+
transform = transforms.Compose([
|
55 |
+
transforms.ToTensor(),
|
56 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
57 |
+
std=[0.229, 0.224, 0.225])
|
58 |
+
])
|
59 |
+
pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
|
60 |
+
disp_np = np.array(img_padded, dtype=np.uint8) # (H,W,3) for display
|
61 |
+
return {"pixel_values": pixel_tensor}, disp_np, pad_box
|
62 |
+
|
63 |
+
# ---------- Utilities ----------
|
64 |
+
def upsample_nearest(arr, H, W, ps):
|
65 |
+
if arr.ndim == 2:
|
66 |
+
return arr.repeat(ps, 0).repeat(ps, 1)
|
67 |
+
elif arr.ndim == 3:
|
68 |
+
# arr shape: (rows, cols, channels)
|
69 |
+
rows, cols, channels = arr.shape
|
70 |
+
arr_up = arr.repeat(ps, 0).repeat(ps, 1)
|
71 |
+
return arr_up.reshape(rows * ps, cols * ps, channels)
|
72 |
+
raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
|
73 |
+
|
74 |
+
def draw_grid(ax, rows, cols, ps):
|
75 |
+
for r in range(1, rows):
|
76 |
+
ax.axhline(r * ps - 0.5, lw=0.8, alpha=0.6, color="white", zorder=3)
|
77 |
+
for c in range(1, cols):
|
78 |
+
ax.axvline(c * ps - 0.5, lw=0.8, alpha=0.6, color="white", zorder=3)
|
79 |
+
|
80 |
+
def draw_indices(ax, rows, cols, ps):
|
81 |
+
for r in range(rows):
|
82 |
+
for c in range(cols):
|
83 |
+
idx = r * cols + c
|
84 |
+
ax.text(c * ps + ps / 2, r * ps + ps / 2, str(idx),
|
85 |
+
ha="center", va="center", fontsize=7,
|
86 |
+
color="white", alpha=0.95, zorder=4)
|
87 |
+
|
88 |
+
def rc_to_idx(r, c, cols): return int(r) * cols + int(c)
|
89 |
+
def idx_to_rc(i, cols): return (int(i) // cols, int(i) % cols)
|
90 |
+
|
91 |
+
# ---------- Per-image embeddings ----------
|
92 |
+
class PatchImageState:
|
93 |
+
def __init__(self, pil_img, model, device, ps):
|
94 |
+
self.pil = pil_img
|
95 |
+
self.ps = ps
|
96 |
+
inputs, disp_np, _ = preprocess_image_no_resize(pil_img, multiple=ps)
|
97 |
+
self.disp = disp_np
|
98 |
+
self.pixel_values = inputs["pixel_values"].to(device) # (1,3,H,W)
|
99 |
+
_, _, self.H, self.W = self.pixel_values.shape
|
100 |
+
self.rows, self.cols = self.H // ps, self.W // ps
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
out = model(pixel_values=self.pixel_values)
|
104 |
+
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
|
105 |
+
|
106 |
+
T, D = hs.shape
|
107 |
+
n_patches = self.rows * self.cols
|
108 |
+
n_special = T - n_patches # class + possible register tokens
|
109 |
+
if n_special < 1:
|
110 |
+
raise RuntimeError(
|
111 |
+
f"[error] Token shape mismatch. T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
|
112 |
+
)
|
113 |
+
|
114 |
+
self.D = D
|
115 |
+
self.patch_embs = hs[n_special:, :].reshape(self.rows, self.cols, D) # (rows,cols,D)
|
116 |
+
self.X = self.patch_embs.reshape(-1, D) # (N,D)
|
117 |
+
self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8) # normalized
|
118 |
+
|
119 |
+
# UI bits (set later by the viewers)
|
120 |
+
self.ax = None
|
121 |
+
self.overlay_im = None
|
122 |
+
self.sel_rect = None
|
123 |
+
self.best_rect = None
|
124 |
+
|
125 |
+
# ---------- Single-image mode ----------
|
126 |
+
def run_single_image(img_path, model, device, ps, show_grid, annotate_indices, overlay_alpha):
|
127 |
+
img = load_image(img_path)
|
128 |
+
st = PatchImageState(img, model, device, ps)
|
129 |
+
|
130 |
+
fig, ax = plt.subplots(figsize=(9, 9))
|
131 |
+
st.ax = ax
|
132 |
+
ax.imshow(st.disp, zorder=0)
|
133 |
+
ax.set_axis_off()
|
134 |
+
if show_grid:
|
135 |
+
draw_grid(ax, st.rows, st.cols, st.ps)
|
136 |
+
if annotate_indices:
|
137 |
+
draw_indices(ax, st.rows, st.cols, st.ps)
|
138 |
+
|
139 |
+
# neutral overlay to start
|
140 |
+
init_scalar = 0.5 * np.ones((st.rows, st.cols), dtype=np.float32)
|
141 |
+
rgba = plt.get_cmap("magma")(init_scalar)
|
142 |
+
rgba_up = upsample_nearest(rgba, st.H, st.W, st.ps)
|
143 |
+
st.overlay_im = ax.imshow(rgba_up, alpha=overlay_alpha, zorder=1)
|
144 |
+
|
145 |
+
st.sel_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="red", zorder=5)
|
146 |
+
ax.add_patch(st.sel_rect)
|
147 |
+
|
148 |
+
current_idx = (st.rows // 2) * st.cols + st.cols // 2
|
149 |
+
cmap = plt.get_cmap("magma")
|
150 |
+
|
151 |
+
def update(idx):
|
152 |
+
nonlocal current_idx
|
153 |
+
current_idx = int(np.clip(idx, 0, st.rows * st.cols - 1))
|
154 |
+
r, c = idx_to_rc(current_idx, st.cols)
|
155 |
+
|
156 |
+
q = st.X[current_idx]
|
157 |
+
qn = q / (np.linalg.norm(q) + 1e-8)
|
158 |
+
cos = st.Xn @ qn
|
159 |
+
cos_map = cos.reshape(st.rows, st.cols)
|
160 |
+
|
161 |
+
disp = (cos_map - cos_map.min()) / (cos_map.ptp() + 1e-8)
|
162 |
+
rgba = cmap(disp)
|
163 |
+
# Force selected cell to pure RED (and full alpha in the RGBA array)
|
164 |
+
rgba[r, c, 0:3] = np.array([1.0, 0.0, 0.0])
|
165 |
+
rgba[r, c, 3] = 1.0
|
166 |
+
|
167 |
+
st.overlay_im.set_data(upsample_nearest(rgba, st.H, st.W, st.ps))
|
168 |
+
st.overlay_im.set_alpha(overlay_alpha) # global alpha
|
169 |
+
|
170 |
+
st.sel_rect.set_xy((c * st.ps, r * st.ps))
|
171 |
+
ax.set_title(
|
172 |
+
f"Single-image • idx={current_idx} (r={r}, c={c}) • cos∈[{cos_map.min():.3f},{cos_map.max():.3f}]",
|
173 |
+
fontsize=11
|
174 |
+
)
|
175 |
+
fig.canvas.draw_idle()
|
176 |
+
|
177 |
+
def on_click(event):
|
178 |
+
if event.inaxes != ax or event.xdata is None or event.ydata is None:
|
179 |
+
return
|
180 |
+
r = int(np.clip(event.ydata // st.ps, 0, st.rows - 1))
|
181 |
+
c = int(np.clip(event.xdata // st.ps, 0, st.cols - 1))
|
182 |
+
update(rc_to_idx(r, c, st.cols))
|
183 |
+
|
184 |
+
def on_key(event):
|
185 |
+
nonlocal current_idx
|
186 |
+
r, c = idx_to_rc(current_idx, st.cols)
|
187 |
+
if event.key == "left":
|
188 |
+
c = max(0, c - 1)
|
189 |
+
elif event.key == "right":
|
190 |
+
c = min(st.cols - 1, c + 1)
|
191 |
+
elif event.key == "up":
|
192 |
+
r = max(0, r - 1)
|
193 |
+
elif event.key == "down":
|
194 |
+
r = min(st.rows - 1, r + 1)
|
195 |
+
elif event.key == "q":
|
196 |
+
plt.close(fig); return
|
197 |
+
else:
|
198 |
+
return
|
199 |
+
update(rc_to_idx(r, c, st.cols))
|
200 |
+
|
201 |
+
fig.canvas.mpl_connect("button_press_event", on_click)
|
202 |
+
fig.canvas.mpl_connect("key_press_event", on_key)
|
203 |
+
|
204 |
+
update(current_idx)
|
205 |
+
print("[single-image] Controls: click to select • arrows to move • 'q' to quit")
|
206 |
+
plt.tight_layout()
|
207 |
+
plt.show()
|
208 |
+
|
209 |
+
# ---------- Two-image mode (shows BOTH overlays; hides red rect on non-active) ----------
|
210 |
+
def run_two_images(img1_path, img2_path, model, device, ps, show_grid, annotate_indices, overlay_alpha):
|
211 |
+
img1, img2 = load_image(img1_path), load_image(img2_path)
|
212 |
+
S = [PatchImageState(img1, model, device, ps),
|
213 |
+
PatchImageState(img2, model, device, ps)]
|
214 |
+
if S[0].D != S[1].D:
|
215 |
+
raise RuntimeError("Embedding dims differ — use the same model for both images.")
|
216 |
+
|
217 |
+
fig, (axL, axR) = plt.subplots(1, 2, figsize=(12, 6))
|
218 |
+
axs = [axL, axR]
|
219 |
+
for i, (ax, st) in enumerate(zip(axs, S)):
|
220 |
+
st.ax = ax
|
221 |
+
ax.imshow(st.disp, zorder=0)
|
222 |
+
ax.set_axis_off()
|
223 |
+
if show_grid:
|
224 |
+
draw_grid(ax, st.rows, st.cols, st.ps)
|
225 |
+
if annotate_indices:
|
226 |
+
draw_indices(ax, st.rows, st.cols, st.ps)
|
227 |
+
# start overlays (hidden until first render)
|
228 |
+
init_scalar = 0.5 * np.ones((st.rows, st.cols), dtype=np.float32)
|
229 |
+
rgba = plt.get_cmap("magma")(init_scalar)
|
230 |
+
rgba_up = upsample_nearest(rgba, st.H, st.W, st.ps)
|
231 |
+
st.overlay_im = ax.imshow(rgba_up, alpha=0.0, zorder=1)
|
232 |
+
|
233 |
+
st.sel_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="red", zorder=5)
|
234 |
+
st.best_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="yellow", zorder=6)
|
235 |
+
ax.add_patch(st.sel_rect)
|
236 |
+
ax.add_patch(st.best_rect)
|
237 |
+
st.best_rect.set_visible(False)
|
238 |
+
|
239 |
+
active_side = 0 # 0=left, 1=right
|
240 |
+
current_idx = [ (S[0].rows//2)*S[0].cols + S[0].cols//2,
|
241 |
+
(S[1].rows//2)*S[1].cols + S[1].cols//2 ]
|
242 |
+
cmap = plt.get_cmap("magma")
|
243 |
+
|
244 |
+
def set_titles(src_i=None, self_stats=None, cross_stats=None):
|
245 |
+
axs[0].set_title(f"LEFT • {S[0].rows}x{S[0].cols} patches • {'ACTIVE' if active_side==0 else ''}", fontsize=10)
|
246 |
+
axs[1].set_title(f"RIGHT • {S[1].rows}x{S[1].cols} patches • {'ACTIVE' if active_side==1 else ''}", fontsize=10)
|
247 |
+
if src_i is not None and self_stats is not None and cross_stats is not None:
|
248 |
+
src_name = "LEFT" if src_i == 0 else "RIGHT"
|
249 |
+
tgt_name = "RIGHT" if src_i == 0 else "LEFT"
|
250 |
+
fig.suptitle(
|
251 |
+
f"Source: {src_name} | Self cos∈[{self_stats[0]:.3f},{self_stats[1]:.3f}] • "
|
252 |
+
f"{tgt_name} cos∈[{cross_stats[0]:.3f},{cross_stats[1]:.3f}] | "
|
253 |
+
f"Controls: click=select • arrows=move • '1'/'2'/'t'=switch side • 'q'=quit",
|
254 |
+
fontsize=11
|
255 |
+
)
|
256 |
+
else:
|
257 |
+
fig.suptitle(
|
258 |
+
"Controls: click=select • arrows=move • '1'/'2'/'t'=switch side • 'q'=quit",
|
259 |
+
fontsize=11
|
260 |
+
)
|
261 |
+
|
262 |
+
def clamp_idx(i, st):
|
263 |
+
return int(np.clip(i, 0, st.rows*st.cols - 1))
|
264 |
+
|
265 |
+
def update_selection_rects():
|
266 |
+
# position rects
|
267 |
+
for i, st in enumerate(S):
|
268 |
+
r, c = idx_to_rc(current_idx[i], st.cols)
|
269 |
+
st.sel_rect.set_xy((c * st.ps, r * st.ps))
|
270 |
+
# visibility: only active side shows red rect
|
271 |
+
for i, st in enumerate(S):
|
272 |
+
st.sel_rect.set_visible(i == active_side)
|
273 |
+
|
274 |
+
def compute_and_show_both_from_src(src_i):
|
275 |
+
"""Show self-similarity on src and cross-similarity on the other image."""
|
276 |
+
src = S[src_i]
|
277 |
+
tgt_i = 1 - src_i
|
278 |
+
tgt = S[tgt_i]
|
279 |
+
|
280 |
+
q_idx = clamp_idx(current_idx[src_i], src)
|
281 |
+
q = src.X[q_idx]
|
282 |
+
qn = q / (np.linalg.norm(q) + 1e-8)
|
283 |
+
|
284 |
+
# --- Self on src ---
|
285 |
+
cos_self = src.Xn @ qn
|
286 |
+
cos_map_self = cos_self.reshape(src.rows, src.cols)
|
287 |
+
disp_self = (cos_map_self - cos_map_self.min()) / (cos_map_self.ptp() + 1e-8)
|
288 |
+
rgba_self = cmap(disp_self)
|
289 |
+
r0, c0 = idx_to_rc(q_idx, src.cols)
|
290 |
+
rgba_self[r0, c0, 0:3] = np.array([1.0, 0.0, 0.0])
|
291 |
+
rgba_self[r0, c0, 3] = 1.0
|
292 |
+
src.overlay_im.set_data(upsample_nearest(rgba_self, src.H, src.W, src.ps))
|
293 |
+
src.overlay_im.set_alpha(overlay_alpha)
|
294 |
+
|
295 |
+
# --- Cross on tgt ---
|
296 |
+
cos_cross = tgt.Xn @ qn
|
297 |
+
cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
|
298 |
+
disp_cross = (cos_map_cross - cos_map_cross.min()) / (cos_map_cross.ptp() + 1e-8)
|
299 |
+
rgba_cross = cmap(disp_cross)
|
300 |
+
tgt.overlay_im.set_data(upsample_nearest(rgba_cross, tgt.H, tgt.W, tgt.ps))
|
301 |
+
tgt.overlay_im.set_alpha(overlay_alpha)
|
302 |
+
|
303 |
+
# highlight best match on target
|
304 |
+
best = int(np.argmax(cos_cross))
|
305 |
+
br, bc = idx_to_rc(best, tgt.cols)
|
306 |
+
tgt.best_rect.set_xy((bc * tgt.ps, br * tgt.ps))
|
307 |
+
tgt.best_rect.set_visible(True)
|
308 |
+
|
309 |
+
# Hide best on source (self best is the selected cell)
|
310 |
+
src.best_rect.set_visible(False)
|
311 |
+
|
312 |
+
set_titles(src_i, (cos_map_self.min(), cos_map_self.max()),
|
313 |
+
(cos_map_cross.min(), cos_map_cross.max()))
|
314 |
+
fig.canvas.draw_idle()
|
315 |
+
|
316 |
+
def on_click(event):
|
317 |
+
nonlocal active_side
|
318 |
+
if event.inaxes is None or event.xdata is None or event.ydata is None:
|
319 |
+
return
|
320 |
+
side = 0 if event.inaxes is axs[0] else (1 if event.inaxes is axs[1] else None)
|
321 |
+
if side is None: return
|
322 |
+
st = S[side]
|
323 |
+
r = int(np.clip(event.ydata // st.ps, 0, st.rows - 1))
|
324 |
+
c = int(np.clip(event.xdata // st.ps, 0, st.cols - 1))
|
325 |
+
current_idx[side] = rc_to_idx(r, c, st.cols)
|
326 |
+
active_side = side
|
327 |
+
update_selection_rects() # <-- updates visibility
|
328 |
+
compute_and_show_both_from_src(active_side)
|
329 |
+
|
330 |
+
def on_key(event):
|
331 |
+
nonlocal active_side
|
332 |
+
if event.key == "q":
|
333 |
+
plt.close(fig); return
|
334 |
+
if event.key in ("t", "T"):
|
335 |
+
active_side = 1 - active_side
|
336 |
+
update_selection_rects() # <-- updates visibility
|
337 |
+
compute_and_show_both_from_src(active_side); return
|
338 |
+
if event.key == "1":
|
339 |
+
active_side = 0
|
340 |
+
update_selection_rects()
|
341 |
+
compute_and_show_both_from_src(active_side); return
|
342 |
+
if event.key == "2":
|
343 |
+
active_side = 1
|
344 |
+
update_selection_rects()
|
345 |
+
compute_and_show_both_from_src(active_side); return
|
346 |
+
|
347 |
+
st = S[active_side]
|
348 |
+
r, c = idx_to_rc(current_idx[active_side], st.cols)
|
349 |
+
if event.key == "left":
|
350 |
+
c = max(0, c - 1)
|
351 |
+
elif event.key == "right":
|
352 |
+
c = min(st.cols - 1, c + 1)
|
353 |
+
elif event.key == "up":
|
354 |
+
r = max(0, r - 1)
|
355 |
+
elif event.key == "down":
|
356 |
+
r = min(st.rows - 1, r + 1)
|
357 |
+
else:
|
358 |
+
return
|
359 |
+
current_idx[active_side] = rc_to_idx(r, c, st.cols)
|
360 |
+
update_selection_rects() # keep visibility rule consistent
|
361 |
+
compute_and_show_both_from_src(active_side)
|
362 |
+
|
363 |
+
update_selection_rects() # initialize positions + visibility
|
364 |
+
set_titles()
|
365 |
+
compute_and_show_both_from_src(active_side)
|
366 |
+
|
367 |
+
fig.canvas.mpl_connect("button_press_event", on_click)
|
368 |
+
fig.canvas.mpl_connect("key_press_event", on_key)
|
369 |
+
|
370 |
+
print("[two-image BOTH] Controls:")
|
371 |
+
print(" • Click on LEFT/RIGHT to select query patch (shows self + cross overlays)")
|
372 |
+
print(" • Arrow keys move selection on ACTIVE side")
|
373 |
+
print(" • '1'/'2'/'t' to switch side • 'q' to quit")
|
374 |
+
plt.tight_layout()
|
375 |
+
plt.show()
|
376 |
+
|
377 |
+
# ---------- Main ----------
|
378 |
+
def main():
|
379 |
+
parser = argparse.ArgumentParser(description="DINOv3 Patch Similarity Viewer (1 or 2 images; two-image shows BOTH overlays)")
|
380 |
+
# Accept either --image (single) or --image1/--image2 (two)
|
381 |
+
parser.add_argument("--image", type=str, default=None,
|
382 |
+
help="Path/URL to image (single-image mode if only this is provided)")
|
383 |
+
parser.add_argument("--image1", type=str, default=None,
|
384 |
+
help="Path/URL to first image (two-image mode if image2 is also provided)")
|
385 |
+
parser.add_argument("--image2", type=str, default=None,
|
386 |
+
help="Path/URL to second image (two-image mode when given)")
|
387 |
+
parser.add_argument("--model", type=str, default=DEFAULT_MODEL_ID,
|
388 |
+
help="DINOv3 model repo id (e.g., facebook/dinov3-vits16-pretrain-lvd1689m)")
|
389 |
+
parser.add_argument("--show_grid", action="store_true", help="Draw patch grid")
|
390 |
+
parser.add_argument("--annotate_indices", action="store_true", help="Write patch indices on cells")
|
391 |
+
parser.add_argument("--overlay_alpha", type=float, default=OVERLAY_ALPHA, help="Heatmap alpha")
|
392 |
+
parser.add_argument("--patch_size", type=int, default=(PATCH_SIZE_OVERRIDE or -1),
|
393 |
+
help="Override patch size. Set 16 to force. Default: model's patch size")
|
394 |
+
args = parser.parse_args()
|
395 |
+
|
396 |
+
show_grid = args.show_grid or SHOW_GRID
|
397 |
+
annotate_indices = args.annotate_indices or ANNOTATE_INDICES
|
398 |
+
overlay_alpha = args.overlay_alpha
|
399 |
+
|
400 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
401 |
+
print(f"[info] Device: {device}")
|
402 |
+
model = AutoModel.from_pretrained(args.model).to(device)
|
403 |
+
model.eval()
|
404 |
+
|
405 |
+
def get_patch_size_from_model(model, default=16):
|
406 |
+
# Try to get patch size from model config if available
|
407 |
+
if hasattr(model, "config") and hasattr(model.config, "patch_size"):
|
408 |
+
ps = model.config.patch_size
|
409 |
+
if isinstance(ps, (tuple, list)):
|
410 |
+
return ps[0]
|
411 |
+
return ps
|
412 |
+
# Try to get from model attributes
|
413 |
+
if hasattr(model, "patch_size"):
|
414 |
+
ps = model.patch_size
|
415 |
+
if isinstance(ps, (tuple, list)):
|
416 |
+
return ps[0]
|
417 |
+
return ps
|
418 |
+
return default
|
419 |
+
|
420 |
+
ps = args.patch_size if args.patch_size and args.patch_size > 0 else get_patch_size_from_model(model, 16)
|
421 |
+
print(f"[info] Using patch size: {ps}")
|
422 |
+
|
423 |
+
# Routing logic:
|
424 |
+
img1 = args.image1 or args.image
|
425 |
+
img2 = args.image2
|
426 |
+
|
427 |
+
if img1 and img2:
|
428 |
+
run_two_images(img1, img2, model, device, ps, show_grid, annotate_indices, overlay_alpha)
|
429 |
+
else:
|
430 |
+
img_single = img1 or DEFAULT_URL
|
431 |
+
run_single_image(img_single, model, device, ps, show_grid, annotate_indices, overlay_alpha)
|
432 |
+
|
433 |
+
if __name__ == "__main__":
|
434 |
+
main()
|
LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 [Your Name or Organization]
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
1. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
8 |
+
|
9 |
+
2. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
PatchCosSimilarity.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,14 +1,161 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DINOv3 Patch Similarity Viewer
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
> **Note:** This README and repository are for educational purposes. The creation of this repo was inspired by the DINOv3 paper to help visualize and understand the output of the model.
|
6 |
+
|
7 |
+
## Purpose
|
8 |
+
|
9 |
+
This repository provides interactive tools to visualize and explore patch-wise similarity in images using the DINOv3 vision transformer model. It is designed for researchers, students, and practitioners interested in understanding how self-supervised vision transformers perceive and relate different regions of an image.
|
10 |
+
|
11 |
+
## About DINOv3
|
12 |
+
|
13 |
+
- **Paper:** [DINOv3: Self-supervised Vision Transformers with Enormous Teacher Models](https://arxiv.org/abs/2508.10104)
|
14 |
+
- **Meta Research Page:** [Meta DINOv3 Publication](https://ai.meta.com/dinov3/)
|
15 |
+
- **Official GitHub:** [facebookresearch/dinov3](https://github.com/facebookresearch/dinov3)
|
16 |
+
|
17 |
+
**Note:**
|
18 |
+
The DINOv3 model weights require access approval.
|
19 |
+
You can request access via the [Meta Research page](https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/) or by selecting the desired model on [Hugging Face model collection](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009).
|
20 |
+
|
21 |
+
## Features
|
22 |
+
|
23 |
+
- **Interactive Visualization:** Click on image patches or use arrow keys to explore patch similarity heatmaps.
|
24 |
+
- **Single or Two-Image Mode:** If one image is specified, shows self-similarity. If two images are specified, shows both self-similarity and cross-image similarity overlays interactively.
|
25 |
+
- **Image Preprocessing:** Loads and pads images without resizing, preserving the original aspect ratio.
|
26 |
+
- **Cosine Similarity Calculation:** Computes and visualizes cosine similarity between image patches.
|
27 |
+
- **Robust Fallback:** If an image URL fails to load, a default image is used.
|
28 |
+
|
29 |
+
## Installation
|
30 |
+
|
31 |
+
Install dependencies with:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install -r requirements.txt
|
35 |
+
```
|
36 |
+
|
37 |
+
## Model Selection
|
38 |
+
|
39 |
+
You can choose from several DINOv3 models available on Hugging Face (click to view each model card):
|
40 |
+
|
41 |
+
LVD-1689M Dataset (Web data)
|
42 |
+
- ViT
|
43 |
+
- [facebook/dinov3-vit7b16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-lvd1689m)
|
44 |
+
- [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m)
|
45 |
+
- [facebook/dinov3-vits16plus-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16plus-pretrain-lvd1689m)
|
46 |
+
- [facebook/dinov3-vitb16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m)
|
47 |
+
- [facebook/dinov3-vitl16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vitl16-pretrain-lvd1689m)
|
48 |
+
- [facebook/dinov3-vith16plus-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vith16plus-pretrain-lvd1689m)
|
49 |
+
|
50 |
+
- ConvNeXt
|
51 |
+
- [facebook/dinov3-convnext-tiny-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-tiny-pretrain-lvd1689m)
|
52 |
+
- [facebook/dinov3-convnext-small-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-small-pretrain-lvd1689m)
|
53 |
+
- [facebook/dinov3-convnext-base-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-base-pretrain-lvd1689m)
|
54 |
+
- [facebook/dinov3-convnext-large-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-large-pretrain-lvd1689m)
|
55 |
+
|
56 |
+
SAT-493M Dataset (Satellite data)
|
57 |
+
- ViT
|
58 |
+
- [facebook/dinov3-vitl16-pretrain-sat493m](https://huggingface.co/facebook/dinov3-vitl16-pretrain-sat493m)
|
59 |
+
- [facebook/dinov3-vit7b16-pretrain-sat493m](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-sat493m)
|
60 |
+
|
61 |
+
## Usage
|
62 |
+
|
63 |
+
### Gradio app
|
64 |
+
|
65 |
+
Run the Gradio app:
|
66 |
+
|
67 |
+
```bash
|
68 |
+
python app.py
|
69 |
+
```
|
70 |
+
|
71 |
+
After runnig the app, go to [http://localhost:7860/](http://localhost:7860/) to see the app running.
|
72 |
+
|
73 |
+
Then:
|
74 |
+
- Choose Dataset and model name
|
75 |
+
- For Single image similarity:
|
76 |
+
- Choose only one file or URL
|
77 |
+
- For 2 image similarity:
|
78 |
+
- Choose images from file and/or URL
|
79 |
+
- Click button "Initialize / Update "
|
80 |
+
- Select the desired patch from the image
|
81 |
+
- Watch the results
|
82 |
+
|
83 |
+
**Note:**
|
84 |
+
*Overlay alpha* is the intensity of the overlay of patches on top of image
|
85 |
+
|
86 |
+
### Python Script
|
87 |
+
|
88 |
+
Run the interactive viewer with the default COCO image:
|
89 |
+
|
90 |
+
```bash
|
91 |
+
python DINOv3CosSimilarity.py
|
92 |
+
```
|
93 |
+
|
94 |
+
#### Single Image Mode
|
95 |
+
|
96 |
+
Specify your own image (local path or URL):
|
97 |
+
|
98 |
+
```bash
|
99 |
+
python DINOv3CosSimilarity.py --image path/to/your/image.jpg
|
100 |
+
python DINOv3CosSimilarity.py --image https://yourdomain.com/image.png
|
101 |
+
```
|
102 |
+
|
103 |
+
#### Two Image Mode
|
104 |
+
|
105 |
+
Specify two images (local paths or URLs):
|
106 |
+
|
107 |
+
```bash
|
108 |
+
python DINOv3CosSimilarity.py --image1 path/to/image1.jpg --image2 path/to/image2.jpg
|
109 |
+
python DINOv3CosSimilarity.py --image1 https://yourdomain.com/image1.png --image2 https://yourdomain.com/image2.png
|
110 |
+
```
|
111 |
+
|
112 |
+
#### Model Selection
|
113 |
+
|
114 |
+
Specify the model with `--model` (default is vits16):
|
115 |
+
|
116 |
+
```bash
|
117 |
+
python DINOv3CosSimilarity.py --model facebook/dinov3-vitb16-pretrain-lvd1689m
|
118 |
+
```
|
119 |
+
|
120 |
+
#### Other Options
|
121 |
+
|
122 |
+
- `--show_grid` : Draw patch grid
|
123 |
+
- `--annotate_indices` : Write patch indices on cells
|
124 |
+
- `--overlay_alpha <float>` : Set heatmap alpha (default 0.55)
|
125 |
+
- `--patch_size <int>` : Override patch size (default: model's patch size)
|
126 |
+
|
127 |
+
#### Controls
|
128 |
+
|
129 |
+
- Mouse click to select a patch
|
130 |
+
- Arrow keys to move selection
|
131 |
+
- '1', '2', or 't' to switch active image (in two-image mode)
|
132 |
+
- 'q' to quit
|
133 |
+
|
134 |
+
## Demo Single Image
|
135 |
+
|
136 |
+

|
137 |
+
|
138 |
+
## Demo 2 Images
|
139 |
+
|
140 |
+

|
141 |
+
|
142 |
+
### Jupyter Notebook
|
143 |
+
|
144 |
+
1. Open `PatchCosSimilarity.ipynb` in Jupyter Notebook.
|
145 |
+
2. Run the cells to load an image and visualize patch similarities.
|
146 |
+
3. Set `url1` for single-image mode, or both `url1` and `url2` for two-image mode.
|
147 |
+
4. If an image fails to load, a default image will be used automatically.
|
148 |
+
5. Set the `model_id` variable to any of the models listed above (see commented lines at the top of the notebook).
|
149 |
+
|
150 |
+
**Notebook Controls:**
|
151 |
+
- Mouse click to select a patch
|
152 |
+
- Arrow keys to move selection
|
153 |
+
- '1', '2', or 't' to switch active image (in two-image mode)
|
154 |
+
|
155 |
+
## License
|
156 |
+
|
157 |
+
This project is licensed under the MIT License. See the `LICENSE` file for details.
|
158 |
+
|
159 |
+
## Acknowledgments
|
160 |
+
|
161 |
+
This project utilizes the DINOv3 model from Hugging Face's Transformers library, along with PyTorch, Matplotlib, and Pillow
|
app.py
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
# Gradio UI for interactive DINOv3 patch similarity (single or dual image)
|
3 |
+
# - No AutoImageProcessor, no resize (only pad to multiple of patch size)
|
4 |
+
# - Single image: click to show self-similarity; selected cell outlined in RED
|
5 |
+
# - Two images: click on one side -> self overlay on source, cross overlay on target; best match on target outlined in YELLOW
|
6 |
+
# - Red selection rectangle is hidden on the non-active image
|
7 |
+
# - Patch size inferred from model (no override). Patch indices are not annotated.
|
8 |
+
# - Dataset selector (LVD-1689M / SAT-493M); model dropdown shows only the short name between "dinov3-" and "-pretrain".
|
9 |
+
# - Sample URL dropdowns switch between LVD (COCO/Picsum) and SAT (satellite imagery) and auto-fill / clear uploads.
|
10 |
+
|
11 |
+
import io
|
12 |
+
import math
|
13 |
+
import urllib.request
|
14 |
+
from functools import lru_cache
|
15 |
+
from typing import Optional, Tuple, Dict, List
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
import numpy as np
|
19 |
+
from PIL import Image, ImageDraw
|
20 |
+
import torch
|
21 |
+
from torchvision import transforms
|
22 |
+
from transformers import AutoModel
|
23 |
+
from matplotlib import colormaps as cm
|
24 |
+
|
25 |
+
# ---------- Provided model IDs (ground truth list) ----------
|
26 |
+
MODEL_ID_LIST = [
|
27 |
+
"facebook/dinov3-vits16-pretrain-lvd1689m",
|
28 |
+
"facebook/dinov3-vits16plus-pretrain-lvd1689m",
|
29 |
+
"facebook/dinov3-vitb16-pretrain-lvd1689m",
|
30 |
+
"facebook/dinov3-vitl16-pretrain-lvd1689m",
|
31 |
+
"facebook/dinov3-vith16plus-pretrain-lvd1689m",
|
32 |
+
"facebook/dinov3-vit7b16-pretrain-lvd1689m",
|
33 |
+
"facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
|
34 |
+
"facebook/dinov3-convnext-small-pretrain-lvd1689m",
|
35 |
+
"facebook/dinov3-convnext-base-pretrain-lvd1689m",
|
36 |
+
"facebook/dinov3-convnext-large-pretrain-lvd1689m",
|
37 |
+
"facebook/dinov3-vitl16-pretrain-sat493m",
|
38 |
+
"facebook/dinov3-vit7b16-pretrain-sat493m",
|
39 |
+
]
|
40 |
+
|
41 |
+
DATASET_LABELS = {
|
42 |
+
"LVD-1689M": "lvd1689m",
|
43 |
+
"SAT-493M": "sat493m",
|
44 |
+
}
|
45 |
+
|
46 |
+
def build_model_maps(model_ids: List[str]):
|
47 |
+
"""
|
48 |
+
Returns:
|
49 |
+
valid_map[(dataset_key, short_name)] -> full_model_id
|
50 |
+
options_by_dataset[dataset_key] -> [short_name,...] (display order preserved)
|
51 |
+
"""
|
52 |
+
valid_map: Dict[Tuple[str, str], str] = {}
|
53 |
+
options_by_dataset: Dict[str, List[str]] = {"lvd1689m": [], "sat493m": []}
|
54 |
+
|
55 |
+
for mid in model_ids:
|
56 |
+
# Expect pattern: "facebook/dinov3-<short>-pretrain-<dataset>"
|
57 |
+
try:
|
58 |
+
prefix = "facebook/dinov3-"
|
59 |
+
start = mid.index(prefix) + len(prefix)
|
60 |
+
pre_idx = mid.index("-pretrain", start)
|
61 |
+
short = mid[start:pre_idx]
|
62 |
+
dataset = mid.split("-pretrain-")[-1].strip()
|
63 |
+
except Exception:
|
64 |
+
# Skip anything that doesn't match the expected pattern
|
65 |
+
continue
|
66 |
+
|
67 |
+
key = (dataset, short)
|
68 |
+
valid_map[key] = mid
|
69 |
+
if dataset in options_by_dataset and short not in options_by_dataset[dataset]:
|
70 |
+
options_by_dataset[dataset].append(short)
|
71 |
+
|
72 |
+
return valid_map, options_by_dataset
|
73 |
+
|
74 |
+
VALID_MODEL_MAP, MODEL_OPTIONS_BY_DATASET = build_model_maps(MODEL_ID_LIST)
|
75 |
+
|
76 |
+
# ---------- Defaults / knobs ----------
|
77 |
+
DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
78 |
+
DEFAULT_DATASET_LABEL = "LVD-1689M" # initial radio
|
79 |
+
DEFAULT_OVERLAY_ALPHA = 0.55
|
80 |
+
DEFAULT_SHOW_GRID = True
|
81 |
+
|
82 |
+
# ---------- Sample image URLs (dependent on dataset) ----------
|
83 |
+
SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
|
84 |
+
# LVD: current ones
|
85 |
+
"lvd1689m": [
|
86 |
+
("– choose a sample –", ""),
|
87 |
+
("COCO: 2 Cats on sofa (039769)", "http://images.cocodataset.org/val2017/000000039769.jpg"),
|
88 |
+
("COCO: Person skiing (000785)", "http://images.cocodataset.org/val2017/000000000785.jpg"),
|
89 |
+
("COCO: People running (000872)", "http://images.cocodataset.org/val2017/000000000872.jpg"),
|
90 |
+
("Picsum: Mountain (ID=1000)", "https://picsum.photos/id/1000/800/600"),
|
91 |
+
("Picsum: Kayak (ID=1011)", "https://picsum.photos/id/1011/800/600"),
|
92 |
+
("Picsum: Man and dog (ID=1012)", "https://picsum.photos/id/1012/800/600"),
|
93 |
+
],
|
94 |
+
# SAT: satellite imagery examples
|
95 |
+
"sat493m": [
|
96 |
+
("– choose a satellite sample –", ""),
|
97 |
+
("Blue Marble (NASA)", "https://upload.wikimedia.org/wikipedia/commons/9/9d/The_Blue_Marble_%28remastered%29.jpg"),
|
98 |
+
("GOES-16 Hurricane Florence (2018)", "https://upload.wikimedia.org/wikipedia/commons/5/5e/Hurricane_Florence_GOES-16_2018-09-12_1510Z.jpg"),
|
99 |
+
("NASA Earth Observatory: Philippines", "https://eoimages.gsfc.nasa.gov/images/imagerecords/151000/151639/philippines_tmo_2020118_lrg.jpg"),
|
100 |
+
],
|
101 |
+
}
|
102 |
+
|
103 |
+
def _sample_labels_for(dataset_label: str):
|
104 |
+
key = DATASET_LABELS.get(dataset_label, "lvd1689m")
|
105 |
+
return [label for label, _ in SAMPLE_URL_CHOICES.get(key, [])]
|
106 |
+
|
107 |
+
def _apply_sample(dataset_label: str, sample_label: str):
|
108 |
+
"""Fill textbox with chosen sample URL and clear any uploaded image."""
|
109 |
+
key = DATASET_LABELS.get(dataset_label, "lvd1689m")
|
110 |
+
sample_map = dict(SAMPLE_URL_CHOICES.get(key, []))
|
111 |
+
url = sample_map.get(sample_label, "")
|
112 |
+
return gr.update(value=url), None # (textbox update, clear upload)
|
113 |
+
|
114 |
+
# ---------- Utility ----------
|
115 |
+
def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]:
|
116 |
+
# Prefer URL if present
|
117 |
+
if url and str(url).strip().lower().startswith(("http://", "https://")):
|
118 |
+
with urllib.request.urlopen(url) as resp:
|
119 |
+
data = resp.read()
|
120 |
+
return Image.open(io.BytesIO(data)).convert("RGB")
|
121 |
+
if isinstance(src, Image.Image):
|
122 |
+
return src.convert("RGB")
|
123 |
+
return None
|
124 |
+
|
125 |
+
def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Image, Tuple[int, int, int, int]]:
|
126 |
+
W, H = pil_img.size
|
127 |
+
H_pad = int(math.ceil(H / multiple) * multiple)
|
128 |
+
W_pad = int(math.ceil(W / multiple) * multiple)
|
129 |
+
if (H_pad, W_pad) == (H, W):
|
130 |
+
return pil_img, (0, 0, 0, 0)
|
131 |
+
canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
|
132 |
+
canvas.paste(pil_img, (0, 0))
|
133 |
+
return canvas, (0, 0, W_pad - W, H_pad - H)
|
134 |
+
|
135 |
+
def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
|
136 |
+
img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
|
137 |
+
transform = transforms.Compose([
|
138 |
+
transforms.ToTensor(),
|
139 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
140 |
+
std =[0.229, 0.224, 0.225]),
|
141 |
+
])
|
142 |
+
pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
|
143 |
+
disp_np = np.array(img_padded, dtype=np.uint8)
|
144 |
+
return {"pixel_values": pixel_tensor}, disp_np, pad_box
|
145 |
+
|
146 |
+
def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
|
147 |
+
if arr.ndim == 2:
|
148 |
+
return arr.repeat(ps, 0).repeat(ps, 1)
|
149 |
+
elif arr.ndim == 3:
|
150 |
+
rows, cols, ch = arr.shape
|
151 |
+
arr2 = arr.repeat(ps, 0).repeat(ps, 1)
|
152 |
+
return arr2.reshape(rows * ps, cols * ps, ch)
|
153 |
+
raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
|
154 |
+
|
155 |
+
def blend_overlay(base_uint8: np.ndarray, overlay_rgb_float: np.ndarray, alpha: float) -> np.ndarray:
|
156 |
+
base = base_uint8.astype(np.float32)
|
157 |
+
over = (overlay_rgb_float * 255.0).astype(np.float32)
|
158 |
+
out = (1.0 - alpha) * base + alpha * over
|
159 |
+
return np.clip(out, 0, 255).astype(np.uint8)
|
160 |
+
|
161 |
+
def draw_grid(img: Image.Image, rows: int, cols: int, ps: int):
|
162 |
+
d = ImageDraw.Draw(img)
|
163 |
+
W, H = img.size
|
164 |
+
for r in range(1, rows):
|
165 |
+
y = r * ps
|
166 |
+
d.line([(0, y), (W, y)], fill=(255, 255, 255), width=1)
|
167 |
+
for c in range(1, cols):
|
168 |
+
x = c * ps
|
169 |
+
d.line([(x, 0), (x, H)], fill=(255, 255, 255), width=1)
|
170 |
+
|
171 |
+
def rc_to_idx(r: int, c: int, cols: int) -> int:
|
172 |
+
return int(r) * cols + int(c)
|
173 |
+
|
174 |
+
def idx_to_rc(i: int, cols: int) -> Tuple[int, int]:
|
175 |
+
return int(i) // cols, int(i) % cols
|
176 |
+
|
177 |
+
# ---------- Model cache ----------
|
178 |
+
@lru_cache(maxsize=3)
|
179 |
+
def load_model_cached(full_model_id: str, device_str: str):
|
180 |
+
device = torch.device(device_str)
|
181 |
+
model = AutoModel.from_pretrained(full_model_id).to(device)
|
182 |
+
model.eval()
|
183 |
+
return model
|
184 |
+
|
185 |
+
def infer_patch_size(model, default: int = 16) -> int:
|
186 |
+
if hasattr(model, "config") and hasattr(model.config, "patch_size"):
|
187 |
+
ps = model.config.patch_size
|
188 |
+
if isinstance(ps, (tuple, list)): return int(ps[0])
|
189 |
+
return int(ps)
|
190 |
+
if hasattr(model, "patch_size"):
|
191 |
+
ps = model.patch_size
|
192 |
+
if isinstance(ps, (tuple, list)): return int(ps[0])
|
193 |
+
return int(ps)
|
194 |
+
return default
|
195 |
+
|
196 |
+
# ---------- Per-image state ----------
|
197 |
+
class PatchImageState:
|
198 |
+
def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
|
199 |
+
self.pil = pil_img
|
200 |
+
self.ps = ps
|
201 |
+
inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
|
202 |
+
self.disp = disp_np
|
203 |
+
pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
|
204 |
+
_, _, H, W = pv.shape
|
205 |
+
self.H, self.W = int(H), int(W)
|
206 |
+
self.rows, self.cols = self.H // ps, self.W // ps
|
207 |
+
|
208 |
+
with torch.no_grad():
|
209 |
+
out = model(pixel_values=pv)
|
210 |
+
hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
|
211 |
+
|
212 |
+
T, D = hs.shape
|
213 |
+
n_patches = self.rows * self.cols
|
214 |
+
n_special = T - n_patches # class + maybe registers
|
215 |
+
if n_special < 1:
|
216 |
+
raise RuntimeError(
|
217 |
+
f"Token mismatch: T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
|
218 |
+
)
|
219 |
+
self.D = D
|
220 |
+
patches = hs[n_special:, :].reshape(self.rows, self.cols, D)
|
221 |
+
self.X = patches.reshape(-1, D)
|
222 |
+
self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8)
|
223 |
+
|
224 |
+
# ---------- Rendering / compute ----------
|
225 |
+
def render_with_cosmap(
|
226 |
+
st: PatchImageState,
|
227 |
+
cos_map: Optional[np.ndarray],
|
228 |
+
overlay_alpha: float,
|
229 |
+
show_grid_flag: bool,
|
230 |
+
select_idx: Optional[int] = None,
|
231 |
+
best_idx: Optional[int] = None,
|
232 |
+
) -> Image.Image:
|
233 |
+
H, W, ps = st.H, st.W, st.ps
|
234 |
+
rows, cols = st.rows, st.cols
|
235 |
+
|
236 |
+
if cos_map is None:
|
237 |
+
disp = np.full((rows, cols), 0.5, dtype=np.float32)
|
238 |
+
else:
|
239 |
+
vmin, vmax = float(cos_map.min()), float(cos_map.max())
|
240 |
+
rng = vmax - vmin if vmax > vmin else 1e-8
|
241 |
+
disp = (cos_map - vmin) / rng
|
242 |
+
|
243 |
+
cmap = cm.get_cmap("magma")
|
244 |
+
rgba = cmap(disp)
|
245 |
+
rgb = rgba[..., :3]
|
246 |
+
|
247 |
+
if select_idx is not None:
|
248 |
+
rs, cs = idx_to_rc(select_idx, cols)
|
249 |
+
rgb[rs, cs, :] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
|
250 |
+
|
251 |
+
over_rgb_up = upsample_nearest(rgb, H, W, ps)
|
252 |
+
blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha))
|
253 |
+
pil = Image.fromarray(blended)
|
254 |
+
|
255 |
+
draw = ImageDraw.Draw(pil)
|
256 |
+
if show_grid_flag:
|
257 |
+
draw_grid(pil, rows, cols, ps)
|
258 |
+
|
259 |
+
if select_idx is not None:
|
260 |
+
r, c = idx_to_rc(select_idx, cols)
|
261 |
+
x0, y0 = c * ps, r * ps
|
262 |
+
x1, y1 = x0 + ps - 1, y0 + ps - 1
|
263 |
+
draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=2)
|
264 |
+
|
265 |
+
if best_idx is not None:
|
266 |
+
r, c = idx_to_rc(best_idx, cols)
|
267 |
+
x0, y0 = c * ps, r * ps
|
268 |
+
x1, y1 = x0 + ps - 1, y0 + ps - 1
|
269 |
+
draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 255, 0), width=2)
|
270 |
+
|
271 |
+
return pil
|
272 |
+
|
273 |
+
def compute_self_and_cross(
|
274 |
+
src: PatchImageState,
|
275 |
+
tgt: Optional[PatchImageState],
|
276 |
+
q_idx: int,
|
277 |
+
):
|
278 |
+
q = src.X[q_idx]
|
279 |
+
qn = q / (np.linalg.norm(q) + 1e-8)
|
280 |
+
|
281 |
+
cos_self = src.Xn @ qn
|
282 |
+
cos_map_self = cos_self.reshape(src.rows, src.cols)
|
283 |
+
self_stats = (float(cos_map_self.min()), float(cos_map_self.max()))
|
284 |
+
|
285 |
+
cross_result = None
|
286 |
+
cos_map_cross = None
|
287 |
+
if tgt is not None:
|
288 |
+
cos_cross = tgt.Xn @ qn
|
289 |
+
cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
|
290 |
+
cross_min, cross_max = float(cos_map_cross.min()), float(cos_map_cross.max())
|
291 |
+
best_idx = int(np.argmax(cos_cross))
|
292 |
+
cross_result = (cross_min, cross_max, best_idx)
|
293 |
+
|
294 |
+
return cos_map_self, cos_map_cross, self_stats, cross_result
|
295 |
+
|
296 |
+
# ---------- Gradio helpers for model & samples ----------
|
297 |
+
def dataset_label_to_key(label: str) -> str:
|
298 |
+
return DATASET_LABELS.get(label, "lvd1689m")
|
299 |
+
|
300 |
+
def update_model_dropdown(dataset_label: str):
|
301 |
+
key = dataset_label_to_key(dataset_label)
|
302 |
+
opts = MODEL_OPTIONS_BY_DATASET.get(key, [])
|
303 |
+
default_val = opts[0] if opts else None
|
304 |
+
return gr.update(choices=opts, value=default_val)
|
305 |
+
|
306 |
+
def update_model_and_samples(dataset_label: str):
|
307 |
+
# Update model dropdown
|
308 |
+
model_update = update_model_dropdown(dataset_label)
|
309 |
+
# Update both sample dropdowns to dataset-specific options
|
310 |
+
labels = _sample_labels_for(dataset_label)
|
311 |
+
sample_update = gr.update(choices=labels, value=(labels[0] if labels else None))
|
312 |
+
return model_update, sample_update, sample_update
|
313 |
+
|
314 |
+
def resolve_full_model_id(dataset_label: str, short_name: str) -> Optional[str]:
|
315 |
+
key = (dataset_label_to_key(dataset_label), short_name)
|
316 |
+
return VALID_MODEL_MAP.get(key)
|
317 |
+
|
318 |
+
# ---------- Gradio callbacks ----------
|
319 |
+
def init_states(
|
320 |
+
left_img_in: Optional[Image.Image],
|
321 |
+
left_url: str,
|
322 |
+
right_img_in: Optional[Image.Image],
|
323 |
+
right_url: str,
|
324 |
+
dataset_label: str,
|
325 |
+
short_model: str,
|
326 |
+
show_grid_flag: bool,
|
327 |
+
overlay_alpha: float,
|
328 |
+
):
|
329 |
+
# Resolve images
|
330 |
+
left_img = load_image_from_any(left_img_in, left_url)
|
331 |
+
right_img = load_image_from_any(right_img_in, right_url)
|
332 |
+
if left_img is None and right_img is None:
|
333 |
+
left_img = load_image_from_any(None, DEFAULT_URL)
|
334 |
+
|
335 |
+
# Resolve model
|
336 |
+
full_model_id = resolve_full_model_id(dataset_label, short_model)
|
337 |
+
if not full_model_id:
|
338 |
+
return (gr.update(), gr.update(), None, None, 0, -1, -1, 16,
|
339 |
+
f"❌ Model not available: {dataset_label} / {short_model}")
|
340 |
+
|
341 |
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
342 |
+
model = load_model_cached(full_model_id, device_str)
|
343 |
+
ps = infer_patch_size(model, 16)
|
344 |
+
|
345 |
+
left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
|
346 |
+
right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
|
347 |
+
|
348 |
+
active_side = 0 if left_state is not None else 1
|
349 |
+
|
350 |
+
status = f"✔ Loaded: {full_model_id} | ps={ps}"
|
351 |
+
out_left, out_right = None, None
|
352 |
+
|
353 |
+
if left_state is not None and right_state is not None:
|
354 |
+
q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
|
355 |
+
cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
|
356 |
+
best_idx = cross_info[2] if cross_info else None
|
357 |
+
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
358 |
+
select_idx=q_idx, best_idx=None)
|
359 |
+
out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
|
360 |
+
select_idx=None, best_idx=best_idx)
|
361 |
+
status += (f" | LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] "
|
362 |
+
f"| RIGHT cross best={best_idx}")
|
363 |
+
left_idx, right_idx = q_idx, (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
|
364 |
+
elif left_state is not None:
|
365 |
+
q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
|
366 |
+
cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
|
367 |
+
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
368 |
+
select_idx=q_idx, best_idx=None)
|
369 |
+
status += f" | Single LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]"
|
370 |
+
left_idx, right_idx = q_idx, -1
|
371 |
+
else:
|
372 |
+
q_idx = (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
|
373 |
+
cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
|
374 |
+
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
375 |
+
select_idx=q_idx, best_idx=None)
|
376 |
+
status += f" | Single RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]"
|
377 |
+
left_idx, right_idx = -1, q_idx
|
378 |
+
|
379 |
+
return (
|
380 |
+
out_left, out_right,
|
381 |
+
left_state, right_state,
|
382 |
+
active_side,
|
383 |
+
left_idx, right_idx,
|
384 |
+
ps,
|
385 |
+
status
|
386 |
+
)
|
387 |
+
|
388 |
+
def _coords_to_idx(x: int, y: int, st: PatchImageState) -> int:
|
389 |
+
r = int(np.clip(y // st.ps, 0, st.rows - 1))
|
390 |
+
c = int(np.clip(x // st.ps, 0, st.cols - 1))
|
391 |
+
return rc_to_idx(r, c, st.cols)
|
392 |
+
|
393 |
+
def on_select_left(
|
394 |
+
evt: gr.SelectData,
|
395 |
+
left_state: Optional[PatchImageState],
|
396 |
+
right_state: Optional[PatchImageState],
|
397 |
+
show_grid_flag: bool,
|
398 |
+
overlay_alpha: float,
|
399 |
+
ps: int,
|
400 |
+
):
|
401 |
+
if left_state is None:
|
402 |
+
return gr.update(), gr.update(), 0, -1, -1, "Upload/Load a LEFT image first."
|
403 |
+
|
404 |
+
x, y = evt.index
|
405 |
+
q_idx = _coords_to_idx(x, y, left_state)
|
406 |
+
|
407 |
+
if right_state is not None:
|
408 |
+
cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
|
409 |
+
best_idx = cross_info[2]
|
410 |
+
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
411 |
+
select_idx=q_idx, best_idx=None)
|
412 |
+
out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
|
413 |
+
select_idx=None, best_idx=best_idx)
|
414 |
+
status = (f"LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
|
415 |
+
f"RIGHT cross best idx={best_idx}")
|
416 |
+
return out_left, out_right, 0, q_idx, -1, status
|
417 |
+
else:
|
418 |
+
cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
|
419 |
+
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
420 |
+
select_idx=q_idx, best_idx=None)
|
421 |
+
status = f"Single LEFT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
|
422 |
+
return out_left, gr.update(), 0, q_idx, -1, status
|
423 |
+
|
424 |
+
def on_select_right(
|
425 |
+
evt: gr.SelectData,
|
426 |
+
left_state: Optional[PatchImageState],
|
427 |
+
right_state: Optional[PatchImageState],
|
428 |
+
show_grid_flag: bool,
|
429 |
+
overlay_alpha: float,
|
430 |
+
ps: int,
|
431 |
+
):
|
432 |
+
if right_state is None:
|
433 |
+
return gr.update(), gr.update(), 1, -1, -1, "Upload/Load a RIGHT image first."
|
434 |
+
|
435 |
+
x, y = evt.index
|
436 |
+
q_idx = _coords_to_idx(x, y, right_state)
|
437 |
+
|
438 |
+
if left_state is not None:
|
439 |
+
cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(right_state, left_state, q_idx)
|
440 |
+
best_idx = cross_info[2]
|
441 |
+
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
442 |
+
select_idx=q_idx, best_idx=None)
|
443 |
+
out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
|
444 |
+
select_idx=None, best_idx=best_idx)
|
445 |
+
status = (f"RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
|
446 |
+
f"LEFT cross best idx={best_idx}")
|
447 |
+
return out_left, out_right, 1, -1, q_idx, status
|
448 |
+
else:
|
449 |
+
cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
|
450 |
+
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
451 |
+
select_idx=q_idx, best_idx=None)
|
452 |
+
status = f"Single RIGHT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
|
453 |
+
return gr.update(), out_right, 1, -1, q_idx, status
|
454 |
+
|
455 |
+
def rebuild_with_settings(
|
456 |
+
left_state: Optional[PatchImageState],
|
457 |
+
right_state: Optional[PatchImageState],
|
458 |
+
active_side: int,
|
459 |
+
left_idx: int,
|
460 |
+
right_idx: int,
|
461 |
+
show_grid_flag: bool,
|
462 |
+
overlay_alpha: float,
|
463 |
+
ps: int,
|
464 |
+
):
|
465 |
+
if left_state is None and right_state is None:
|
466 |
+
return gr.update(), gr.update(), "Load an image first."
|
467 |
+
|
468 |
+
if left_state is not None and right_state is not None:
|
469 |
+
if active_side == 0:
|
470 |
+
q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
|
471 |
+
cos_self, cos_cross, _, cross_info = compute_self_and_cross(left_state, right_state, q_idx)
|
472 |
+
best_idx = cross_info[2]
|
473 |
+
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
474 |
+
select_idx=q_idx, best_idx=None)
|
475 |
+
out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
|
476 |
+
select_idx=None, best_idx=best_idx)
|
477 |
+
else:
|
478 |
+
q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
|
479 |
+
cos_self, cos_cross, _, cross_info = compute_self_and_cross(right_state, left_state, q_idx)
|
480 |
+
best_idx = cross_info[2]
|
481 |
+
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
482 |
+
select_idx=q_idx, best_idx=None)
|
483 |
+
out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
|
484 |
+
select_idx=None, best_idx=best_idx)
|
485 |
+
return out_left, out_right, "Updated overlays."
|
486 |
+
elif left_state is not None:
|
487 |
+
q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
|
488 |
+
cos_self, _, _, _ = compute_self_and_cross(left_state, None, q_idx)
|
489 |
+
out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
|
490 |
+
select_idx=q_idx, best_idx=None)
|
491 |
+
return out_left, gr.update(), "Updated overlays."
|
492 |
+
else:
|
493 |
+
q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
|
494 |
+
cos_self, _, _, _ = compute_self_and_cross(right_state, None, q_idx)
|
495 |
+
out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
|
496 |
+
select_idx=q_idx, best_idx=None)
|
497 |
+
return gr.update(), out_right, "Updated overlays."
|
498 |
+
|
499 |
+
# ---------- Gradio UI ----------
|
500 |
+
with gr.Blocks(title="DINOv3 Patch Similarity (Self & Cross)") as demo:
|
501 |
+
gr.Markdown(
|
502 |
+
"""
|
503 |
+
# DINOv3 Patch Similarity (Self & Cross)
|
504 |
+
1) Pick **Dataset** (LVD-1689M / SAT-493M).
|
505 |
+
2) Pick **Model**.
|
506 |
+
3) Upload one or two images (or paste URLs) and press **Initialize / Update**.
|
507 |
+
- Click on a patch to update overlays.
|
508 |
+
- In two-image mode, the non-active image hides the red selection and shows **yellow** best match.
|
509 |
+
"""
|
510 |
+
)
|
511 |
+
|
512 |
+
with gr.Row():
|
513 |
+
dataset_radio = gr.Radio(
|
514 |
+
label="Dataset",
|
515 |
+
choices=list(DATASET_LABELS.keys()),
|
516 |
+
value=DEFAULT_DATASET_LABEL,
|
517 |
+
interactive=True
|
518 |
+
)
|
519 |
+
initial_key = DATASET_LABELS[DEFAULT_DATASET_LABEL]
|
520 |
+
initial_models = MODEL_OPTIONS_BY_DATASET.get(initial_key, [])
|
521 |
+
model_dropdown = gr.Dropdown(
|
522 |
+
label="Model name",
|
523 |
+
choices=initial_models,
|
524 |
+
value=(initial_models[0] if initial_models else None),
|
525 |
+
interactive=True
|
526 |
+
)
|
527 |
+
|
528 |
+
# initial sample labels based on default dataset
|
529 |
+
initial_sample_labels = [label for label, _ in SAMPLE_URL_CHOICES.get(initial_key, [])]
|
530 |
+
|
531 |
+
with gr.Row():
|
532 |
+
with gr.Column():
|
533 |
+
left_input = gr.Image(label="Left Image (upload)", type="pil",
|
534 |
+
sources=["upload", "clipboard", "webcam"], interactive=True)
|
535 |
+
left_url = gr.Textbox(label="Left Image URL (optional)", placeholder="https://...")
|
536 |
+
left_sample = gr.Dropdown(label="Use a sample URL",
|
537 |
+
choices=initial_sample_labels,
|
538 |
+
value=(initial_sample_labels[0] if initial_sample_labels else None),
|
539 |
+
interactive=True)
|
540 |
+
with gr.Column():
|
541 |
+
right_input = gr.Image(label="Right Image (upload)", type="pil",
|
542 |
+
sources=["upload", "clipboard", "webcam"], interactive=True)
|
543 |
+
right_url = gr.Textbox(label="Right Image URL (optional)", placeholder="https://...")
|
544 |
+
right_sample = gr.Dropdown(label="Use a sample URL",
|
545 |
+
choices=initial_sample_labels,
|
546 |
+
value=(initial_sample_labels[0] if initial_sample_labels else None),
|
547 |
+
interactive=True)
|
548 |
+
|
549 |
+
with gr.Accordion("Overlay Settings", open=True):
|
550 |
+
show_grid = gr.Checkbox(label="Show patch grid", value=DEFAULT_SHOW_GRID)
|
551 |
+
overlay_alpha = gr.Slider(label="Overlay alpha", minimum=0.0, maximum=1.0,
|
552 |
+
value=DEFAULT_OVERLAY_ALPHA, step=0.01)
|
553 |
+
|
554 |
+
init_btn = gr.Button("Initialize / Update", variant="primary")
|
555 |
+
|
556 |
+
with gr.Row():
|
557 |
+
left_view = gr.Image(label="LEFT (click to select patch)", interactive=True)
|
558 |
+
right_view = gr.Image(label="RIGHT (click to select patch)", interactive=True)
|
559 |
+
|
560 |
+
status = gr.Markdown("")
|
561 |
+
|
562 |
+
# Hidden states
|
563 |
+
left_state = gr.State(None)
|
564 |
+
right_state = gr.State(None)
|
565 |
+
active_side = gr.State(0)
|
566 |
+
left_idx = gr.State(-1)
|
567 |
+
right_idx = gr.State(-1)
|
568 |
+
ps_state = gr.State(16)
|
569 |
+
|
570 |
+
# Update model dropdown and sample lists when dataset changes
|
571 |
+
dataset_radio.change(
|
572 |
+
fn=update_model_and_samples,
|
573 |
+
inputs=[dataset_radio],
|
574 |
+
outputs=[model_dropdown, left_sample, right_sample]
|
575 |
+
)
|
576 |
+
|
577 |
+
# When a sample is chosen, set URL and clear any uploaded image (prefer URL)
|
578 |
+
left_sample.change(
|
579 |
+
fn=_apply_sample,
|
580 |
+
inputs=[dataset_radio, left_sample],
|
581 |
+
outputs=[left_url, left_input]
|
582 |
+
)
|
583 |
+
right_sample.change(
|
584 |
+
fn=_apply_sample,
|
585 |
+
inputs=[dataset_radio, right_sample],
|
586 |
+
outputs=[right_url, right_input]
|
587 |
+
)
|
588 |
+
|
589 |
+
# Initialize / reload model + overlays
|
590 |
+
init_btn.click(
|
591 |
+
fn=init_states,
|
592 |
+
inputs=[left_input, left_url, right_input, right_url, dataset_radio, model_dropdown, show_grid, overlay_alpha],
|
593 |
+
outputs=[left_view, right_view, left_state, right_state, active_side, left_idx, right_idx, ps_state, status],
|
594 |
+
show_progress=True
|
595 |
+
)
|
596 |
+
|
597 |
+
# Click handlers
|
598 |
+
left_view.select(
|
599 |
+
fn=on_select_left,
|
600 |
+
inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
|
601 |
+
outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
|
602 |
+
)
|
603 |
+
right_view.select(
|
604 |
+
fn=on_select_right,
|
605 |
+
inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
|
606 |
+
outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
|
607 |
+
)
|
608 |
+
|
609 |
+
# Live re-render on setting changes
|
610 |
+
show_grid.change(
|
611 |
+
fn=rebuild_with_settings,
|
612 |
+
inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
|
613 |
+
outputs=[left_view, right_view, status]
|
614 |
+
)
|
615 |
+
overlay_alpha.change(
|
616 |
+
fn=rebuild_with_settings,
|
617 |
+
inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
|
618 |
+
outputs=[left_view, right_view, status]
|
619 |
+
)
|
620 |
+
|
621 |
+
if __name__ == "__main__":
|
622 |
+
demo.queue().launch()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
matplotlib
|
4 |
+
numpy==1.26.4
|
5 |
+
ipykernel
|
6 |
+
ipywidgets
|
7 |
+
ipympl
|
8 |
+
|
9 |
+
# Install Transformers directly from GitHub source
|
10 |
+
transformers @ git+https://github.com/huggingface/transformers.git
|