File size: 27,634 Bytes
faca9e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
# app.py
# Gradio UI for interactive DINOv3 patch similarity (single or dual image)
# - No AutoImageProcessor, no resize (only pad to multiple of patch size)
# - Single image: click to show self-similarity; selected cell outlined in RED
# - Two images: click on one side -> self overlay on source, cross overlay on target; best match on target outlined in YELLOW
# - Red selection rectangle is hidden on the non-active image
# - Patch size inferred from model (no override). Patch indices are not annotated.
# - Dataset selector (LVD-1689M / SAT-493M); model dropdown shows only the short name between "dinov3-" and "-pretrain".
# - Sample URL dropdowns switch between LVD (COCO/Picsum) and SAT (satellite imagery) and auto-fill / clear uploads.

import io
import math
import urllib.request
from functools import lru_cache
from typing import Optional, Tuple, Dict, List

import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
from transformers import AutoModel
from matplotlib import colormaps as cm

# ---------- Provided model IDs (ground truth list) ----------
MODEL_ID_LIST = [
    "facebook/dinov3-vits16-pretrain-lvd1689m",
    "facebook/dinov3-vits16plus-pretrain-lvd1689m",
    "facebook/dinov3-vitb16-pretrain-lvd1689m",
    "facebook/dinov3-vitl16-pretrain-lvd1689m",
    "facebook/dinov3-vith16plus-pretrain-lvd1689m",
    "facebook/dinov3-vit7b16-pretrain-lvd1689m",
    "facebook/dinov3-convnext-tiny-pretrain-lvd1689m",
    "facebook/dinov3-convnext-small-pretrain-lvd1689m",
    "facebook/dinov3-convnext-base-pretrain-lvd1689m",
    "facebook/dinov3-convnext-large-pretrain-lvd1689m",
    "facebook/dinov3-vitl16-pretrain-sat493m",
    "facebook/dinov3-vit7b16-pretrain-sat493m",
]

DATASET_LABELS = {
    "LVD-1689M": "lvd1689m",
    "SAT-493M": "sat493m",
}

def build_model_maps(model_ids: List[str]):
    """

    Returns:

      valid_map[(dataset_key, short_name)] -> full_model_id

      options_by_dataset[dataset_key] -> [short_name,...]  (display order preserved)

    """
    valid_map: Dict[Tuple[str, str], str] = {}
    options_by_dataset: Dict[str, List[str]] = {"lvd1689m": [], "sat493m": []}

    for mid in model_ids:
        # Expect pattern: "facebook/dinov3-<short>-pretrain-<dataset>"
        try:
            prefix = "facebook/dinov3-"
            start = mid.index(prefix) + len(prefix)
            pre_idx = mid.index("-pretrain", start)
            short = mid[start:pre_idx]
            dataset = mid.split("-pretrain-")[-1].strip()
        except Exception:
            # Skip anything that doesn't match the expected pattern
            continue

        key = (dataset, short)
        valid_map[key] = mid
        if dataset in options_by_dataset and short not in options_by_dataset[dataset]:
            options_by_dataset[dataset].append(short)

    return valid_map, options_by_dataset

VALID_MODEL_MAP, MODEL_OPTIONS_BY_DATASET = build_model_maps(MODEL_ID_LIST)

# ---------- Defaults / knobs ----------
DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
DEFAULT_DATASET_LABEL = "LVD-1689M"  # initial radio
DEFAULT_OVERLAY_ALPHA = 0.55
DEFAULT_SHOW_GRID = True

