merve HF Staff commited on
Commit
dde52b8
·
verified ·
1 Parent(s): 78fcdc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -94
app.py CHANGED
@@ -7,176 +7,120 @@ from transformers import AutoImageProcessor, AutoModel
7
  import torch.nn.functional as F
8
  import spaces
9
 
10
-
11
  DINO_MODELS = {
12
- "DINOv3 Base ViT": "facebook/dinov3-vitb16-pretrain-lvd1689m",
13
  "DINOv3 Large ViT": "facebook/dinov3-vitl16-pretrain-lvd1689m",
14
- "DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m"
15
  }
16
 
 
 
17
 
18
  def load_model(model_name):
19
- global processor, model
20
  model_path = DINO_MODELS[model_name]
21
-
22
  processor = AutoImageProcessor.from_pretrained(model_path)
23
- model = AutoModel.from_pretrained(model_path)
24
  return f"✅ Model '{model_name}' loaded successfully!"
25
 
26
- load_model("DINOv3 Base ViT")
27
-
28
  @spaces.GPU()
29
- def extract_features(image):
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
- model = model.to(device)
 
 
 
 
 
33
  original_size = image.size
34
- inputs = processor(images=image, return_tensors="pt")
35
- inputs = {k: v.to(device) for k, v in inputs.items()}
36
- model_size = processor.size['height']
37
-
38
  with torch.no_grad():
39
  outputs = model(**inputs)
40
  features = outputs.last_hidden_state
41
-
42
- return features, original_size, model_size
43
 
44
  def find_correspondences(features1, features2, threshold=0.8):
 
45
  B, N1, D = features1.shape
46
- B, N2, D = features2.shape
47
-
48
  features1_norm = F.normalize(features1, dim=-1)
49
  features2_norm = F.normalize(features2, dim=-1)
50
-
51
  similarity = torch.matmul(features1_norm, features2_norm.transpose(-2, -1))
52
-
53
  matches1 = torch.argmax(similarity, dim=-1)
54
  matches2 = torch.argmax(similarity, dim=-2)
55
-
56
  max_sim1 = torch.max(similarity, dim=-1)[0]
57
- max_sim2 = torch.max(similarity, dim=-2)[0]
58
-
59
- mutual_matches = matches2[0, matches1[0]] == torch.arange(N1).to(device)
60
  good_matches = (max_sim1[0] > threshold) & mutual_matches
61
-
62
- return matches1[0][good_matches], torch.arange(N1).to(device)[good_matches], max_sim1[0][good_matches]
63
 
64
  def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14):
65
  orig_w, orig_h = original_size
66
  patches_h = model_size // patch_size
67
  patches_w = model_size // patch_size
68
-
69
  if patch_idx >= patches_h * patches_w:
70
  return None, None
71
-
72
  patch_y = patch_idx // patches_w
73
  patch_x = patch_idx % patches_w
74
-
75
  y_model = patch_y * patch_size + patch_size // 2
76
  x_model = patch_x * patch_size + patch_size // 2
77
-
78
  x = int(x_model * orig_w / model_size)
79
  y = int(y_model * orig_h / model_size)
80
-
81
  return x, y
82
 
83
  def match_keypoints(image1, image2, model_name):
84
  if image1 is None or image2 is None:
85
  return None
86
-
87
  load_model(model_name)
88
-
89
- img1_pil = Image.fromarray(image1).convert('RGB')
90
- img2_pil = Image.fromarray(image2).convert('RGB')
91
-
92
- features1, original_size1, model_size1 = extract_features(img1_pil)
93
- features2, original_size2, model_size2 = extract_features(img2_pil)
94
-
95
- features1 = features1[:, 1:, :]
96
- features2 = features2[:, 1:, :]
97
-
98
  matches2_idx, matches1_idx, similarities = find_correspondences(features1, features2, threshold=0.7)
99
-
100
  img1_np = np.array(img1_pil)
101
  img2_np = np.array(img2_pil)
102
-
103
  h1, w1 = img1_np.shape[:2]
104
  h2, w2 = img2_np.shape[:2]
105
-
106
  result_img = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8)
107
  result_img[:h1, :w1] = img1_np
108
- result_img[:h2, w1:w1+w2] = img2_np
109
-
110
- colors = []
111
- keypoints1 = []
112
- keypoints2 = []
113
-
114
- for i, (m1, m2, sim) in enumerate(zip(matches1_idx.cpu(), matches2_idx.cpu(), similarities.cpu())):
115
- x1, y1 = patch_to_image_coords(m1.item(), original_size1, model_size1)
116
- x2, y2 = patch_to_image_coords(m2.item(), original_size2, model_size2)
117
-
118
  if x1 is not None and x2 is not None:
