Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from transformers import AutoImageProcessor, AutoModel | |
import torch.nn.functional as F | |
import spaces | |
DINO_MODELS = { | |
"DINOv3 Base ViT": "facebook/dinov3-vitb16-pretrain-lvd1689m", | |
"DINOv3 Large ViT": "facebook/dinov3-vitl16-pretrain-lvd1689m", | |
"DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m", | |
} | |
_default_model_name = "DINOv3 Base ViT" | |
processor = AutoImageProcessor.from_pretrained(DINO_MODELS[_default_model_name]) | |
def load_model(model_name): | |
global processor | |
model_path = DINO_MODELS[model_name] | |
processor = AutoImageProcessor.from_pretrained(model_path) | |
return f"β Model '{model_name}' loaded successfully!" | |
def extract_features(image, model_name): | |
model_id = DINO_MODELS[model_name] | |
model = AutoModel.from_pretrained(model_id).to("cuda").eval() | |
local_processor = AutoImageProcessor.from_pretrained(model_id) | |
inputs = local_processor(images=image, return_tensors="pt") | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
model_size = local_processor.size["height"] | |
original_size = image.size | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
features = outputs.last_hidden_state | |
num_register_tokens = getattr(model.config, "num_register_tokens", 0) | |
return features[:, 1 + num_register_tokens:, :].float().cpu(), original_size, model_size | |
def find_correspondences(features1, features2, threshold=0.8): | |
device = torch.device("cpu") | |
B, N1, D = features1.shape | |
_, N2, _ = features2.shape | |
features1_norm = F.normalize(features1, dim=-1) | |
features2_norm = F.normalize(features2, dim=-1) | |
similarity = torch.matmul(features1_norm, features2_norm.transpose(-2, -1)) | |
matches1 = torch.argmax(similarity, dim=-1) | |
matches2 = torch.argmax(similarity, dim=-2) | |
max_sim1 = torch.max(similarity, dim=-1)[0] | |
arange1 = torch.arange(N1, device=device) | |
mutual_matches = matches2[0, matches1[0]] == arange1 | |
good_matches = (max_sim1[0] > threshold) & mutual_matches | |
return matches1[0][good_matches].cpu(), arange1[good_matches].cpu(), max_sim1[0][good_matches].cpu() | |
def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14): | |
orig_w, orig_h = original_size | |
patches_h = model_size // patch_size | |
patches_w = model_size // patch_size | |
if patch_idx >= patches_h * patches_w: | |
return None, None | |
patch_y = patch_idx // patches_w | |
patch_x = patch_idx % patches_w | |
y_model = patch_y * patch_size + patch_size // 2 | |
x_model = patch_x * patch_size + patch_size // 2 | |
x = int(x_model * orig_w / model_size) | |
y = int(y_model * orig_h / model_size) | |
return x, y | |
def match_keypoints(image1, image2, model_name): | |
if image1 is None or image2 is None: | |
return None | |
load_model(model_name) | |
img1_pil = Image.fromarray(image1).convert("RGB") | |
img2_pil = Image.fromarray(image2).convert("RGB") | |
features1, original_size1, model_size1 = extract_features(img1_pil, model_name) | |
features2, original_size2, model_size2 = extract_features(img2_pil, model_name) | |
matches2_idx, matches1_idx, similarities = find_correspondences(features1, features2, threshold=0.7) | |
img1_np = np.array(img1_pil) | |
img2_np = np.array(img2_pil) | |
h1, w1 = img1_np.shape[:2] | |
h2, w2 = img2_np.shape[:2] | |
result_img = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8) | |
result_img[:h1, :w1] = img1_np | |
result_img[:h2, w1:w1 + w2] = img2_np | |
for m1, m2, _ in zip(matches1_idx, matches2_idx, similarities): | |
x1, y1 = patch_to_image_coords(int(m1), original_size1, model_size1) | |
x2, y2 = patch_to_image_coords(int(m2), original_size2, model_size2) | |
if x1 is not None and x2 is not None: | |
color = tuple(np.random.randint(0, 255, size=3).tolist()) | |
cv2.circle(result_img, (x1, y1), 6, color, -1) | |
cv2.circle(result_img, (x2 + w1, y2), 6, color, -1) | |
cv2.line(result_img, (x1, y1), (x2 + w1, y2), color, 2) | |
return result_img | |
with gr.Blocks(title="DINOv3 Keypoint Matching") as demo: | |
gr.Markdown("# DINOv3 For Keypoint Matching") | |
gr.Markdown("DINOv3 can be used to find matching features between two images.") | |
gr.Markdown("Upload two images to find corresponding keypoints using DINOv3 features, switch between different DINOv3 checkpoints.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(scale=1): | |
image1 = gr.Image(label="Image 1", type="numpy") | |
image2 = gr.Image(label="Image 2", type="numpy") | |
model_selector = gr.Dropdown( | |
choices=list(DINO_MODELS.keys()), | |
value=_default_model_name, | |
label="Select DINOv3 Model", | |
info="Choose the model size. Larger models may provide better features but require more memory.", | |
) | |
status_bar = gr.Textbox( | |
value=f"β Model '{_default_model_name}' ready.", | |
label="Status", | |
interactive=False, | |
container=False, | |
) | |
match_btn = gr.Button("Find Correspondences", variant="primary") | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Matched Keypoints") | |
model_selector.change(fn=load_model, inputs=[model_selector], outputs=[status_bar]) | |
match_btn.click(fn=match_keypoints, inputs=[image1, image2, model_selector], outputs=[output_image]) | |
gr.Examples( | |
examples=[["map.jpg", "street.jpg"], ["bee.JPG", "bee_edited.jpg"]], | |
inputs=[image1, image2], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |