Upload DiskForKeypointDetection
Browse files- modeling_disk.py +2 -2
modeling_disk.py
CHANGED
@@ -2,10 +2,10 @@ import kornia
|
|
2 |
import torch
|
3 |
|
4 |
from .configuration_disk import DiskConfig
|
5 |
-
from transformers import AutoConfig, AutoModelForKeypointDetection, PreTrainedModel
|
6 |
from transformers.models.superpoint.modeling_superpoint import (
|
7 |
SuperPointKeypointDescriptionOutput,
|
8 |
)
|
|
|
9 |
|
10 |
|
11 |
class DiskForKeypointDetection(PreTrainedModel):
|
@@ -15,7 +15,7 @@ class DiskForKeypointDetection(PreTrainedModel):
|
|
15 |
super().__init__(config)
|
16 |
|
17 |
self.config = config
|
18 |
-
self.model = kornia.feature.DISK
|
19 |
|
20 |
def forward(
|
21 |
self, pixel_values: torch.Tensor
|
|
|
2 |
import torch
|
3 |
|
4 |
from .configuration_disk import DiskConfig
|
|
|
5 |
from transformers.models.superpoint.modeling_superpoint import (
|
6 |
SuperPointKeypointDescriptionOutput,
|
7 |
)
|
8 |
+
from transformers import PreTrainedModel
|
9 |
|
10 |
|
11 |
class DiskForKeypointDetection(PreTrainedModel):
|
|
|
15 |
super().__init__(config)
|
16 |
|
17 |
self.config = config
|
18 |
+
self.model = kornia.feature.DISK(self.config.descriptor_decoder_dim)
|
19 |
|
20 |
def forward(
|
21 |
self, pixel_values: torch.Tensor
|