119
- color = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
120
- colors.append(color)
121
- keypoints1.append((x1, y1))
122
- keypoints2.append((x2 + w1, y2))
123
-
124
- cv2.circle(result_img, (x1, y1), 15, color, -1)
125
- cv2.circle(result_img, (x2 + w1, y2), 15, color, -1)
126
- cv2.line(result_img, (x1, y1), (x2 + w1, y2), color, 10)
127
-
128
-
129
-
130
  return result_img
131
 
132
- load_model("DINOv3 Base ViT")
133
-
134
  with gr.Blocks(title="DINOv3 Keypoint Matching") as demo:
135
  gr.Markdown("# DINOv3 For Keypoint Matching")
136
  gr.Markdown("DINOv3 can be used to find matching features between two images.")
137
  gr.Markdown("Upload two images to find corresponding keypoints using DINOv3 features, switch between different DINOv3 checkpoints.")
138
-
139
  with gr.Row():
140
  image1 = gr.Image(label="Image 1", type="numpy")
141
  image2 = gr.Image(label="Image 2", type="numpy")
142
  with gr.Column(scale=1):
143
-
144
  model_selector = gr.Dropdown(
145
  choices=list(DINO_MODELS.keys()),
146
- value="DINOv3 Base ViT",
147
  label="Select DINOv3 Model",
148
- info="Choose the model size. Larger models may provide better features but require more memory."
149
  )
150
-
151
- # Add status bar
152
  status_bar = gr.Textbox(
153
- value="✅ Model 'DINOv3 Base ViT' loaded successfully!",
154
  label="Status",
155
  interactive=False,
156
- container=False
157
  )
158
-
159
  match_btn = gr.Button("Find Correspondences", variant="primary")
160
-
161
  with gr.Column(scale=2):
162
  output_image = gr.Image(label="Matched Keypoints")
163
-
164
- model_selector.change(
165
- fn=load_model,
166
- inputs=[model_selector],
167
- outputs=[status_bar]
168
- )
169
-
170
- match_btn.click(
171
- fn=match_keypoints,
172
- inputs=[image1, image2, model_selector],
173
- outputs=[output_image]
174
- )
175
-
176
  gr.Examples(
177
  examples=[["map.jpg", "street.jpg"], ["bee.JPG", "bee_edited.jpg"]],
178
- inputs=[image1, image2]
179
  )
180
 
181
  if __name__ == "__main__":
182
- demo.launch(share=True)
 
7
  import torch.nn.functional as F
8
  import spaces
9
 
 
10
  DINO_MODELS = {
11
+ "DINOv3 Base ViT": "facebook/dinov3-vitb16-pretrain-lvd1689m",
12
  "DINOv3 Large ViT": "facebook/dinov3-vitl16-pretrain-lvd1689m",
13
+ "DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m",
14
  }
15
 
16
+ _default_model_name = "DINOv3 Base ViT"
17
+ processor = AutoImageProcessor.from_pretrained(DINO_MODELS[_default_model_name])
18
 
19
  def load_model(model_name):
20
+ global processor
21
  model_path = DINO_MODELS[model_name]
 
22
  processor = AutoImageProcessor.from_pretrained(model_path)
 
23
  return f"✅ Model '{model_name}' loaded successfully!"
24
 
 
 
25
  @spaces.GPU()
26
+ def extract_features(image, model_name):
 
27
 
28
+ model_id = DINO_MODELS[model_name]
29
+ model = AutoModel.from_pretrained(model_id).to("cuda").eval()
30
+ local_processor = AutoImageProcessor.from_pretrained(model_id)
31
+ inputs = local_processor(images=image, return_tensors="pt")
32
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
33
+ model_size = local_processor.size["height"]
34
  original_size = image.size
 
 
 
 
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
  features = outputs.last_hidden_state
38
+ return features[:, 1:, :].float().cpu(), original_size, model_size
 
39
 
40
  def find_correspondences(features1, features2, threshold=0.8):
41
+ device = torch.device("cpu")
42
  B, N1, D = features1.shape
43
+ _, N2, _ = features2.shape
 
44
  features1_norm = F.normalize(features1, dim=-1)
45
  features2_norm = F.normalize(features2, dim=-1)
 
46
  similarity = torch.matmul(features1_norm, features2_norm.transpose(-2, -1))
 
47
  matches1 = torch.argmax(similarity, dim=-1)
