from dataclasses import dataclass, field
from typing import Any, Dict, Optional

@dataclass
class LossConfiguration:
    num_classes: int

    xent_weight: float = 1.0
    dice_weight: float = 1.0
    focal_loss: bool = False
    focal_loss_gamma: float = 2.0
    requires_frustrum: bool = True
    requires_flood_mask: bool = False
    class_weights: Optional[Any] = None
    label_smoothing: float = 0.1

@dataclass
class BackboneConfigurationBase:
    pretrained: bool
    frozen: bool
    output_dim: bool

@dataclass
class DINOConfiguration(BackboneConfigurationBase):
    pretrained: bool = True
    frozen: bool = False
    output_dim: int = 128

@dataclass
class ResNetConfiguration(BackboneConfigurationBase):
    input_dim: int
    encoder: str
    remove_stride_from_first_conv: bool
    num_downsample: Optional[int]
    decoder_norm: str
    do_average_pooling: bool
    checkpointed: bool

@dataclass
class ImageEncoderConfiguration:
    name: str
    backbone: Any

@dataclass
class ModelConfiguration:
    segmentation_head: Dict[str, Any]
    image_encoder: ImageEncoderConfiguration

    name: str
    num_classes: int
    latent_dim: int
    z_max: int
    x_max: int
    
    pixel_per_meter: int
    num_scale_bins: int

    loss: LossConfiguration

    scale_range: list[int] = field(default_factory=lambda: [0, 9])
    z_min: Optional[int] = None