| from typing import Optional | |
| from transformers import PretrainedConfig | |
| class DiskConfig(PretrainedConfig): | |
| model_type = "disk" | |
| def __init__( | |
| self, | |
| weights: str = "depth", | |
| max_num_keypoints: Optional[int] = None, | |
| descriptor_decoder_dim: int = 128, | |
| nms_window_size: int = 5, | |
| detection_threshold: float = 0.0, | |
| pad_if_not_divisible: bool = True, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.weights = weights | |
| self.max_num_keypoints = max_num_keypoints | |
| self.descriptor_decoder_dim = descriptor_decoder_dim | |
| self.nms_window_size = nms_window_size | |
| self.detection_threshold = detection_threshold | |
| self.pad_if_not_divisible = pad_if_not_divisible | |
| if __name__ == "__main__": | |
| config = DiskConfig() | |
| config.save_pretrained("stevenbucaille/disk", push_to_hub=True) | |