# ---------- Sample image URLs (dependent on dataset) ----------
SAMPLE_URL_CHOICES: Dict[str, List[Tuple[str, str]]] = {
    # LVD: current ones
    "lvd1689m": [
        ("– choose a sample –", ""),
        ("COCO: 2 Cats on sofa (039769)", "http://images.cocodataset.org/val2017/000000039769.jpg"),
        ("COCO: Person skiing (000785)", "http://images.cocodataset.org/val2017/000000000785.jpg"),
        ("COCO: People running (000872)", "http://images.cocodataset.org/val2017/000000000872.jpg"),
        ("Picsum: Mountain (ID=1000)", "https://picsum.photos/id/1000/800/600"),
        ("Picsum: Kayak (ID=1011)", "https://picsum.photos/id/1011/800/600"),
        ("Picsum: Man and dog (ID=1012)", "https://picsum.photos/id/1012/800/600"),
    ],
    # SAT: satellite imagery examples
    "sat493m": [
        ("– choose a satellite sample –", ""),
        ("Blue Marble (NASA)", "https://upload.wikimedia.org/wikipedia/commons/9/9d/The_Blue_Marble_%28remastered%29.jpg"),
        ("GOES-16 Hurricane Florence (2018)", "https://upload.wikimedia.org/wikipedia/commons/5/5e/Hurricane_Florence_GOES-16_2018-09-12_1510Z.jpg"),
        ("NASA Earth Observatory: Philippines", "https://eoimages.gsfc.nasa.gov/images/imagerecords/151000/151639/philippines_tmo_2020118_lrg.jpg"),
    ],
}

def _sample_labels_for(dataset_label: str):
    key = DATASET_LABELS.get(dataset_label, "lvd1689m")
    return [label for label, _ in SAMPLE_URL_CHOICES.get(key, [])]

def _apply_sample(dataset_label: str, sample_label: str):
    """Fill textbox with chosen sample URL and clear any uploaded image."""
    key = DATASET_LABELS.get(dataset_label, "lvd1689m")
    sample_map = dict(SAMPLE_URL_CHOICES.get(key, []))
    url = sample_map.get(sample_label, "")
    return gr.update(value=url), None  # (textbox update, clear upload)

# ---------- Utility ----------
def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]:
    # Prefer URL if present
    if url and str(url).strip().lower().startswith(("http://", "https://")):
        with urllib.request.urlopen(url) as resp:
            data = resp.read()
        return Image.open(io.BytesIO(data)).convert("RGB")
    if isinstance(src, Image.Image):
        return src.convert("RGB")
    return None

def pad_to_multiple(pil_img: Image.Image, multiple: int = 16) -> Tuple[Image.Image, Tuple[int, int, int, int]]:
    W, H = pil_img.size
    H_pad = int(math.ceil(H / multiple) * multiple)
    W_pad = int(math.ceil(W / multiple) * multiple)
    if (H_pad, W_pad) == (H, W):
        return pil_img, (0, 0, 0, 0)
    canvas = Image.new("RGB", (W_pad, H_pad), (0, 0, 0))
    canvas.paste(pil_img, (0, 0))
    return canvas, (0, 0, W_pad - W, H_pad - H)

def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16):
    img_padded, pad_box = pad_to_multiple(pil_img, multiple=multiple)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std =[0.229, 0.224, 0.225]),
    ])
    pixel_tensor = transform(img_padded).unsqueeze(0)  # (1,3,H,W)
    disp_np = np.array(img_padded, dtype=np.uint8)
    return {"pixel_values": pixel_tensor}, disp_np, pad_box

def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int) -> np.ndarray:
    if arr.ndim == 2:
        return arr.repeat(ps, 0).repeat(ps, 1)
    elif arr.ndim == 3:
        rows, cols, ch = arr.shape
        arr2 = arr.repeat(ps, 0).repeat(ps, 1)
        return arr2.reshape(rows * ps, cols * ps, ch)
    raise ValueError("upsample_nearest expects (rows,cols) or (rows,cols,channels)")

def blend_overlay(base_uint8: np.ndarray, overlay_rgb_float: np.ndarray, alpha: float) -> np.ndarray:
    base = base_uint8.astype(np.float32)
    over = (overlay_rgb_float * 255.0).astype(np.float32)
    out = (1.0 - alpha) * base + alpha * over
    return np.clip(out, 0, 255).astype(np.uint8)

def draw_grid(img: Image.Image, rows: int, cols: int, ps: int):
    d = ImageDraw.Draw(img)
    W, H = img.size
    for r in range(1, rows):
        y = r * ps
        d.line([(0, y), (W, y)], fill=(255, 255, 255), width=1)
    for c in range(1, cols):
        x = c * ps
        d.line([(x, 0), (x, H)], fill=(255, 255, 255), width=1)

def rc_to_idx(r: int, c: int, cols: int) -> int:
    return int(r) * cols + int(c)

def idx_to_rc(i: int, cols: int) -> Tuple[int, int]:
    return int(i) // cols, int(i) % cols

