merve HF Staff commited on
Commit
a1b7c96
·
verified ·
1 Parent(s): 89718c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -29
app.py CHANGED
@@ -15,35 +15,27 @@ DINO_MODELS = {
15
  "DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m"
16
  }
17
 
18
- current_processor = None
19
- current_model = None
20
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  def load_model(model_name):
23
- global current_processor, current_model
24
-
25
  model_path = DINO_MODELS[model_name]
26
-
27
- try:
28
- current_processor = AutoImageProcessor.from_pretrained(model_path)
29
- current_model = AutoModel.from_pretrained(model_path)
30
- current_model = current_model.to(DEVICE)
31
- return f"✅ Model '{model_name}' loaded successfully!"
32
- except Exception as e:
33
- return f"❌ Error loading model '{model_name}': {str(e)}"
34
 
35
  @spaces.GPU()
36
- def extract_features(image):
37
-
38
- original_size = image.size
39
-
40
- inputs = current_processor(images=image, return_tensors="pt")
41
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
42
-
43
- model_size = current_processor.size['height']
44
 
45
  with torch.no_grad():
46
- outputs = current_model(**inputs)
47
  features = outputs.last_hidden_state
48
 
49
  return features, original_size, model_size
@@ -63,14 +55,13 @@ def find_correspondences(features1, features2, threshold=0.8):
63
  max_sim1 = torch.max(similarity, dim=-1)[0]
64
  max_sim2 = torch.max(similarity, dim=-2)[0]
65
 
66
- mutual_matches = matches2[0, matches1[0]] == torch.arange(N1).to(DEVICE)
67
  good_matches = (max_sim1[0] > threshold) & mutual_matches
68
 
69
- return matches1[0][good_matches], torch.arange(N1).to(DEVICE)[good_matches], max_sim1[0][good_matches]
70
 
71
  def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14):
72
  orig_w, orig_h = original_size
73
-
74
  patches_h = model_size // patch_size
75
  patches_w = model_size // patch_size
76
 
@@ -90,14 +81,13 @@ def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14):
90
 
91
  def match_keypoints(image1, image2, model_name):
92
  if image1 is None or image2 is None:
93
- return None, "Please upload both images"
94
 
95
  load_model(model_name)
96
 
97
  img1_pil = Image.fromarray(image1).convert('RGB')
98
  img2_pil = Image.fromarray(image2).convert('RGB')
99
 
100
-
101
  features1, original_size1, model_size1 = extract_features(img1_pil)
102
  features2, original_size2, model_size2 = extract_features(img2_pil)
103
 
@@ -170,7 +160,6 @@ with gr.Blocks(title="DINOv3 Keypoint Matching") as demo:
170
  with gr.Column(scale=2):
171
  output_image = gr.Image(label="Matched Keypoints")
172
 
173
- # Connect model selector to status bar
174
  model_selector.change(
175
  fn=load_model,
176
  inputs=[model_selector],
@@ -189,4 +178,4 @@ with gr.Blocks(title="DINOv3 Keypoint Matching") as demo:
189
  )
190
 
191
  if __name__ == "__main__":
192
- demo.launch()
 
15
  "DINOv3 Large ConvNeXT": "facebook/dinov3-convnext-large-pretrain-lvd1689m"
16
  }
17
 
 
 
 
18
 
19
  def load_model(model_name):
20
+ global processor, model
 
21
  model_path = DINO_MODELS[model_name]
22
+
23
+ processor = AutoImageProcessor.from_pretrained(model_path)
24
+ model = AutoModel.from_pretrained(model_path)
25
+ model = model.to(device)
26
+ return f"✅ Model '{model_name}' loaded successfully!"
27
+
28
+ load_model("DINOv3 Base ViT")
 
29
 
30
  @spaces.GPU()
31
+ def extract_features(image):
32
+ original_size = image.size
33
+ inputs = processor(images=image, return_tensors="pt")
34
+ inputs = {k: v.to(device) for k, v in inputs.items()}
35
+ model_size = processor.size['height']
 
 
 
36
 
37
  with torch.no_grad():
38
+ outputs = model(**inputs)
39
  features = outputs.last_hidden_state
40
 
41
  return features, original_size, model_size
 
55
  max_sim1 = torch.max(similarity, dim=-1)[0]
56
  max_sim2 = torch.max(similarity, dim=-2)[0]
57
 
58
+ mutual_matches = matches2[0, matches1[0]] == torch.arange(N1).to(device)
59
  good_matches = (max_sim1[0] > threshold) & mutual_matches
60
 
61
+ return matches1[0][good_matches], torch.arange(N1).to(device)[good_matches], max_sim1[0][good_matches]
62
 
63
  def patch_to_image_coords(patch_idx, original_size, model_size, patch_size=14):
64
  orig_w, orig_h = original_size
 
65
  patches_h = model_size // patch_size
66
  patches_w = model_size // patch_size
67
 
 
81
 
82
  def match_keypoints(image1, image2, model_name):
83
  if image1 is None or image2 is None:
84
+ return None
85
 
86
  load_model(model_name)
87
 
88
  img1_pil = Image.fromarray(image1).convert('RGB')
89
  img2_pil = Image.fromarray(image2).convert('RGB')
90
 
 
91
  features1, original_size1, model_size1 = extract_features(img1_pil)
92
  features2, original_size2, model_size2 = extract_features(img2_pil)
93
 
 
160
  with gr.Column(scale=2):
161
  output_image = gr.Image(label="Matched Keypoints")
162
 
 
163
  model_selector.change(
164
  fn=load_model,
165
  inputs=[model_selector],
 
178
  )
179
 
180
  if __name__ == "__main__":
181
+ demo.launch(share=True)