stevenbucaille commited on
Commit
12ba2dd
·
0 Parent(s):

Initial commit: Add DISK model implementation with fixed imports

Browse files
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ build/
7
+ develop-eggs/
8
+ dist/
9
+ downloads/
10
+ eggs/
11
+ .eggs/
12
+ lib/
13
+ lib64/
14
+ parts/
15
+ sdist/
16
+ var/
17
+ wheels/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+ MANIFEST
disk_model/__init__.py ADDED
File without changes
disk_model/configuration_disk.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class DiskConfig(PretrainedConfig):
7
+ model_type = "disk"
8
+
9
+ def __init__(
10
+ self,
11
+ weights: str = "depth",
12
+ max_num_keypoints: Optional[int] = None,
13
+ descriptor_decoder_dim: int = 128,
14
+ nms_window_size: int = 5,
15
+ detection_threshold: float = 0.0,
16
+ pad_if_not_divisible: bool = True,
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+ self.weights = weights
21
+ self.max_num_keypoints = max_num_keypoints
22
+ self.descriptor_decoder_dim = descriptor_decoder_dim
23
+ self.nms_window_size = nms_window_size
24
+ self.detection_threshold = detection_threshold
25
+ self.pad_if_not_divisible = pad_if_not_divisible
26
+
27
+
28
+ if __name__ == "__main__":
29
+ config = DiskConfig()
30
+ config.save_pretrained("stevenbucaille/disk", push_to_hub=True)
disk_model/modeling_disk.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import torch
3
+
4
+ from configuration_disk import DiskConfig
5
+ from transformers import AutoConfig, AutoModelForKeypointDetection, PreTrainedModel
6
+ from transformers.models.superpoint.modeling_superpoint import (
7
+ SuperPointKeypointDescriptionOutput,
8
+ )
9
+
10
+
11
+ class DiskForKeypointDetection(PreTrainedModel):
12
+ config_class = DiskConfig
13
+
14
+ def __init__(self, config: DiskConfig):
15
+ super().__init__(config)
16
+
17
+ self.config = config
18
+ self.model = kornia.feature.DISK.from_pretrained(self.config.weights)
19
+
20
+ def forward(
21
+ self, pixel_values: torch.Tensor
22
+ ) -> SuperPointKeypointDescriptionOutput:
23
+ detections = self.model(
24
+ pixel_values,
25
+ n=self.config.max_num_keypoints,
26
+ window_size=self.config.nms_window_size,
27
+ score_threshold=self.config.detection_threshold,
28
+ pad_if_not_divisible=self.config.pad_if_not_divisible,
29
+ )
30
+ max_num_keypoints = max(
31
+ detection.keypoints.shape[0] for detection in detections
32
+ )
33
+ keypoints = torch.zeros(
34
+ len(detections), max_num_keypoints, 2, device=pixel_values.device
35
+ )
36
+ descriptors = torch.zeros(
37
+ len(detections),
38
+ max_num_keypoints,
39
+ self.config.descriptor_decoder_dim,
40
+ device=pixel_values.device,
41
+ )
42
+ scores = torch.zeros(
43
+ len(detections), max_num_keypoints, device=pixel_values.device
44
+ )
45
+ mask = torch.zeros(
46
+ len(detections), max_num_keypoints, device=pixel_values.device
47
+ )
48
+ for i, detection in enumerate(detections):
49
+ keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints
50
+ descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors
51
+ scores[i, : detection.detection_scores.shape[0]] = (
52
+ detection.detection_scores
53
+ )
54
+ mask[i, : detection.detection_scores.shape[0]] = 1
55
+ width, height = pixel_values.shape[-1], pixel_values.shape[-2]
56
+ keypoints[:, :, 0] = keypoints[:, :, 0] / width
57
+ keypoints[:, :, 1] = keypoints[:, :, 1] / height
58
+
59
+ return SuperPointKeypointDescriptionOutput(
60
+ keypoints=keypoints,
61
+ scores=scores,
62
+ descriptors=descriptors,
63
+ mask=mask,
64
+ )
65
+
66
+
67
+ if __name__ == "__main__":
68
+ config = DiskConfig()
69
+ model = DiskForKeypointDetection(config)
70
+ model.save_pretrained("stevenbucaille/disk", push_to_hub=True)
71
+
72
+ AutoConfig.register("disk", DiskConfig)
73
+ AutoModelForKeypointDetection.register(DiskConfig, DiskForKeypointDetection)
74
+
75
+ DiskConfig.register_for_auto_class()
76
+ DiskForKeypointDetection.register_for_auto_class("AutoModelForKeypointDetection")