from typing import Optional, Dict
import torch.nn as nn
import torch
from .schema import LossConfiguration


def dice_loss(input: torch.Tensor,
              target: torch.Tensor,
              loss_mask: torch.Tensor,
              class_weights: Optional[torch.Tensor | bool],
              smooth=1e-5):
    '''
    :param input: (B, H, W, C) Logits for each class
    :param target: (B, H, W, C) Ground truth class labels in one_hot
    :param loss_mask: (B, H, W) Mask indicating valid regions of the image
    :param class_weights: (C) Weights for each class
    :param smooth: Smoothing factor to avoid division by zero, default 1.0
    '''
    
    if isinstance(class_weights, torch.Tensor):
        class_weights = class_weights.unsqueeze(0)
    elif class_weights is None or class_weights == False:
        class_weights = torch.ones(
            1, target.size(-1), dtype=target.dtype, device=target.device)
    elif class_weights == True:
        class_weights = target.sum(1)
        class_weights = torch.reciprocal(target.mean(1) + 1e-3)
        class_weights = class_weights.clamp(min=1e-5)
        # Only consider classes that are present
        class_weights *= (target.sum(1) != 0).float()
        class_weights.requires_grad = False

    intersect = (2 * input * target)
    intersect = (intersect) + smooth

    union = (input + target)
    union = (union) + smooth

    loss = 1 - (intersect / union)  # B, H, W, C
    loss *= class_weights.unsqueeze(0).unsqueeze(0)
    loss = loss.sum(-1) / class_weights.sum()
    loss *= loss_mask
    loss = loss.sum() / loss_mask.sum()  # 1

    return loss


class EnhancedLoss(nn.Module):
    def __init__(
        self,
        cfg: LossConfiguration,
    ):  # following params in the paper
        super(EnhancedLoss, self).__init__()
        self.num_classes = cfg.num_classes
        self.xent_weight = cfg.xent_weight
        self.focal = cfg.focal_loss
        self.focal_gamma = cfg.focal_loss_gamma
        self.dice_weight = cfg.dice_weight
        # self.class_mapping = 

        if self.xent_weight == 0. and self.dice_weight == 0.:
            raise ValueError(
                "At least one of xent_weight and dice_weight must be greater than 0.")
        
        if self.xent_weight > 0.:
            self.xent_loss = nn.BCEWithLogitsLoss(
                reduction="none"
            )

        if self.dice_weight > 0.:
            self.dice_loss = dice_loss

        if cfg.class_weights is not None and cfg.class_weights != True:
            self.register_buffer("class_weights", torch.tensor(
                cfg.class_weights), persistent=False)
        else:
            self.class_weights = cfg.class_weights

        self.class_weights: Optional[torch.Tensor | bool]

        self.requires_frustrum = cfg.requires_frustrum
        self.requires_flood_mask = cfg.requires_flood_mask
        self.label_smoothing = cfg.label_smoothing

    def forward(self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]):
        '''
        Args:
            pred: Dict containing the
                - output: (B, C, H, W) Probabilities for each class
                - valid_bev: (B, H, W) Mask indicating valid regions of the image
                - conf: (B, H, W) Confidence map
            data: Dict containing the
                - seg_masks: (B, H, W, C) Ground truth class labels, one-hot encoded
                - confidence_map: (B, H, W) Confidence map
        '''
        loss = {}

        probs = pred['output'].permute(0, 2, 3, 1)  # (B, H, W, C)
        logits = pred['logits'].permute(0, 2, 3, 1)  # (B, H, W, C)
        labels: torch.Tensor = data['seg_masks']  # (B, H, W, C)

        loss_mask = torch.ones(
            labels.shape[:3], device=labels.device, dtype=labels.dtype)

        if self.requires_frustrum:
            frustrum_mask = pred["valid_bev"][..., :-1] != 0
            loss_mask = loss_mask * frustrum_mask.float()

        if self.requires_flood_mask:
            flood_mask = data["flood_masks"] == 0
            loss_mask = loss_mask * flood_mask.float()

        if self.xent_weight > 0.:

            if self.label_smoothing > 0.:
                labels_ls = labels.float().clone()
                labels_ls = labels_ls * \
                    (1 - self.label_smoothing) + \
                    self.label_smoothing / self.num_classes

                xent_loss = self.xent_loss(logits, labels_ls)
            else:
                xent_loss = self.xent_loss(logits, labels)

            if self.focal:
                pt = torch.exp(-xent_loss)
                xent_loss = (1 - pt) ** self.focal_gamma * xent_loss

            xent_loss *= loss_mask.unsqueeze(-1)
            xent_loss = xent_loss.sum() / (loss_mask.sum() + 1e-5)
            loss['cross_entropy'] = xent_loss
            loss['total'] = xent_loss * self.xent_weight

        if self.dice_weight > 0.:
            dloss = self.dice_loss(
                probs, labels, loss_mask, self.class_weights)
            loss['dice'] = dloss

            if 'total' in loss:
                loss['total'] += dloss * self.dice_weight
            else:
                loss['total'] = dloss * self.dice_weight

        return loss