48
  matches2 = torch.argmax(similarity, dim=-2)
 
49
  max_sim1 = torch.max(similarity, dim=-1)[0]
50
+ arange1 = torch.arange(N1, device=device)
51
+ mutual_matches = matches2[0, matches1[0]] == arange1
 
52
  good_matches = (max_sim1[0] > threshold) & mutual_matches
53
+ return matches1[0][good_matches].cpu(), arange1[good_matches].cpu(), max_sim1[0][good_matches].cpu()
 
54
 
55
  def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14):
56
  orig_w, orig_h = original_size
57
  patches_h = model_size // patch_size
58
  patches_w = model_size // patch_size
 
59
  if patch_idx >= patches_h * patches_w:
60
  return None, None
 
61
  patch_y = patch_idx // patches_w
62
  patch_x = patch_idx % patches_w
 
63
  y_model = patch_y * patch_size + patch_size // 2
64
  x_model = patch_x * patch_size + patch_size // 2
 
65
  x = int(x_model * orig_w / model_size)
66
  y = int(y_model * orig_h / model_size)
 
67
  return x, y
68
 
69
  def match_keypoints(image1, image2, model_name):
70
  if image1 is None or image2 is None:
71
  return None
 
72
  load_model(model_name)
73
+ img1_pil = Image.fromarray(image1).convert("RGB")
74
+ img2_pil = Image.fromarray(image2).convert("RGB")
75
+ features1, original_size1, model_size1 = extract_features(img1_pil, model_name)
76
+ features2, original_size2, model_size2 = extract_features(img2_pil, model_name)
 
 
 
 
 
 
77
  matches2_idx, matches1_idx, similarities = find_correspondences(features1, features2, threshold=0.7)
 
78
  img1_np = np.array(img1_pil)
79
  img2_np = np.array(img2_pil)
 
80
  h1, w1 = img1_np.shape[:2]
81
  h2, w2 = img2_np.shape[:2]
 
82
  result_img = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8)
83
  result_img[:h1, :w1] = img1_np
84
+ result_img[:h2, w1:w1 + w2] = img2_np
85
+ for m1, m2, _ in zip(matches1_idx, matches2_idx, similarities):
86
+ x1, y1 = patch_to_image_coords(int(m1), original_size1, model_size1)
87
+ x2, y2 = patch_to_image_coords(int(m2), original_size2, model_size2)
 
 
 
 
 
 
88
  if x1 is not None and x2 is not None:
89
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
90
+ cv2.circle(result_img, (x1, y1), 6, color, -1)
91
+ cv2.circle(result_img, (x2 + w1, y2), 6, color, -1)
92
+ cv2.line(result_img, (x1, y1), (x2 + w1, y2), color, 2)
 
 
 
 
 
 
 
93
  return result_img
94
 
 
 
95
  with gr.Blocks(title="DINOv3 Keypoint Matching") as demo:
96
  gr.Markdown("# DINOv3 For Keypoint Matching")
97
  gr.Markdown("DINOv3 can be used to find matching features between two images.")
98
  gr.Markdown("Upload two images to find corresponding keypoints using DINOv3 features, switch between different DINOv3 checkpoints.")
 
99
  with gr.Row():
100
  image1 = gr.Image(label="Image 1", type="numpy")
101
  image2 = gr.Image(label="Image 2", type="numpy")
102
  with gr.Column(scale=1):
 
103
  model_selector = gr.Dropdown(
104
  choices=list(DINO_MODELS.keys()),
105
+ value=_default_model_name,
106
  label="Select DINOv3 Model",
107
+ info="Choose the model size. Larger models may provide better features but require more memory.",
108
  )
 
 
109
  status_bar = gr.Textbox(
110
+ value=f"✅ Model '{_default_model_name}' ready.",
111
  label="Status",
112
  interactive=False,
113
+ container=False,
114
  )
 
115
  match_btn = gr.Button("Find Correspondences", variant="primary")
 
116
  with gr.Column(scale=2):
117
  output_image = gr.Image(label="Matched Keypoints")
118
+ model_selector.change(fn=load_model, inputs=[model_selector], outputs=[status_bar])
119
+ match_btn.click(fn=match_keypoints, inputs=[image1, image2, model_selector], outputs=[output_image])
 
 
 
 
 
 
 
 
 
 
 
120
  gr.Examples(
121
  examples=[["map.jpg", "street.jpg"], ["bee.JPG", "bee_edited.jpg"]],
122
+ inputs=[image1, image2],
123
  )
124
 
125
  if __name__ == "__main__":
126
+ demo.launch()