# ---------- Model cache ----------
@lru_cache(maxsize=3)
def load_model_cached(full_model_id: str, device_str: str):
    device = torch.device(device_str)
    model = AutoModel.from_pretrained(full_model_id).to(device)
    model.eval()
    return model

def infer_patch_size(model, default: int = 16) -> int:
    if hasattr(model, "config") and hasattr(model.config, "patch_size"):
        ps = model.config.patch_size
        if isinstance(ps, (tuple, list)): return int(ps[0])
        return int(ps)
    if hasattr(model, "patch_size"):
        ps = model.patch_size
        if isinstance(ps, (tuple, list)): return int(ps[0])
        return int(ps)
    return default

# ---------- Per-image state ----------
class PatchImageState:
    def __init__(self, pil_img: Image.Image, model, device_str: str, ps: int):
        self.pil = pil_img
        self.ps = ps
        inputs, disp_np, _ = preprocess_no_resize(pil_img, multiple=ps)
        self.disp = disp_np
        pv = inputs["pixel_values"].to(device_str)  # (1,3,H,W)
        _, _, H, W = pv.shape
        self.H, self.W = int(H), int(W)
        self.rows, self.cols = self.H // ps, self.W // ps

        with torch.no_grad():
            out = model(pixel_values=pv)
        hs = out.last_hidden_state.squeeze(0).detach().cpu().numpy()  # (T,D)

        T, D = hs.shape
        n_patches = self.rows * self.cols
        n_special = T - n_patches  # class + maybe registers
        if n_special < 1:
            raise RuntimeError(
                f"Token mismatch: T={T}, rows*cols={n_patches}, HxW={self.H}x{self.W}, ps={ps}"
            )
        self.D = D
        patches = hs[n_special:, :].reshape(self.rows, self.cols, D)
        self.X = patches.reshape(-1, D)
        self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8)

# ---------- Rendering / compute ----------
def render_with_cosmap(

    st: PatchImageState,

    cos_map: Optional[np.ndarray],

    overlay_alpha: float,

    show_grid_flag: bool,

    select_idx: Optional[int] = None,

    best_idx: Optional[int] = None,

) -> Image.Image:
    H, W, ps = st.H, st.W, st.ps
    rows, cols = st.rows, st.cols

    if cos_map is None:
        disp = np.full((rows, cols), 0.5, dtype=np.float32)
    else:
        vmin, vmax = float(cos_map.min()), float(cos_map.max())
        rng = vmax - vmin if vmax > vmin else 1e-8
        disp = (cos_map - vmin) / rng

    cmap = cm.get_cmap("magma")
    rgba = cmap(disp)
    rgb = rgba[..., :3]

    if select_idx is not None:
        rs, cs = idx_to_rc(select_idx, cols)
        rgb[rs, cs, :] = np.array([1.0, 0.0, 0.0], dtype=np.float32)

    over_rgb_up = upsample_nearest(rgb, H, W, ps)
    blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha))
    pil = Image.fromarray(blended)

    draw = ImageDraw.Draw(pil)
    if show_grid_flag:
        draw_grid(pil, rows, cols, ps)

    if select_idx is not None:
        r, c = idx_to_rc(select_idx, cols)
        x0, y0 = c * ps, r * ps
        x1, y1 = x0 + ps - 1, y0 + ps - 1
        draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=2)

    if best_idx is not None:
        r, c = idx_to_rc(best_idx, cols)
        x0, y0 = c * ps, r * ps
        x1, y1 = x0 + ps - 1, y0 + ps - 1
        draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 255, 0), width=2)

    return pil

def compute_self_and_cross(

    src: PatchImageState,

    tgt: Optional[PatchImageState],

    q_idx: int,

):
    q = src.X[q_idx]
    qn = q / (np.linalg.norm(q) + 1e-8)

    cos_self = src.Xn @ qn
    cos_map_self = cos_self.reshape(src.rows, src.cols)
    self_stats = (float(cos_map_self.min()), float(cos_map_self.max()))

    cross_result = None
    cos_map_cross = None
    if tgt is not None:
        cos_cross = tgt.Xn @ qn
        cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols)
        cross_min, cross_max = float(cos_map_cross.min()), float(cos_map_cross.max())
        best_idx = int(np.argmax(cos_cross))
        cross_result = (cross_min, cross_max, best_idx)

    return cos_map_self, cos_map_cross, self_stats, cross_result

