File size: 1,369 Bytes
4187c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

@dataclass
class LossConfiguration:
    num_classes: int

    xent_weight: float = 1.0
    dice_weight: float = 1.0
    focal_loss: bool = False
    focal_loss_gamma: float = 2.0
    requires_frustrum: bool = True
    requires_flood_mask: bool = False
    class_weights: Optional[Any] = None
    label_smoothing: float = 0.1

@dataclass
class BackboneConfigurationBase:
    pretrained: bool
    frozen: bool
    output_dim: bool

@dataclass
class DINOConfiguration(BackboneConfigurationBase):
    pretrained: bool = True
    frozen: bool = False
    output_dim: int = 128

@dataclass
class ResNetConfiguration(BackboneConfigurationBase):
    input_dim: int
    encoder: str
    remove_stride_from_first_conv: bool
    num_downsample: Optional[int]
    decoder_norm: str
    do_average_pooling: bool
    checkpointed: bool

@dataclass
class ImageEncoderConfiguration:
    name: str
    backbone: Any

@dataclass
class ModelConfiguration:
    segmentation_head: Dict[str, Any]
    image_encoder: ImageEncoderConfiguration

    name: str
    num_classes: int
    latent_dim: int
    z_max: int
    x_max: int
    
    pixel_per_meter: int
    num_scale_bins: int

    loss: LossConfiguration

    scale_range: list[int] = field(default_factory=lambda: [0, 9])
    z_min: Optional[int] = None