disk / configuration_disk.py
stevenbucaille's picture
Upload DiskForKeypointDetection
8dd9a5c verified
raw
history blame
894 Bytes
from typing import Optional
from transformers import PretrainedConfig
class DiskConfig(PretrainedConfig):
model_type = "disk"
def __init__(
self,
weights: str = "depth",
max_num_keypoints: Optional[int] = None,
descriptor_decoder_dim: int = 128,
nms_window_size: int = 5,
detection_threshold: float = 0.0,
pad_if_not_divisible: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.weights = weights
self.max_num_keypoints = max_num_keypoints
self.descriptor_decoder_dim = descriptor_decoder_dim
self.nms_window_size = nms_window_size
self.detection_threshold = detection_threshold
self.pad_if_not_divisible = pad_if_not_divisible
if __name__ == "__main__":
config = DiskConfig()
config.save_pretrained("stevenbucaille/disk", push_to_hub=True)