# ---------- Gradio helpers for model & samples ----------
def dataset_label_to_key(label: str) -> str:
    return DATASET_LABELS.get(label, "lvd1689m")

def update_model_dropdown(dataset_label: str):
    key = dataset_label_to_key(dataset_label)
    opts = MODEL_OPTIONS_BY_DATASET.get(key, [])
    default_val = opts[0] if opts else None
    return gr.update(choices=opts, value=default_val)

def update_model_and_samples(dataset_label: str):
    # Update model dropdown
    model_update = update_model_dropdown(dataset_label)
    # Update both sample dropdowns to dataset-specific options
    labels = _sample_labels_for(dataset_label)
    sample_update = gr.update(choices=labels, value=(labels[0] if labels else None))
    return model_update, sample_update, sample_update

def resolve_full_model_id(dataset_label: str, short_name: str) -> Optional[str]:
    key = (dataset_label_to_key(dataset_label), short_name)
    return VALID_MODEL_MAP.get(key)

# ---------- Gradio callbacks ----------
def init_states(

    left_img_in: Optional[Image.Image],

    left_url: str,

    right_img_in: Optional[Image.Image],

    right_url: str,

    dataset_label: str,

    short_model: str,

    show_grid_flag: bool,

    overlay_alpha: float,

):
    # Resolve images
    left_img = load_image_from_any(left_img_in, left_url)
    right_img = load_image_from_any(right_img_in, right_url)
    if left_img is None and right_img is None:
        left_img = load_image_from_any(None, DEFAULT_URL)

    # Resolve model
    full_model_id = resolve_full_model_id(dataset_label, short_model)
    if not full_model_id:
        return (gr.update(), gr.update(), None, None, 0, -1, -1, 16,
                f"❌ Model not available: {dataset_label} / {short_model}")

    device_str = "cuda" if torch.cuda.is_available() else "cpu"
    model = load_model_cached(full_model_id, device_str)
    ps = infer_patch_size(model, 16)

    left_state = PatchImageState(left_img, model, device_str, ps) if left_img is not None else None
    right_state = PatchImageState(right_img, model, device_str, ps) if right_img is not None else None

    active_side = 0 if left_state is not None else 1

    status = f"✔ Loaded: {full_model_id}  |  ps={ps}"
    out_left, out_right = None, None

    if left_state is not None and right_state is not None:
        q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
        cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
        best_idx = cross_info[2] if cross_info else None
        out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
                                      select_idx=q_idx, best_idx=None)
        out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
                                       select_idx=None, best_idx=best_idx)
        status += (f"  |  LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]  "
                   f"|  RIGHT cross best={best_idx}")
        left_idx, right_idx = q_idx, (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
    elif left_state is not None:
        q_idx = (left_state.rows // 2) * left_state.cols + (left_state.cols // 2)
        cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
        out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
                                      select_idx=q_idx, best_idx=None)
        status += f"  |  Single LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]"
        left_idx, right_idx = q_idx, -1
    else:
        q_idx = (right_state.rows // 2) * right_state.cols + (right_state.cols // 2)
        cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
        out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
                                       select_idx=q_idx, best_idx=None)
        status += f"  |  Single RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]"
        left_idx, right_idx = -1, q_idx

    return (
        out_left, out_right,
        left_state, right_state,
        active_side,
        left_idx, right_idx,
        ps,
        status
    )

def _coords_to_idx(x: int, y: int, st: PatchImageState) -> int:
    r = int(np.clip(y // st.ps, 0, st.rows - 1))
    c = int(np.clip(x // st.ps, 0, st.cols - 1))
    return rc_to_idx(r, c, st.cols)

def on_select_left(

    evt: gr.SelectData,

    left_state: Optional[PatchImageState],

    right_state: Optional[PatchImageState],

    show_grid_flag: bool,

    overlay_alpha: float,

    ps: int,

):
    if left_state is None:
        return gr.update(), gr.update(), 0, -1, -1, "Upload/Load a LEFT image first."

    x, y = evt.index
    q_idx = _coords_to_idx(x, y, left_state)

    if right_state is not None:
        cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(left_state, right_state, q_idx)
        best_idx = cross_info[2]
        out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
                                      select_idx=q_idx, best_idx=None)
        out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
                                       select_idx=None, best_idx=best_idx)
        status = (f"LEFT {left_state.rows}x{left_state.cols} self∈[{smin:.3f},{smax:.3f}]  |  "
                  f"RIGHT cross best idx={best_idx}")
        return out_left, out_right, 0, q_idx, -1, status
    else:
        cos_self, _, (smin, smax), _ = compute_self_and_cross(left_state, None, q_idx)
        out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
                                      select_idx=q_idx, best_idx=None)
        status = f"Single LEFT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
        return out_left, gr.update(), 0, q_idx, -1, status

def on_select_right(

    evt: gr.SelectData,

    left_state: Optional[PatchImageState],

    right_state: Optional[PatchImageState],

    show_grid_flag: bool,

    overlay_alpha: float,

    ps: int,

):
    if right_state is None:
        return gr.update(), gr.update(), 1, -1, -1, "Upload/Load a RIGHT image first."

    x, y = evt.index
    q_idx = _coords_to_idx(x, y, right_state)

    if left_state is not None:
        cos_self, cos_cross, (smin, smax), cross_info = compute_self_and_cross(right_state, left_state, q_idx)
        best_idx = cross_info[2]
        out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
                                       select_idx=q_idx, best_idx=None)
        out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
                                      select_idx=None, best_idx=best_idx)
        status = (f"RIGHT {right_state.rows}x{right_state.cols} self∈[{smin:.3f},{smax:.3f}]  |  "
                  f"LEFT cross best idx={best_idx}")
        return out_left, out_right, 1, -1, q_idx, status
    else:
        cos_self, _, (smin, smax), _ = compute_self_and_cross(right_state, None, q_idx)
        out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
                                       select_idx=q_idx, best_idx=None)
        status = f"Single RIGHT • idx={q_idx} • self∈[{smin:.3f},{smax:.3f}]"
        return gr.update(), out_right, 1, -1, q_idx, status

