manu02 commited on
Commit
faca9e1
·
verified ·
1 Parent(s): 8d7fef2

Initial commit

Browse files
Files changed (6) hide show
  1. DINOv3CosSimilarity.py +434 -0
  2. LICENSE +9 -0
  3. PatchCosSimilarity.ipynb +0 -0
  4. README.md +161 -14
  5. app.py +622 -0
  6. requirements.txt +10 -0
DINOv3CosSimilarity.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filepath: DINOv3-PatchSimilarity/DINOv3CosSimilarity2Images.py
2
+ # filepath: DINOv3-PatchSimilarity/DINOv3CosSimilarity.py
3
+ # Interactive DINOv3 patch similarity viewer for one or two images
4
+ # Interactive DINOv3 patch similarity viewer (NO AutoImageProcessor, NO resize)
5
+ # - Single-image mode (0 or 1 image given): original behavior
6
+ # - Two-image mode (2 images given): when you click or move on one image,
7
+ # shows BOTH overlays (self on source, cross on target).
8
+ # NEW: Red selection rectangle is hidden on the non-active image.
9
+
10
+ import sys, math, io, urllib.request, argparse, os
11
+ import numpy as np
12
+ from PIL import Image
13
+ import torch
14
+ from torchvision import transforms
15
+ import matplotlib
16
+ try:
17
+ matplotlib.use("TkAgg")
18
+ except Exception:
19
+ pass
20
+ import matplotlib.pyplot as plt
21
+ from matplotlib.patches import Rectangle
22
+ from transformers import AutoModel
23
+
24
+ # ---------- Defaults / knobs ----------
25
+ DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
26
+ DEFAULT_MODEL_ID = "facebook/dinov3-vits16-pretrain-lvd1689m"
27
+
28
+ SHOW_GRID = True
29
+ ANNOTATE_INDICES = False
30
+ OVERLAY_ALPHA = 0.55
31
+ PATCH_SIZE_OVERRIDE = None # set 16 to force; None = read from model if available
32
+
33
+ # ---------- Image I/O ----------
34
+ def load_image(path_or_url):
35
+ if str(path_or_url).startswith(("http://", "https://")):
36
+ with urllib.request.urlopen(path_or_url) as resp:
37
+ data = resp.read()
38
+ return Image.open(io.BytesIO(data)).convert("RGB")
39
+ return Image.open(path_or_url).convert("RGB")
40
+
41
+ # ---------- Preprocessing (custom, no resize) ----------
42
+ def pad_to_multiple(pil_img, multiple=16):
43
+ W, H = pil_img.size
44
+ H_pad = int(math.ceil(H / multiple) * multiple)
45
+ W_pad = int(math.ceil(W / multiple) * multiple)
46
+ if (H_pad, W_pad) == (H, W):
47
+ return pil_img, (0, 0, 0, 0)
48
+ canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
49
+ canvas.paste(pil_img, (0, 0))
50
+ return canvas, (0, 0, W_pad - W, H_pad - H)
51
+
52
+ def preprocess_image_no_resize(pil_img, multiple=16):
53
+ img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
54
+ transform = transforms.Compose([
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225])
58
+ ])
59
+ pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
60
+ disp_np = np.array(img_padded, dtype=np.uint8) # (H,W,3) for display
61
+ return {"pixel_values": pixel_tensor}, disp_np, pad_box
62
+
63
+ # ---------- Utilities ----------
64
+ def upsample_nearest(arr, H, W, ps):
65
+ if arr.ndim == 2:
66
+ return arr.repeat(ps, 0).repeat(ps, 1)
67
+ elif arr.ndim == 3:
68
+ # arr shape: (rows, cols, channels)
69
+ rows, cols, channels = arr.shape
70
+ arr_up = arr.repeat(ps, 0).repeat(ps, 1)
71
+ return arr_up.reshape(rows * ps, cols * ps, channels)
72
+ raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
73
+
74
+ def draw_grid(ax, rows, cols, ps):
75
+ for r in range(1, rows):
76
+ ax.axhline(r * ps - 0.5, lw=0.8, alpha=0.6, color="white", zorder=3)
77
+ for c in range(1, cols):
78
+ ax.axvline(c * ps - 0.5, lw=0.8, alpha=0.6, color="white", zorder=3)
79
+
80
+ def draw_indices(ax, rows, cols, ps):
81
+ for r in range(rows):
82
+ for c in range(cols):
83
+ idx = r * cols + c
84
+ ax.text(c * ps + ps / 2, r * ps + ps / 2, str(idx),
85
+ ha="center", va="center", fontsize=7,
86
+ color="white", alpha=0.95, zorder=4)
87
+
88
+ def rc_to_idx(r, c, cols): return int(r) * cols + int(c)
89
+ def idx_to_rc(i, cols): return (int(i) // cols, int(i) % cols)
90
+
91
+ # ---------- Per-image embeddings ----------
92
+ class PatchImageState:
93
+ def __init__(self, pil_img, model, device, ps):
94
+ self.pil = pil_img
95
+ self.ps = ps
96
+ inputs, disp_np, _ = preprocess_image_no_resize(pil_img, multiple=ps)
97
+ self.disp = disp_np
98
+ self.pixel_values = inputs["pixel_values"].to(device) # (1,3,H,W)
99
+ _, _, self.H, self.W = self.pixel_values.shape
100
+ self.rows, self.cols = self.H // ps, self.W // ps
101
+
102
+ with torch.no_grad():
103
+ out = model(pixel_values=self.pixel_values)
104
+ hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
105
+
106
+ T, D = hs.shape
107
+ n_patches = self.rows * self.cols
108
+ n_special = T - n_patches # class + possible register tokens
109
+ if n_special < 1:
110
+ raise RuntimeError(
111
+ f"[error] Token shape mismatch. T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
112
+ )
113
+
114
+ self.D = D
115
+ self.patch_embs = hs[n_special:, :].reshape(self.rows, self.cols, D) # (rows,cols,D)
116
+ self.X = self.patch_embs.reshape(-1, D) # (N,D)
117
+ self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8) # normalized
118
+
119
+ # UI bits (set later by the viewers)
120
+ self.ax = None
121
+ self.overlay_im = None
122
+ self.sel_rect = None
123
+ self.best_rect = None
124
+
125
+ # ---------- Single-image mode ----------
126
+ def run_single_image(img_path, model, device, ps, show_grid, annotate_indices, overlay_alpha):
127
+ img = load_image(img_path)
128
+ st = PatchImageState(img, model, device, ps)
129
+
130
+ fig, ax = plt.subplots(figsize=(9, 9))
131
+ st.ax = ax
132
+ ax.imshow(st.disp, zorder=0)
133
+ ax.set_axis_off()
134
+ if show_grid:
135
+ draw_grid(ax, st.rows, st.cols, st.ps)
136
+ if annotate_indices:
137
+ draw_indices(ax, st.rows, st.cols, st.ps)
138
+
139
+ # neutral overlay to start
140
+ init_scalar = 0.5 * np.ones((st.rows, st.cols), dtype=np.float32)
141
+ rgba = plt.get_cmap("magma")(init_scalar)
142
+ rgba_up = upsample_nearest(rgba, st.H, st.W, st.ps)
143
+ st.overlay_im = ax.imshow(rgba_up, alpha=overlay_alpha, zorder=1)
144
+
145
+ st.sel_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="red", zorder=5)
146
+ ax.add_patch(st.sel_rect)
147
+
148
+ current_idx = (st.rows // 2) * st.cols + st.cols // 2
149
+ cmap = plt.get_cmap("magma")
150
+
151
+ def update(idx):
152
+ nonlocal current_idx
153
+ current_idx = int(np.clip(idx, 0, st.rows * st.cols - 1))
154
+ r, c = idx_to_rc(current_idx, st.cols)
155
+
156
+ q = st.X[current_idx]
157
+ qn = q / (np.linalg.norm(q) + 1e-8)
158
+ cos = st.Xn @ qn
159
+ cos_map = cos.reshape(st.rows, st.cols)
160
+
161
+ disp = (cos_map - cos_map.min()) / (cos_map.ptp() + 1e-8)
162
+ rgba = cmap(disp)
163
+ # Force selected cell to pure RED (and full alpha in the RGBA array)
164
+ rgba[r, c, 0:3] = np.array([1.0, 0.0, 0.0])
165
+ rgba[r, c, 3] = 1.0
166
+
167
+ st.overlay_im.set_data(upsample_nearest(rgba, st.H, st.W, st.ps))
168
+ st.overlay_im.set_alpha(overlay_alpha) # global alpha
169
+
170
+ st.sel_rect.set_xy((c * st.ps, r * st.ps))
171
+ ax.set_title(
172
+ f"Single-image • idx={current_idx} (r={r}, c={c}) • cos∈[{cos_map.min():.3f},{cos_map.max():.3f}]",
173
+ fontsize=11
174
+ )
175
+ fig.canvas.draw_idle()
176
+
177
+ def on_click(event):
178
+ if event.inaxes != ax or event.xdata is None or event.ydata is None:
179
+ return
180
+ r = int(np.clip(event.ydata // st.ps, 0, st.rows - 1))
181
+ c = int(np.clip(event.xdata // st.ps, 0, st.cols - 1))
182
+ update(rc_to_idx(r, c, st.cols))
183
+
184
+ def on_key(event):
185
+ nonlocal current_idx
186
+ r, c = idx_to_rc(current_idx, st.cols)
187
+ if event.key == "left":
188
+ c = max(0, c - 1)
189
+ elif event.key == "right":
190
+ c = min(st.cols - 1, c + 1)
191
+ elif event.key == "up":
192
+ r = max(0, r - 1)
193
+ elif event.key == "down":
194
+ r = min(st.rows - 1, r + 1)
195
+ elif event.key == "q":
196
+ plt.close(fig); return
197
+ else:
198
+ return
199
+ update(rc_to_idx(r, c, st.cols))
200
+
201
+ fig.canvas.mpl_connect("button_press_event", on_click)
202
+ fig.canvas.mpl_connect("key_press_event", on_key)
203
+
204
+ update(current_idx)
205
+ print("[single-image] Controls: click to select • arrows to move • 'q' to quit")
206
+ plt.tight_layout()
207
+ plt.show()
208
+
209
+ # ---------- Two-image mode (shows BOTH overlays; hides red rect on non-active) ----------
210
+ def run_two_images(img1_path, img2_path, model, device, ps, show_grid, annotate_indices, overlay_alpha):
211
+ img1, img2 = load_image(img1_path), load_image(img2_path)
212
+ S = [PatchImageState(img1, model, device, ps),
213
+ PatchImageState(img2, model, device, ps)]
214
+ if S[0].D != S[1].D:
215
+ raise RuntimeError("Embedding dims differ — use the same model for both images.")
216
+
217
+ fig, (axL, axR) = plt.subplots(1, 2, figsize=(12, 6))
218
+ axs = [axL, axR]
219
+ for i, (ax, st) in enumerate(zip(axs, S)):
220
+ st.ax = ax
221
+ ax.imshow(st.disp, zorder=0)
222
+ ax.set_axis_off()
223
+ if show_grid:
224
+ draw_grid(ax, st.rows, st.cols, st.ps)
225
+ if annotate_indices:
226
+ draw_indices(ax, st.rows, st.cols, st.ps)
227
+ # start overlays (hidden until first render)
228
+ init_scalar = 0.5 * np.ones((st.rows, st.cols), dtype=np.float32)
229
+ rgba = plt.get_cmap("magma")(init_scalar)
230
+ rgba_up = upsample_nearest(rgba, st.H, st.W, st.ps)
231
+ st.overlay_im = ax.imshow(rgba_up, alpha=0.0, zorder=1)
232
+
233
+ st.sel_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="red", zorder=5)
234
+ st.best_rect = Rectangle((0, 0), st.ps, st.ps, fill=False, lw=2.0, ec="yellow", zorder=6)
235
+ ax.add_patch(st.sel_rect)
236
+ ax.add_patch(st.best_rect)
237
+ st.best_rect.set_visible(False)
238
+
239
+ active_side = 0 # 0=left, 1=right
240
+ current_idx = [ (S[0].rows//2)*S[0].cols + S[0].cols//2,
241
+ (S[1].rows//2)*S[1].cols + S[1].cols//2 ]
242
+ cmap = plt.get_cmap("magma")
243
+
244
+ def set_titles(src_i=None, self_stats=None, cross_stats=None):
245
+ axs[0].set_title(f"LEFT • {S[0].rows}x{S[0].cols} patches • {'ACTIVE' if active_side==0 else ''}", fontsize=10)
246
+ axs[1].set_title(f"RIGHT • {S[1].rows}x{S[1].cols} patches • {'ACTIVE' if active_side==1 else ''}", fontsize=10)
247
+ if src_i is not None and self_stats is not None and cross_stats is not None:
248
+ src_name = "LEFT" if src_i == 0 else "RIGHT"
249
+ tgt_name = "RIGHT" if src_i == 0 else "LEFT"
250
+ fig.suptitle(
251
+ f"Source: {src_name} | Self cos∈[{self_stats[0]:.3f},{self_stats[1]:.3f}] • "
252
+ f"{tgt_name} cos∈[{cross_stats[0]:.3f},{cross_stats[1]:.3f}] | "
253
+ f"Controls: click=select • arrows=move • '1'/'2'/'t'=switch side • 'q'=quit",
254
+ fontsize=11
255
+ )
256
+ else:
257
+ fig.suptitle(
258
+ "Controls: click=select • arrows=move • '1'/'2'/'t'=switch side • 'q'=quit",
259
+ fontsize=11
260
+ )
261
+
262
+ def clamp_idx(i, st):
263
+ return int(np.clip(i, 0, st.rows*st.cols - 1))
264
+
265
+ def update_selection_rects():
266
+ # position rects
267
+ for i, st in enumerate(S):
268
+ r, c = idx_to_rc(current_idx[i], st.cols)
269
+ st.sel_rect.set_xy((c * st.ps, r * st.ps))
270
+ # visibility: only active side shows red rect
271
+ for i, st in enumerate(S):
272
+ st.sel_rect.set_visible(i == active_side)
273
+
274
+ def compute_and_show_both_from_src(src_i):
275
+ """Show self-similarity on src and cross-similarity on the other image."""
276
+ src = S[src_i]
277
+ tgt_i = 1 - src_i
278
+ tgt = S[tgt_i]
279
+
280
+ q_idx = clamp_idx(current_idx[src_i], src)
281
+ q = src.X[q_idx]
282
+ qn = q / (np.linalg.norm(q) + 1e-8)
283
+
284
+ # --- Self on src ---
285
+ cos_self = src.Xn @ qn
286
+ cos_map_self = cos_self.reshape(src.rows, src.cols)
287
+ disp_self = (cos_map_self - cos_map_self.min()) / (cos_map_self.ptp() + 1e-8)
288
+ rgba_self = cmap(disp_self)
289
+ r0, c0 = idx_to_rc(q_idx, src.cols)
290
+ rgba_self[r0, c0, 0:3] = np.array([1.0, 0.0, 0.0])
291
+ rgba_self[r0, c0, 3] = 1.0
292
+ src.overlay_im.set_data(upsample_nearest(rgba_self, src.H, src.W, src.ps))
293
+ src.overlay_im.set_alpha(overlay_alpha)
294
+
295
+ # --- Cross on tgt ---
296
+ cos_cross = tgt.Xn @ qn
297
+ cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
298
+ disp_cross = (cos_map_cross - cos_map_cross.min()) / (cos_map_cross.ptp() + 1e-8)
299
+ rgba_cross = cmap(disp_cross)
300
+ tgt.overlay_im.set_data(upsample_nearest(rgba_cross, tgt.H, tgt.W, tgt.ps))
301
+ tgt.overlay_im.set_alpha(overlay_alpha)
302
+
303
+ # highlight best match on target
304
+ best = int(np.argmax(cos_cross))
305
+ br, bc = idx_to_rc(best, tgt.cols)
306
+ tgt.best_rect.set_xy((bc * tgt.ps, br * tgt.ps))
307
+ tgt.best_rect.set_visible(True)
308
+
309
+ # Hide best on source (self best is the selected cell)
310
+ src.best_rect.set_visible(False)
311
+
312
+ set_titles(src_i, (cos_map_self.min(), cos_map_self.max()),
313
+ (cos_map_cross.min(), cos_map_cross.max()))
314
+ fig.canvas.draw_idle()
315
+
316
+ def on_click(event):
317
+ nonlocal active_side
318
+ if event.inaxes is None or event.xdata is None or event.ydata is None:
319
+ return
320
+ side = 0 if event.inaxes is axs[0] else (1 if event.inaxes is axs[1] else None)
321
+ if side is None: return
322
+ st = S[side]
323
+ r = int(np.clip(event.ydata // st.ps, 0, st.rows - 1))
324
+ c = int(np.clip(event.xdata // st.ps, 0, st.cols - 1))
325
+ current_idx[side] = rc_to_idx(r, c, st.cols)
326
+ active_side = side
327
+ update_selection_rects() # <-- updates visibility
328
+ compute_and_show_both_from_src(active_side)
329
+
330
+ def on_key(event):
331
+ nonlocal active_side
332
+ if event.key == "q":
333
+ plt.close(fig); return
334
+ if event.key in ("t", "T"):
335
+ active_side = 1 - active_side
336
+ update_selection_rects() # <-- updates visibility
337
+ compute_and_show_both_from_src(active_side); return
338
+ if event.key == "1":
339
+ active_side = 0
340
+ update_selection_rects()
341
+ compute_and_show_both_from_src(active_side); return
342
+ if event.key == "2":
343
+ active_side = 1
344
+ update_selection_rects()
345
+ compute_and_show_both_from_src(active_side); return
346
+
347
+ st = S[active_side]
348
+ r, c = idx_to_rc(current_idx[active_side], st.cols)
349
+ if event.key == "left":
350
+ c = max(0, c - 1)
351
+ elif event.key == "right":
352
+ c = min(st.cols - 1, c + 1)
353
+ elif event.key == "up":
354
+ r = max(0, r - 1)
355
+ elif event.key == "down":
356
+ r = min(st.rows - 1, r + 1)
357
+ else:
358
+ return
359
+ current_idx[active_side] = rc_to_idx(r, c, st.cols)
360
+ update_selection_rects() # keep visibility rule consistent
361
+ compute_and_show_both_from_src(active_side)
362
+
363
+ update_selection_rects() # initialize positions + visibility
364
+ set_titles()
365
+ compute_and_show_both_from_src(active_side)
366
+
367
+ fig.canvas.mpl_connect("button_press_event", on_click)
368
+ fig.canvas.mpl_connect("key_press_event", on_key)
369
+
370
+ print("[two-image BOTH] Controls:")
371
+ print(" • Click on LEFT/RIGHT to select query patch (shows self + cross overlays)")
372
+ print(" • Arrow keys move selection on ACTIVE side")
373
+ print(" • '1'/'2'/'t' to switch side • 'q' to quit")
374
+ plt.tight_layout()
375
+ plt.show()
376
+
377
+ # ---------- Main ----------
378
+ def main():
379
+ parser = argparse.ArgumentParser(description="DINOv3 Patch Similarity Viewer (1 or 2 images; two-image shows BOTH overlays)")
380
+ # Accept either --image (single) or --image1/--image2 (two)
381
+ parser.add_argument("--image", type=str, default=None,
382
+ help="Path/URL to image (single-image mode if only this is provided)")
383
+ parser.add_argument("--image1", type=str, default=None,
384
+ help="Path/URL to first image (two-image mode if image2 is also provided)")
385
+ parser.add_argument("--image2", type=str, default=None,
386
+ help="Path/URL to second image (two-image mode when given)")
387
+ parser.add_argument("--model", type=str, default=DEFAULT_MODEL_ID,
388
+ help="DINOv3 model repo id (e.g., facebook/dinov3-vits16-pretrain-lvd1689m)")
389
+ parser.add_argument("--show_grid", action="store_true", help="Draw patch grid")
390
+ parser.add_argument("--annotate_indices", action="store_true", help="Write patch indices on cells")
391
+ parser.add_argument("--overlay_alpha", type=float, default=OVERLAY_ALPHA, help="Heatmap alpha")
392
+ parser.add_argument("--patch_size", type=int, default=(PATCH_SIZE_OVERRIDE or -1),
393
+ help="Override patch size. Set 16 to force. Default: model's patch size")
394
+ args = parser.parse_args()
395
+
396
+ show_grid = args.show_grid or SHOW_GRID
397
+ annotate_indices = args.annotate_indices or ANNOTATE_INDICES
398
+ overlay_alpha = args.overlay_alpha
399
+
400
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
401
+ print(f"[info] Device: {device}")
402
+ model = AutoModel.from_pretrained(args.model).to(device)
403
+ model.eval()
404
+
405
+ def get_patch_size_from_model(model, default=16):
406
+ # Try to get patch size from model config if available
407
+ if hasattr(model, "config") and hasattr(model.config, "patch_size"):
408
+ ps = model.config.patch_size
409
+ if isinstance(ps, (tuple, list)):
410
+ return ps[0]
411
+ return ps
412
+ # Try to get from model attributes
413
+ if hasattr(model, "patch_size"):
414
+ ps = model.patch_size
415
+ if isinstance(ps, (tuple, list)):
416
+ return ps[0]
417
+ return ps
418
+ return default
419
+
420
+ ps = args.patch_size if args.patch_size and args.patch_size > 0 else get_patch_size_from_model(model, 16)
421
+ print(f"[info] Using patch size: {ps}")
422
+
423
+ # Routing logic:
424
+ img1 = args.image1 or args.image
425
+ img2 = args.image2
426
+
427
+ if img1 and img2:
428
+ run_two_images(img1, img2, model, device, ps, show_grid, annotate_indices, overlay_alpha)
429
+ else:
430
+ img_single = img1 or DEFAULT_URL
431
+ run_single_image(img_single, model, device, ps, show_grid, annotate_indices, overlay_alpha)
432
+
433
+ if __name__ == "__main__":
434
+ main()
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 [Your Name or Organization]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ 1. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ 2. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
PatchCosSimilarity.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,14 +1,161 @@
1
- ---
2
- title: DINOv3 Interactive Patch Cosine Similarity
3
- emoji: 🏃
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.43.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Interactive tool to visualize patch-wise similarity in image
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DINOv3 Patch Similarity Viewer
2
+
3
+ ![Gradio Test app](assets/GradioAppTest.gif)
4
+
5
+ > **Note:** This README and repository are for educational purposes. The creation of this repo was inspired by the DINOv3 paper to help visualize and understand the output of the model.
6
+
7
+ ## Purpose
8
+
9
+ This repository provides interactive tools to visualize and explore patch-wise similarity in images using the DINOv3 vision transformer model. It is designed for researchers, students, and practitioners interested in understanding how self-supervised vision transformers perceive and relate different regions of an image.
10
+
11
+ ## About DINOv3
12
+
13
+ - **Paper:** [DINOv3: Self-supervised Vision Transformers with Enormous Teacher Models](https://arxiv.org/abs/2508.10104)
14
+ - **Meta Research Page:** [Meta DINOv3 Publication](https://ai.meta.com/dinov3/)
15
+ - **Official GitHub:** [facebookresearch/dinov3](https://github.com/facebookresearch/dinov3)
16
+
17
+ **Note:**
18
+ The DINOv3 model weights require access approval.
19
+ You can request access via the [Meta Research page](https://ai.meta.com/resources/models-and-libraries/dinov3-downloads/) or by selecting the desired model on [Hugging Face model collection](https://huggingface.co/collections/facebook/dinov3-68924841bd6b561778e31009).
20
+
21
+ ## Features
22
+
23
+ - **Interactive Visualization:** Click on image patches or use arrow keys to explore patch similarity heatmaps.
24
+ - **Single or Two-Image Mode:** If one image is specified, shows self-similarity. If two images are specified, shows both self-similarity and cross-image similarity overlays interactively.
25
+ - **Image Preprocessing:** Loads and pads images without resizing, preserving the original aspect ratio.
26
+ - **Cosine Similarity Calculation:** Computes and visualizes cosine similarity between image patches.
27
+ - **Robust Fallback:** If an image URL fails to load, a default image is used.
28
+
29
+ ## Installation
30
+
31
+ Install dependencies with:
32
+
33
+ ```bash
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Model Selection
38
+
39
+ You can choose from several DINOv3 models available on Hugging Face (click to view each model card):
40
+
41
+ LVD-1689M Dataset (Web data)
42
+ - ViT
43
+ - [facebook/dinov3-vit7b16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-lvd1689m)
44
+ - [facebook/dinov3-vits16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16-pretrain-lvd1689m)
45
+ - [facebook/dinov3-vits16plus-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vits16plus-pretrain-lvd1689m)
46
+ - [facebook/dinov3-vitb16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m)
47
+ - [facebook/dinov3-vitl16-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vitl16-pretrain-lvd1689m)
48
+ - [facebook/dinov3-vith16plus-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-vith16plus-pretrain-lvd1689m)
49
+
50
+ - ConvNeXt
51
+ - [facebook/dinov3-convnext-tiny-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-tiny-pretrain-lvd1689m)
52
+ - [facebook/dinov3-convnext-small-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-small-pretrain-lvd1689m)
53
+ - [facebook/dinov3-convnext-base-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-base-pretrain-lvd1689m)
54
+ - [facebook/dinov3-convnext-large-pretrain-lvd1689m](https://huggingface.co/facebook/dinov3-convnext-large-pretrain-lvd1689m)
55
+
56
+ SAT-493M Dataset (Satellite data)
57
+ - ViT
58
+ - [facebook/dinov3-vitl16-pretrain-sat493m](https://huggingface.co/facebook/dinov3-vitl16-pretrain-sat493m)
59
+ - [facebook/dinov3-vit7b16-pretrain-sat493m](https://huggingface.co/facebook/dinov3-vit7b16-pretrain-sat493m)
60
+
61
+ ## Usage
62
+
63
+ ### Gradio app
64
+
65
+ Run the Gradio app:
66
+
67
+ ```bash
68
+ python app.py
69
+ ```
70
+
71
+ After runnig the app, go to [http://localhost:7860/](http://localhost:7860/) to see the app running.
72
+
73
+ Then:
74
+ - Choose Dataset and model name
75
+ - For Single image similarity:
76
+ - Choose only one file or URL
77
+ - For 2 image similarity:
78
+ - Choose images from file and/or URL
79
+ - Click button "Initialize / Update "
80
+ - Select the desired patch from the image
81
+ - Watch the results
82
+
83
+ **Note:**
84
+ *Overlay alpha* is the intensity of the overlay of patches on top of image
85
+
86
+ ### Python Script
87
+
88
+ Run the interactive viewer with the default COCO image:
89
+
90
+ ```bash
91
+ python DINOv3CosSimilarity.py
92
+ ```
93
+
94
+ #### Single Image Mode
95
+
96
+ Specify your own image (local path or URL):
97
+
98
+ ```bash
99
+ python DINOv3CosSimilarity.py --image path/to/your/image.jpg
100
+ python DINOv3CosSimilarity.py --image https://yourdomain.com/image.png
101
+ ```
102
+
103
+ #### Two Image Mode
104
+
105
+ Specify two images (local paths or URLs):
106
+
107
+ ```bash
108
+ python DINOv3CosSimilarity.py --image1 path/to/image1.jpg --image2 path/to/image2.jpg
109
+ python DINOv3CosSimilarity.py --image1 https://yourdomain.com/image1.png --image2 https://yourdomain.com/image2.png
110
+ ```
111
+
112
+ #### Model Selection
113
+
114
+ Specify the model with `--model` (default is vits16):
115
+
116
+ ```bash
117
+ python DINOv3CosSimilarity.py --model facebook/dinov3-vitb16-pretrain-lvd1689m
118
+ ```
119
+
120
+ #### Other Options
121
+
122
+ - `--show_grid` : Draw patch grid
123
+ - `--annotate_indices` : Write patch indices on cells
124
+ - `--overlay_alpha <float>` : Set heatmap alpha (default 0.55)
125
+ - `--patch_size <int>` : Override patch size (default: model's patch size)
126
+
127
+ #### Controls
128
+
129
+ - Mouse click to select a patch
130
+ - Arrow keys to move selection
131
+ - '1', '2', or 't' to switch active image (in two-image mode)
132
+ - 'q' to quit
133
+
134
+ ## Demo Single Image
135
+
136
+ ![Interactive Patch Similarity Demo](assets/Test_Interactive_video.gif)
137
+
138
+ ## Demo 2 Images
139
+
140
+ ![Multiple Interactive Patch Similarity Demo](assets/Multiple_Interactive_test_video.gif)
141
+
142
+ ### Jupyter Notebook
143
+
144
+ 1. Open `PatchCosSimilarity.ipynb` in Jupyter Notebook.
145
+ 2. Run the cells to load an image and visualize patch similarities.
146
+ 3. Set `url1` for single-image mode, or both `url1` and `url2` for two-image mode.
147
+ 4. If an image fails to load, a default image will be used automatically.
148
+ 5. Set the `model_id` variable to any of the models listed above (see commented lines at the top of the notebook).
149
+
150
+ **Notebook Controls:**
151
+ - Mouse click to select a patch
152
+ - Arrow keys to move selection
153
+ - '1', '2', or 't' to switch active image (in two-image mode)
154
+
155
+ ## License
156
+
157
+ This project is licensed under the MIT License. See the `LICENSE` file for details.
158
+
159
+ ## Acknowledgments
160
+
161
+ This project utilizes the DINOv3 model from Hugging Face's Transformers library, along with PyTorch, Matplotlib, and Pillow
app.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Gradio UI for interactive DINOv3 patch similarity (single or dual image)
3
+ # - No AutoImageProcessor, no resize (only pad to multiple of patch size)
4
+ # - Single image: click to show self-similarity; selected cell outlined in RED
5
+ # - Two images: click on one side -> self overlay on source, cross overlay on target; best match on target outlined in YELLOW
6
+ # - Red selection rectangle is hidden on the non-active image
7
+ # - Patch size inferred from model (no override). Patch indices are not annotated.
8
+ # - Dataset selector (LVD-1689M / SAT-493M); model dropdown shows only the short name between "dinov3-" and "-pretrain".
9
+ # - Sample URL dropdowns switch between LVD (COCO/Picsum) and SAT (satellite imagery) and auto-fill / clear uploads.
10
+
11
+ import io
12
+ import math
13
+ import urllib.request
14
+ from functools import lru_cache
15
+ from typing import Optional, Tuple, Dict, List
16
+
17
+ import gradio as gr
18
+ import numpy as np
19
+ from PIL import Image, ImageDraw
20
+ import torch
21
+ from torchvision import transforms
22
+ from transformers import AutoModel
23
+ from matplotlib import colormaps as cm
24
+
25
+ # ---------- Provided model IDs (ground truth list) ----------
26
+ MODEL_ID_LIST = [
27
+ "facebook/dinov3-vits16-pretrain-lvd1689m",
28
+ "facebook/dinov3-vits16plus-pretrain-lvd1689m",
29
+ "facebook/dinov3-vitb16-pretrain-lvd1689m",
30
+ "facebook/dinov3-vitl16-pretrain-lvd1689m",
31
+ "facebook/dinov3-vith16plus-pretrain-lvd1689m",
32
+ "facebook/dinov3-vit7b16-pretrain-lvd1689m",
33
+ "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
34
+ "facebook/dinov3-convnext-small-pretrain-lvd1689m",
35
+ "facebook/dinov3-convnext-base-pretrain-lvd1689m",
36
+ "facebook/dinov3-convnext-large-pretrain-lvd1689m",
37
+ "facebook/dinov3-vitl16-pretrain-sat493m",
38
+ "facebook/dinov3-vit7b16-pretrain-sat493m",
39
+ ]
40
+
41
+ DATASET_LABELS = {
42
+ "LVD-1689M": "lvd1689m",
43
+ "SAT-493M": "sat493m",
44
+ }
45
+
46
+ def build_model_maps(model_ids: List[str]):
47
+ """
48
+ Returns:
49
+ valid_map[(dataset_key, short_name)] -> full_model_id
50
+ options_by_dataset[dataset_key] -> [short_name,...] (display order preserved)
51
+ """
52
+ valid_map: Dict[Tuple[str, str], str] = {}
53
+ options_by_dataset: Dict[str, List[str]] = {"lvd1689m": [], "sat493m": []}
54
+
55
+ for mid in model_ids:
56
+ # Expect pattern: "facebook/dinov3-<short>-pretrain-<dataset>"
57
+ try:
58
+ prefix = "facebook/dinov3-"
59
+ start = mid.index(prefix) + len(prefix)
60
+ pre_idx = mid.index("-pretrain", start)
61
+ short = mid[start:pre_idx]
62
+ dataset = mid.split("-pretrain-")[-1].strip()
63
+ except Exception:
64
+ # Skip anything that doesn't match the expected pattern
65
+ continue
66
+
67
+ key = (dataset, short)
68
+ valid_map[key] = mid
69
+ if dataset in options_by_dataset and short not in options_by_dataset[dataset]:
70
+ options_by_dataset[dataset].append(short)
71
+
72
+ return valid_map, options_by_dataset
73
+
74
+ VALID_MODEL_MAP, MODEL_OPTIONS_BY_DATASET = build_model_maps(MODEL_ID_LIST)
75
+
76
+ # ---------- Defaults / knobs ----------
77
+ DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
78
+ DEFAULT_DATASET_LABEL = "LVD-1689M" # initial radio
79
+ DEFAULT_OVERLAY_ALPHA = 0.55
80
+ DEFAULT_SHOW_GRID = True
81
+
82
+ # ---------- Sample image URLs (dependent on dataset) ----------
83
+ SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
84
+ # LVD: current ones
85
+ "lvd1689m": [
86
+ ("– choose a sample –", ""),
87
+ ("COCO: 2 Cats on sofa (039769)", "http://images.cocodataset.org/val2017/000000039769.jpg"),
88
+ ("COCO: Person skiing (000785)", "http://images.cocodataset.org/val2017/000000000785.jpg"),
89
+ ("COCO: People running (000872)", "http://images.cocodataset.org/val2017/000000000872.jpg"),
90
+ ("Picsum: Mountain (ID=1000)", "https://picsum.photos/id/1000/800/600"),
91
+ ("Picsum: Kayak (ID=1011)", "https://picsum.photos/id/1011/800/600"),
92
+ ("Picsum: Man and dog (ID=1012)", "https://picsum.photos/id/1012/800/600"),
93
+ ],
94
+ # SAT: satellite imagery examples
95
+ "sat493m": [
96
+ ("– choose a satellite sample –", ""),
97
+ ("Blue Marble (NASA)", "https://upload.wikimedia.org/wikipedia/commons/9/9d/The_Blue_Marble_%28remastered%29.jpg"),
98
+ ("GOES-16 Hurricane Florence (2018)", "https://upload.wikimedia.org/wikipedia/commons/5/5e/Hurricane_Florence_GOES-16_2018-09-12_1510Z.jpg"),
99
+ ("NASA Earth Observatory: Philippines", "https://eoimages.gsfc.nasa.gov/images/imagerecords/151000/151639/philippines_tmo_2020118_lrg.jpg"),
100
+ ],
101
+ }
102
+
103
+ def _sample_labels_for(dataset_label: str):
104
+ key = DATASET_LABELS.get(dataset_label, "lvd1689m")
105
+ return [label for label, _ in SAMPLE_URL_CHOICES.get(key, [])]
106
+
107
+ def _apply_sample(dataset_label: str, sample_label: str):
108
+ """Fill textbox with chosen sample URL and clear any uploaded image."""
109
+ key = DATASET_LABELS.get(dataset_label, "lvd1689m")
110
+ sample_map = dict(SAMPLE_URL_CHOICES.get(key, []))
111
+ url = sample_map.get(sample_label, "")
112
+ return gr.update(value=url), None # (textbox update, clear upload)
113
+
114
+ # ---------- Utility ----------
115
+ def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]:
116
+ # Prefer URL if present
117
+ if url and str(url).strip().lower().startswith(("http://", "https://")):
118
+ with urllib.request.urlopen(url) as resp:
119
+ data = resp.read()
120
+ return Image.open(io.BytesIO(data)).convert("RGB")
121
+ if isinstance(src, Image.Image):
122
+ return src.convert("RGB")
123
+ return None
124
+
125
+ def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Image, Tuple[int, int, int, int]]:
126
+ W, H = pil_img.size
127
+ H_pad = int(math.ceil(H / multiple) * multiple)
128
+ W_pad = int(math.ceil(W / multiple) * multiple)
129
+ if (H_pad, W_pad) == (H, W):
130
+ return pil_img, (0, 0, 0, 0)
131
+ canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
132
+ canvas.paste(pil_img, (0, 0))
133
+ return canvas, (0, 0, W_pad - W, H_pad - H)
134
+
135
+ def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
136
+ img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
137
+ transform = transforms.Compose([
138
+ transforms.ToTensor(),
139
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
140
+ std =[0.229, 0.224, 0.225]),
141
+ ])
142
+ pixel_tensor = transform(img_padded).unsqueeze(0) # (1,3,H,W)
143
+ disp_np = np.array(img_padded, dtype=np.uint8)
144
+ return {"pixel_values": pixel_tensor}, disp_np, pad_box
145
+
146
+ def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
147
+ if arr.ndim == 2:
148
+ return arr.repeat(ps, 0).repeat(ps, 1)
149
+ elif arr.ndim == 3:
150
+ rows, cols, ch = arr.shape
151
+ arr2 = arr.repeat(ps, 0).repeat(ps, 1)
152
+ return arr2.reshape(rows * ps, cols * ps, ch)
153
+ raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")
154
+
155
+ def blend_overlay(base_uint8: np.ndarray, overlay_rgb_float: np.ndarray, alpha: float) -> np.ndarray:
156
+ base = base_uint8.astype(np.float32)
157
+ over = (overlay_rgb_float * 255.0).astype(np.float32)
158
+ out = (1.0 - alpha) * base + alpha * over
159
+ return np.clip(out, 0, 255).astype(np.uint8)
160
+
161
+ def draw_grid(img: Image.Image, rows: int, cols: int, ps: int):
162
+ d = ImageDraw.Draw(img)
163
+ W, H = img.size
164
+ for r in range(1, rows):
165
+ y = r * ps
166
+ d.line([(0, y), (W, y)], fill=(255, 255, 255), width=1)
167
+ for c in range(1, cols):
168
+ x = c * ps
169
+ d.line([(x, 0), (x, H)], fill=(255, 255, 255), width=1)
170
+
171
+ def rc_to_idx(r: int, c: int, cols: int) -> int:
172
+ return int(r) * cols + int(c)
173
+
174
+ def idx_to_rc(i: int, cols: int) -> Tuple[int, int]:
175
+ return int(i) // cols, int(i) % cols
176
+
177
+ # ---------- Model cache ----------
178
+ @lru_cache(maxsize=3)
179
+ def load_model_cached(full_model_id: str, device_str: str):
180
+ device = torch.device(device_str)
181
+ model = AutoModel.from_pretrained(full_model_id).to(device)
182
+ model.eval()
183
+ return model
184
+
185
+ def infer_patch_size(model, default: int = 16) -> int:
186
+ if hasattr(model, "config") and hasattr(model.config, "patch_size"):
187
+ ps = model.config.patch_size
188
+ if isinstance(ps, (tuple, list)): return int(ps[0])
189
+ return int(ps)
190
+ if hasattr(model, "patch_size"):
191
+ ps = model.patch_size
192
+ if isinstance(ps, (tuple, list)): return int(ps[0])
193
+ return int(ps)
194
+ return default
195
+
196
+ # ---------- Per-image state ----------
197
+ class PatchImageState:
198
+ def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
199
+ self.pil = pil_img
200
+ self.ps = ps
201
+ inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
202
+ self.disp = disp_np
203
+ pv = inputs["pixel_values"].to(device_str) # (1,3,H,W)
204
+ _, _, H, W = pv.shape
205
+ self.H, self.W = int(H), int(W)
206
+ self.rows, self.cols = self.H // ps, self.W // ps
207
+
208
+ with torch.no_grad():
209
+ out = model(pixel_values=pv)
210
+ hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy() # (T,D)
211
+
212
+ T, D = hs.shape
213
+ n_patches = self.rows * self.cols
214
+ n_special = T - n_patches # class + maybe registers
215
+ if n_special < 1:
216
+ raise RuntimeError(
217
+ f"Token mismatch: T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
218
+ )
219
+ self.D = D
220
+ patches = hs[n_special:, :].reshape(self.rows, self.cols, D)
221
+ self.X = patches.reshape(-1, D)
222
+ self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8)
223
+
224
+ # ---------- Rendering / compute ----------
225
+ def render_with_cosmap(
226
+ st: PatchImageState,
227
+ cos_map: Optional[np.ndarray],
228
+ overlay_alpha: float,
229
+ show_grid_flag: bool,
230
+ select_idx: Optional[int] = None,
231
+ best_idx: Optional[int] = None,
232
+ ) -> Image.Image:
233
+ H, W, ps = st.H, st.W, st.ps
234
+ rows, cols = st.rows, st.cols
235
+
236
+ if cos_map is None:
237
+ disp = np.full((rows, cols), 0.5, dtype=np.float32)
238
+ else:
239
+ vmin, vmax = float(cos_map.min()), float(cos_map.max())
240
+ rng = vmax - vmin if vmax > vmin else 1e-8
241
+ disp = (cos_map - vmin) / rng
242
+
243
+ cmap = cm.get_cmap("magma")
244
+ rgba = cmap(disp)
245
+ rgb = rgba[..., :3]
246
+
247
+ if select_idx is not None:
248
+ rs, cs = idx_to_rc(select_idx, cols)
249
+ rgb[rs, cs, :] = np.array([1.0, 0.0, 0.0], dtype=np.float32)
250
+
251
+ over_rgb_up = upsample_nearest(rgb, H, W, ps)
252
+ blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha))
253
+ pil = Image.fromarray(blended)
254
+
255
+ draw = ImageDraw.Draw(pil)
256
+ if show_grid_flag:
257
+ draw_grid(pil, rows, cols, ps)
258
+
259
+ if select_idx is not None:
260
+ r, c = idx_to_rc(select_idx, cols)
261
+ x0, y0 = c * ps, r * ps
262
+ x1, y1 = x0 + ps - 1, y0 + ps - 1
263
+ draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=2)
264
+
265
+ if best_idx is not None:
266
+ r, c = idx_to_rc(best_idx, cols)
267
+ x0, y0 = c * ps, r * ps
268
+ x1, y1 = x0 + ps - 1, y0 + ps - 1
269
+ draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 255, 0), width=2)
270
+
271
+ return pil
272
+
273
+ def compute_self_and_cross(
274
+ src: PatchImageState,
275
+ tgt: Optional[PatchImageState],
276
+ q_idx: int,
277
+ ):
278
+ q = src.X[q_idx]
279
+ qn = q / (np.linalg.norm(q) + 1e-8)
280
+
281
+ cos_self = src.Xn @ qn
282
+ cos_map_self = cos_self.reshape(src.rows, src.cols)
283
+ self_stats = (float(cos_map_self.min()), float(cos_map_self.max()))
284
+
285
+ cross_result = None
286
+ cos_map_cross = None
287
+ if tgt is not None:
288
+ cos_cross = tgt.Xn @ qn
289
+ cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
290
+ cross_min, cross_max = float(cos_map_cross.min()), float(cos_map_cross.max())
291
+ best_idx = int(np.argmax(cos_cross))
292
+ cross_result = (cross_min, cross_max, best_idx)
293
+
294
+ return cos_map_self, cos_map_cross, self_stats, cross_result
295
+
296
+ # ---------- Gradio helpers for model & samples ----------
297
+ def dataset_label_to_key(label: str) -> str:
298
+ return DATASET_LABELS.get(label, "lvd1689m")
299
+
300
+ def update_model_dropdown(dataset_label: str):
301
+ key = dataset_label_to_key(dataset_label)
302
+ opts = MODEL_OPTIONS_BY_DATASET.get(key, [])
303
+ default_val = opts[0] if opts else None
304
+ return gr.update(choices=opts, value=default_val)
305
+
306
+ def update_model_and_samples(dataset_label: str):
307
+ # Update model dropdown
308
+ model_update = update_model_dropdown(dataset_label)
309
+ # Update both sample dropdowns to dataset-specific options
310
+ labels = _sample_labels_for(dataset_label)
311
+ sample_update = gr.update(choices=labels, value=(labels[0] if labels else None))
312
+ return model_update, sample_update, sample_update
313
+
314
+ def resolve_full_model_id(dataset_label: str, short_name: str) -> Optional[str]:
315
+ key = (dataset_label_to_key(dataset_label), short_name)
316
+ return VALID_MODEL_MAP.get(key)
317
+
318
+ # ---------- Gradio callbacks ----------
319
+ def init_states(
320
+ left_img_in: Optional[Image.Image],
321
+ left_url: str,
322
+ right_img_in: Optional[Image.Image],
323
+ right_url: str,
324
+ dataset_label: str,
325
+ short_model: str,
326
+ show_grid_flag: bool,
327
+ overlay_alpha: float,
328
+ ):
329
+ # Resolve images
330
+ left_img = load_image_from_any(left_img_in, left_url)
331
+ right_img = load_image_from_any(right_img_in, right_url)
332
+ if left_img is None and right_img is None:
333
+ left_img = load_image_from_any(None, DEFAULT_URL)
334
+
335
+ # Resolve model
336
+ full_model_id = resolve_full_model_id(dataset_label, short_model)
337
+ if not full_model_id:
338
+ return (gr.update(), gr.update(), None, None, 0, -1, -1, 16,
339
+ f"❌ Model not available: {dataset_label} / {short_model}")
340
+
341
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
342
+ model = load_model_cached(full_model_id, device_str)
343
+ ps = infer_patch_size(model, 16)
344
+
345
+ left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
346
+ right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None
347
+
348
+ active_side = 0 if left_state is not None else 1
349
+
350
+ status = f"✔ Loaded: {full_model_id} | ps={ps}"
351
+ out_left, out_right = None, None
352
+
353
+ if left_state is not None and right_state is not None:
354
+ q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
355
+ cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
356
+ best_idx = cross_info[2] if cross_info else None
357
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
358
+ select_idx=q_idx, best_idx=None)
359
+ out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
360
+ select_idx=None, best_idx=best_idx)
361
+ status += (f" | LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] "
362
+ f"| RIGHT cross best={best_idx}")
363
+ left_idx, right_idx = q_idx, (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
364
+ elif left_state is not None:
365
+ q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
366
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
367
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
368
+ select_idx=q_idx, best_idx=None)
369
+ status += f" | Single LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]"
370
+ left_idx, right_idx = q_idx, -1
371
+ else:
372
+ q_idx = (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
373
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
374
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
375
+ select_idx=q_idx, best_idx=None)
376
+ status += f" | Single RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]"
377
+ left_idx, right_idx = -1, q_idx
378
+
379
+ return (
380
+ out_left, out_right,
381
+ left_state, right_state,
382
+ active_side,
383
+ left_idx, right_idx,
384
+ ps,
385
+ status
386
+ )
387
+
388
+ def _coords_to_idx(x: int, y: int, st: PatchImageState) -> int:
389
+ r = int(np.clip(y // st.ps, 0, st.rows - 1))
390
+ c = int(np.clip(x // st.ps, 0, st.cols - 1))
391
+ return rc_to_idx(r, c, st.cols)
392
+
393
+ def on_select_left(
394
+ evt: gr.SelectData,
395
+ left_state: Optional[PatchImageState],
396
+ right_state: Optional[PatchImageState],
397
+ show_grid_flag: bool,
398
+ overlay_alpha: float,
399
+ ps: int,
400
+ ):
401
+ if left_state is None:
402
+ return gr.update(), gr.update(), 0, -1, -1, "Upload/Load a LEFT image first."
403
+
404
+ x, y = evt.index
405
+ q_idx = _coords_to_idx(x, y, left_state)
406
+
407
+ if right_state is not None:
408
+ cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
409
+ best_idx = cross_info[2]
410
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
411
+ select_idx=q_idx, best_idx=None)
412
+ out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
413
+ select_idx=None, best_idx=best_idx)
414
+ status = (f"LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
415
+ f"RIGHT cross best idx={best_idx}")
416
+ return out_left, out_right, 0, q_idx, -1, status
417
+ else:
418
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
419
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
420
+ select_idx=q_idx, best_idx=None)
421
+ status = f"Single LEFT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
422
+ return out_left, gr.update(), 0, q_idx, -1, status
423
+
424
+ def on_select_right(
425
+ evt: gr.SelectData,
426
+ left_state: Optional[PatchImageState],
427
+ right_state: Optional[PatchImageState],
428
+ show_grid_flag: bool,
429
+ overlay_alpha: float,
430
+ ps: int,
431
+ ):
432
+ if right_state is None:
433
+ return gr.update(), gr.update(), 1, -1, -1, "Upload/Load a RIGHT image first."
434
+
435
+ x, y = evt.index
436
+ q_idx = _coords_to_idx(x, y, right_state)
437
+
438
+ if left_state is not None:
439
+ cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(right_state, left_state, q_idx)
440
+ best_idx = cross_info[2]
441
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
442
+ select_idx=q_idx, best_idx=None)
443
+ out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
444
+ select_idx=None, best_idx=best_idx)
445
+ status = (f"RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}] | "
446
+ f"LEFT cross best idx={best_idx}")
447
+ return out_left, out_right, 1, -1, q_idx, status
448
+ else:
449
+ cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
450
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
451
+ select_idx=q_idx, best_idx=None)
452
+ status = f"Single RIGHT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
453
+ return gr.update(), out_right, 1, -1, q_idx, status
454
+
455
+ def rebuild_with_settings(
456
+ left_state: Optional[PatchImageState],
457
+ right_state: Optional[PatchImageState],
458
+ active_side: int,
459
+ left_idx: int,
460
+ right_idx: int,
461
+ show_grid_flag: bool,
462
+ overlay_alpha: float,
463
+ ps: int,
464
+ ):
465
+ if left_state is None and right_state is None:
466
+ return gr.update(), gr.update(), "Load an image first."
467
+
468
+ if left_state is not None and right_state is not None:
469
+ if active_side == 0:
470
+ q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
471
+ cos_self, cos_cross, _, cross_info = compute_self_and_cross(left_state, right_state, q_idx)
472
+ best_idx = cross_info[2]
473
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
474
+ select_idx=q_idx, best_idx=None)
475
+ out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
476
+ select_idx=None, best_idx=best_idx)
477
+ else:
478
+ q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
479
+ cos_self, cos_cross, _, cross_info = compute_self_and_cross(right_state, left_state, q_idx)
480
+ best_idx = cross_info[2]
481
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
482
+ select_idx=q_idx, best_idx=None)
483
+ out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
484
+ select_idx=None, best_idx=best_idx)
485
+ return out_left, out_right, "Updated overlays."
486
+ elif left_state is not None:
487
+ q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
488
+ cos_self, _, _, _ = compute_self_and_cross(left_state, None, q_idx)
489
+ out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
490
+ select_idx=q_idx, best_idx=None)
491
+ return out_left, gr.update(), "Updated overlays."
492
+ else:
493
+ q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
494
+ cos_self, _, _, _ = compute_self_and_cross(right_state, None, q_idx)
495
+ out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
496
+ select_idx=q_idx, best_idx=None)
497
+ return gr.update(), out_right, "Updated overlays."
498
+
499
+ # ---------- Gradio UI ----------
500
+ with gr.Blocks(title="DINOv3 Patch Similarity (Self & Cross)") as demo:
501
+ gr.Markdown(
502
+ """
503
+ # DINOv3 Patch Similarity (Self & Cross)
504
+ 1) Pick **Dataset** (LVD-1689M / SAT-493M).
505
+ 2) Pick **Model**.
506
+ 3) Upload one or two images (or paste URLs) and press **Initialize / Update**.
507
+ - Click on a patch to update overlays.
508
+ - In two-image mode, the non-active image hides the red selection and shows **yellow** best match.
509
+ """
510
+ )
511
+
512
+ with gr.Row():
513
+ dataset_radio = gr.Radio(
514
+ label="Dataset",
515
+ choices=list(DATASET_LABELS.keys()),
516
+ value=DEFAULT_DATASET_LABEL,
517
+ interactive=True
518
+ )
519
+ initial_key = DATASET_LABELS[DEFAULT_DATASET_LABEL]
520
+ initial_models = MODEL_OPTIONS_BY_DATASET.get(initial_key, [])
521
+ model_dropdown = gr.Dropdown(
522
+ label="Model name",
523
+ choices=initial_models,
524
+ value=(initial_models[0] if initial_models else None),
525
+ interactive=True
526
+ )
527
+
528
+ # initial sample labels based on default dataset
529
+ initial_sample_labels = [label for label, _ in SAMPLE_URL_CHOICES.get(initial_key, [])]
530
+
531
+ with gr.Row():
532
+ with gr.Column():
533
+ left_input = gr.Image(label="Left Image (upload)", type="pil",
534
+ sources=["upload", "clipboard", "webcam"], interactive=True)
535
+ left_url = gr.Textbox(label="Left Image URL (optional)", placeholder="https://...")
536
+ left_sample = gr.Dropdown(label="Use a sample URL",
537
+ choices=initial_sample_labels,
538
+ value=(initial_sample_labels[0] if initial_sample_labels else None),
539
+ interactive=True)
540
+ with gr.Column():
541
+ right_input = gr.Image(label="Right Image (upload)", type="pil",
542
+ sources=["upload", "clipboard", "webcam"], interactive=True)
543
+ right_url = gr.Textbox(label="Right Image URL (optional)", placeholder="https://...")
544
+ right_sample = gr.Dropdown(label="Use a sample URL",
545
+ choices=initial_sample_labels,
546
+ value=(initial_sample_labels[0] if initial_sample_labels else None),
547
+ interactive=True)
548
+
549
+ with gr.Accordion("Overlay Settings", open=True):
550
+ show_grid = gr.Checkbox(label="Show patch grid", value=DEFAULT_SHOW_GRID)
551
+ overlay_alpha = gr.Slider(label="Overlay alpha", minimum=0.0, maximum=1.0,
552
+ value=DEFAULT_OVERLAY_ALPHA, step=0.01)
553
+
554
+ init_btn = gr.Button("Initialize / Update", variant="primary")
555
+
556
+ with gr.Row():
557
+ left_view = gr.Image(label="LEFT (click to select patch)", interactive=True)
558
+ right_view = gr.Image(label="RIGHT (click to select patch)", interactive=True)
559
+
560
+ status = gr.Markdown("")
561
+
562
+ # Hidden states
563
+ left_state = gr.State(None)
564
+ right_state = gr.State(None)
565
+ active_side = gr.State(0)
566
+ left_idx = gr.State(-1)
567
+ right_idx = gr.State(-1)
568
+ ps_state = gr.State(16)
569
+
570
+ # Update model dropdown and sample lists when dataset changes
571
+ dataset_radio.change(
572
+ fn=update_model_and_samples,
573
+ inputs=[dataset_radio],
574
+ outputs=[model_dropdown, left_sample, right_sample]
575
+ )
576
+
577
+ # When a sample is chosen, set URL and clear any uploaded image (prefer URL)
578
+ left_sample.change(
579
+ fn=_apply_sample,
580
+ inputs=[dataset_radio, left_sample],
581
+ outputs=[left_url, left_input]
582
+ )
583
+ right_sample.change(
584
+ fn=_apply_sample,
585
+ inputs=[dataset_radio, right_sample],
586
+ outputs=[right_url, right_input]
587
+ )
588
+
589
+ # Initialize / reload model + overlays
590
+ init_btn.click(
591
+ fn=init_states,
592
+ inputs=[left_input, left_url, right_input, right_url, dataset_radio, model_dropdown, show_grid, overlay_alpha],
593
+ outputs=[left_view, right_view, left_state, right_state, active_side, left_idx, right_idx, ps_state, status],
594
+ show_progress=True
595
+ )
596
+
597
+ # Click handlers
598
+ left_view.select(
599
+ fn=on_select_left,
600
+ inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
601
+ outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
602
+ )
603
+ right_view.select(
604
+ fn=on_select_right,
605
+ inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
606
+ outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
607
+ )
608
+
609
+ # Live re-render on setting changes
610
+ show_grid.change(
611
+ fn=rebuild_with_settings,
612
+ inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
613
+ outputs=[left_view, right_view, status]
614
+ )
615
+ overlay_alpha.change(
616
+ fn=rebuild_with_settings,
617
+ inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
618
+ outputs=[left_view, right_view, status]
619
+ )
620
+
621
+ if __name__ == "__main__":
622
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ numpy==1.26.4
5
+ ipykernel
6
+ ipywidgets
7
+ ipympl
8
+
9
+ # Install Transformers directly from GitHub source
10
+ transformers @ git+https://github.com/huggingface/transformers.git