import gradio as gr import torch import numpy as np from PIL import Image import cv2 from transformers import AutoImageProcessor, AutoModel import torch.nn.functional as F import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") DINO_MODELS = { "DINOv3 Base ViT": "facebook/dinov3-vitb16-pretrain-lvd1689m", "DINOv3 Large ViT": "facebook/dinov3-vitl16-pretrain-lvd1689m", "DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m" } def load_model(model_name): global processor, model model_path = DINO_MODELS[model_name] processor = AutoImageProcessor.from_pretrained(model_path) model = AutoModel.from_pretrained(model_path) model = model.to(device) return f"✅ Model '{model_name}' loaded successfully!" load_model("DINOv3 Base ViT") @spaces.GPU() def extract_features(image): original_size = image.size inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} model_size = processor.size['height'] with torch.no_grad(): outputs = model(**inputs) features = outputs.last_hidden_state return features, original_size, model_size def find_correspondences(features1, features2, threshold=0.8): B, N1, D = features1.shape B, N2, D = features2.shape features1_norm = F.normalize(features1, dim=-1) features2_norm = F.normalize(features2, dim=-1) similarity = torch.matmul(features1_norm, features2_norm.transpose(-2, -1)) matches1 = torch.argmax(similarity, dim=-1) matches2 = torch.argmax(similarity, dim=-2) max_sim1 = torch.max(similarity, dim=-1)[0] max_sim2 = torch.max(similarity, dim=-2)[0] mutual_matches = matches2[0, matches1[0]] == torch.arange(N1).to(device) good_matches = (max_sim1[0] > threshold) & mutual_matches return matches1[0][good_matches], torch.arange(N1).to(device)[good_matches], max_sim1[0][good_matches] def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14): orig_w, orig_h = original_size patches_h = model_size // patch_size patches_w = model_size // patch_size if patch_idx >= patches_h * patches_w: return None, None patch_y = patch_idx // patches_w patch_x = patch_idx % patches_w y_model = patch_y * patch_size + patch_size // 2 x_model = patch_x * patch_size + patch_size // 2 x = int(x_model * orig_w / model_size) y = int(y_model * orig_h / model_size) return x, y def match_keypoints(image1, image2, model_name): if image1 is None or image2 is None: return None load_model(model_name) img1_pil = Image.fromarray(image1).convert('RGB') img2_pil = Image.fromarray(image2).convert('RGB') features1, original_size1, model_size1 = extract_features(img1_pil) features2, original_size2, model_size2 = extract_features(img2_pil) features1 = features1[:, 1:, :] features2 = features2[:, 1:, :] matches2_idx, matches1_idx, similarities = find_correspondences(features1, features2, threshold=0.7) img1_np = np.array(img1_pil) img2_np = np.array(img2_pil) h1, w1 = img1_np.shape[:2] h2, w2 = img2_np.shape[:2] result_img = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8) result_img[:h1, :w1] = img1_np result_img[:h2, w1:w1+w2] = img2_np colors = [] keypoints1 = [] keypoints2 = [] for i, (m1, m2, sim) in enumerate(zip(matches1_idx.cpu(), matches2_idx.cpu(), similarities.cpu())): x1, y1 = patch_to_image_coords(m1.item(), original_size1, model_size1) x2, y2 = patch_to_image_coords(m2.item(), original_size2, model_size2) if x1 is not None and x2 is not None: color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) colors.append(color) keypoints1.append((x1, y1)) keypoints2.append((x2 + w1, y2)) cv2.circle(result_img, (x1, y1), 15, color, -1) cv2.circle(result_img, (x2 + w1, y2), 15, color, -1) cv2.line(result_img, (x1, y1), (x2 + w1, y2), color, 10) return result_img load_model("DINOv3 Base ViT") with gr.Blocks(title="DINOv3 Keypoint Matching") as demo: gr.Markdown("# DINOv3 For Keypoint Matching") gr.Markdown("DINOv3 can be used to find matching features between two images.") gr.Markdown("Upload two images to find corresponding keypoints using DINOv3 features, switch between different DINOv3 checkpoints.") with gr.Row(): image1 = gr.Image(label="Image 1", type="numpy") image2 = gr.Image(label="Image 2", type="numpy") with gr.Column(scale=1): model_selector = gr.Dropdown( choices=list(DINO_MODELS.keys()), value="DINOv3 Base ViT", label="Select DINOv3 Model", info="Choose the model size. Larger models may provide better features but require more memory." ) # Add status bar status_bar = gr.Textbox( value="✅ Model 'DINOv3 Base ViT' loaded successfully!", label="Status", interactive=False, container=False ) match_btn = gr.Button("Find Correspondences", variant="primary") with gr.Column(scale=2): output_image = gr.Image(label="Matched Keypoints") model_selector.change( fn=load_model, inputs=[model_selector], outputs=[status_bar] ) match_btn.click( fn=match_keypoints, inputs=[image1, image2, model_selector], outputs=[output_image] ) gr.Examples( examples=[["map.jpg", "street.jpg"], ["bee.JPG", "bee_edited.jpg"]], inputs=[image1, image2] ) if __name__ == "__main__": demo.launch(share=True)