def rebuild_with_settings(

    left_state: Optional[PatchImageState],

    right_state: Optional[PatchImageState],

    active_side: int,

    left_idx: int,

    right_idx: int,

    show_grid_flag: bool,

    overlay_alpha: float,

    ps: int,

):
    if left_state is None and right_state is None:
        return gr.update(), gr.update(), "Load an image first."

    if left_state is not None and right_state is not None:
        if active_side == 0:
            q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
            cos_self, cos_cross, _, cross_info = compute_self_and_cross(left_state, right_state, q_idx)
            best_idx = cross_info[2]
            out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
                                          select_idx=q_idx, best_idx=None)
            out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag,
                                           select_idx=None, best_idx=best_idx)
        else:
            q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
            cos_self, cos_cross, _, cross_info = compute_self_and_cross(right_state, left_state, q_idx)
            best_idx = cross_info[2]
            out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
                                           select_idx=q_idx, best_idx=None)
            out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag,
                                          select_idx=None, best_idx=best_idx)
        return out_left, out_right, "Updated overlays."
    elif left_state is not None:
        q_idx = left_idx if left_idx >= 0 else (left_state.rows//2)*left_state.cols + (left_state.cols//2)
        cos_self, _, _, _ = compute_self_and_cross(left_state, None, q_idx)
        out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag,
                                      select_idx=q_idx, best_idx=None)
        return out_left, gr.update(), "Updated overlays."
    else:
        q_idx = right_idx if right_idx >= 0 else (right_state.rows//2)*right_state.cols + (right_state.cols//2)
        cos_self, _, _, _ = compute_self_and_cross(right_state, None, q_idx)
        out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag,
                                       select_idx=q_idx, best_idx=None)
        return gr.update(), out_right, "Updated overlays."

