stevenbucaille commited on
Commit
930428f
·
verified ·
1 Parent(s): 8dd9a5c

Upload DiskForKeypointDetection

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_disk.py +64 -0
config.json CHANGED
@@ -3,7 +3,8 @@
3
  "DiskForKeypointDetection"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_disk.DiskConfig"
 
7
  },
8
  "descriptor_decoder_dim": 128,
9
  "detection_threshold": 0.0,
 
3
  "DiskForKeypointDetection"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_disk.DiskConfig",
7
+ "AutoModelForKeypointDetection": "modeling_disk.DiskForKeypointDetection"
8
  },
9
  "descriptor_decoder_dim": 128,
10
  "detection_threshold": 0.0,
modeling_disk.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import torch
3
+
4
+ from .configuration_disk import DiskConfig
5
+ from transformers import AutoConfig, AutoModelForKeypointDetection, PreTrainedModel
6
+ from transformers.models.superpoint.modeling_superpoint import (
7
+ SuperPointKeypointDescriptionOutput,
8
+ )
9
+
10
+
11
+ class DiskForKeypointDetection(PreTrainedModel):
12
+ config_class = DiskConfig
13
+
14
+ def __init__(self, config: DiskConfig):
15
+ super().__init__(config)
16
+
17
+ self.config = config
18
+ self.model = kornia.feature.DISK.from_pretrained(self.config.weights)
19
+
20
+ def forward(
21
+ self, pixel_values: torch.Tensor
22
+ ) -> SuperPointKeypointDescriptionOutput:
23
+ detections = self.model(
24
+ pixel_values,
25
+ n=self.config.max_num_keypoints,
26
+ window_size=self.config.nms_window_size,
27
+ score_threshold=self.config.detection_threshold,
28
+ pad_if_not_divisible=self.config.pad_if_not_divisible,
29
+ )
30
+ max_num_keypoints = max(
31
+ detection.keypoints.shape[0] for detection in detections
32
+ )
33
+ keypoints = torch.zeros(
34
+ len(detections), max_num_keypoints, 2, device=pixel_values.device
35
+ )
36
+ descriptors = torch.zeros(
37
+ len(detections),
38
+ max_num_keypoints,
39
+ self.config.descriptor_decoder_dim,
40
+ device=pixel_values.device,
41
+ )
42
+ scores = torch.zeros(
43
+ len(detections), max_num_keypoints, device=pixel_values.device
44
+ )
45
+ mask = torch.zeros(
46
+ len(detections), max_num_keypoints, device=pixel_values.device
47
+ )
48
+ for i, detection in enumerate(detections):
49
+ keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints
50
+ descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors
51
+ scores[i, : detection.detection_scores.shape[0]] = (
52
+ detection.detection_scores
53
+ )
54
+ mask[i, : detection.detection_scores.shape[0]] = 1
55
+ width, height = pixel_values.shape[-1], pixel_values.shape[-2]
56
+ keypoints[:, :, 0] = keypoints[:, :, 0] / width
57
+ keypoints[:, :, 1] = keypoints[:, :, 1] / height
58
+
59
+ return SuperPointKeypointDescriptionOutput(
60
+ keypoints=keypoints,
61
+ scores=scores,
62
+ descriptors=descriptors,
63
+ mask=mask,
64
+ )