import kornia import torch from .configuration_disk import DiskConfig from transformers.models.superpoint.modeling_superpoint import ( SuperPointKeypointDescriptionOutput, ) from transformers import PreTrainedModel class DiskForKeypointDetection(PreTrainedModel): config_class = DiskConfig def __init__(self, config: DiskConfig): super().__init__(config) self.config = config self.model = kornia.feature.DISK(self.config.descriptor_decoder_dim) def forward( self, pixel_values: torch.Tensor ) -> SuperPointKeypointDescriptionOutput: detections = self.model( pixel_values, n=self.config.max_num_keypoints, window_size=self.config.nms_window_size, score_threshold=self.config.detection_threshold, pad_if_not_divisible=self.config.pad_if_not_divisible, ) max_num_keypoints = max( detection.keypoints.shape[0] for detection in detections ) keypoints = torch.zeros( len(detections), max_num_keypoints, 2, device=pixel_values.device ) descriptors = torch.zeros( len(detections), max_num_keypoints, self.config.descriptor_decoder_dim, device=pixel_values.device, ) scores = torch.zeros( len(detections), max_num_keypoints, device=pixel_values.device ) mask = torch.zeros( len(detections), max_num_keypoints, device=pixel_values.device ) for i, detection in enumerate(detections): keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors scores[i, : detection.detection_scores.shape[0]] = ( detection.detection_scores ) mask[i, : detection.detection_scores.shape[0]] = 1 width, height = pixel_values.shape[-1], pixel_values.shape[-2] keypoints[:, :, 0] = keypoints[:, :, 0] / width keypoints[:, :, 1] = keypoints[:, :, 1] / height return SuperPointKeypointDescriptionOutput( keypoints=keypoints, scores=scores, descriptors=descriptors, mask=mask, )