# ---------- Gradio UI ----------
with gr.Blocks(title="DINOv3 Patch Similarity (Self & Cross)") as demo:
    gr.Markdown(
        """

        # DINOv3 Patch Similarity (Self & Cross)

        1) Pick **Dataset** (LVD-1689M / SAT-493M).  

        2) Pick **Model**.  

        3) Upload one or two images (or paste URLs) and press **Initialize / Update**.  

        - Click on a patch to update overlays.  

        - In two-image mode, the non-active image hides the red selection and shows **yellow** best match.

        """
    )

    with gr.Row():
        dataset_radio = gr.Radio(
            label="Dataset",
            choices=list(DATASET_LABELS.keys()),
            value=DEFAULT_DATASET_LABEL,
            interactive=True
        )
        initial_key = DATASET_LABELS[DEFAULT_DATASET_LABEL]
        initial_models = MODEL_OPTIONS_BY_DATASET.get(initial_key, [])
        model_dropdown = gr.Dropdown(
            label="Model name",
            choices=initial_models,
            value=(initial_models[0] if initial_models else None),
            interactive=True
        )

    # initial sample labels based on default dataset
    initial_sample_labels = [label for label, _ in SAMPLE_URL_CHOICES.get(initial_key, [])]

    with gr.Row():
        with gr.Column():
            left_input = gr.Image(label="Left Image (upload)", type="pil",
                                  sources=["upload", "clipboard", "webcam"], interactive=True)
            left_url = gr.Textbox(label="Left Image URL (optional)", placeholder="https://...")
            left_sample = gr.Dropdown(label="Use a sample URL",
                                      choices=initial_sample_labels,
                                      value=(initial_sample_labels[0] if initial_sample_labels else None),
                                      interactive=True)
        with gr.Column():
            right_input = gr.Image(label="Right Image (upload)", type="pil",
                                   sources=["upload", "clipboard", "webcam"], interactive=True)
            right_url = gr.Textbox(label="Right Image URL (optional)", placeholder="https://...")
            right_sample = gr.Dropdown(label="Use a sample URL",
                                       choices=initial_sample_labels,
                                       value=(initial_sample_labels[0] if initial_sample_labels else None),
                                       interactive=True)

    with gr.Accordion("Overlay Settings", open=True):
        show_grid = gr.Checkbox(label="Show patch grid", value=DEFAULT_SHOW_GRID)
        overlay_alpha = gr.Slider(label="Overlay alpha", minimum=0.0, maximum=1.0,
                                  value=DEFAULT_OVERLAY_ALPHA, step=0.01)

    init_btn = gr.Button("Initialize / Update", variant="primary")

    with gr.Row():
        left_view = gr.Image(label="LEFT (click to select patch)", interactive=True)
        right_view = gr.Image(label="RIGHT (click to select patch)", interactive=True)

    status = gr.Markdown("")

    # Hidden states
    left_state = gr.State(None)
    right_state = gr.State(None)
    active_side = gr.State(0)
    left_idx = gr.State(-1)
    right_idx = gr.State(-1)
    ps_state = gr.State(16)

    # Update model dropdown and sample lists when dataset changes
    dataset_radio.change(
        fn=update_model_and_samples,
        inputs=[dataset_radio],
        outputs=[model_dropdown, left_sample, right_sample]
    )

    # When a sample is chosen, set URL and clear any uploaded image (prefer URL)
    left_sample.change(
        fn=_apply_sample,
        inputs=[dataset_radio, left_sample],
        outputs=[left_url, left_input]
    )
    right_sample.change(
        fn=_apply_sample,
        inputs=[dataset_radio, right_sample],
        outputs=[right_url, right_input]
    )

    # Initialize / reload model + overlays
    init_btn.click(
        fn=init_states,
        inputs=[left_input, left_url, right_input, right_url, dataset_radio, model_dropdown, show_grid, overlay_alpha],
        outputs=[left_view, right_view, left_state, right_state, active_side, left_idx, right_idx, ps_state, status],
        show_progress=True
    )

    # Click handlers
    left_view.select(
        fn=on_select_left,
        inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
        outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
    )
    right_view.select(
        fn=on_select_right,
        inputs=[left_state, right_state, show_grid, overlay_alpha, ps_state],
        outputs=[left_view, right_view, active_side, left_idx, right_idx, status]
    )

    # Live re-render on setting changes
    show_grid.change(
        fn=rebuild_with_settings,
        inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
        outputs=[left_view, right_view, status]
    )
    overlay_alpha.change(
        fn=rebuild_with_settings,
        inputs=[left_state, right_state, active_side, left_idx, right_idx, show_grid, overlay_alpha, ps_state],
        outputs=[left_view, right_view, status]
    )

if __name__ == "__main__":
    demo.queue().launch()