import torch
import torchmetrics
import torchmetrics.classification


class PixelAccuracy(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct_pixels", default=torch.tensor(
            0), dist_reduce_fx="sum")
        self.add_state("total_pixels", default=torch.tensor(0),
                       dist_reduce_fx="sum")

    def update(self, pred, data):
        output_mask = pred['output'] > 0.5
        gt_mask = data["seg_masks"].permute(0, 3, 1, 2)
        self.correct_pixels += (
            (output_mask == gt_mask).sum()
        )
        self.total_pixels += torch.numel(pred["valid_bev"][..., :-1])

    def compute(self):
        return self.correct_pixels / self.total_pixels


class IOU(torchmetrics.Metric):
    def __init__(self, num_classes=3, **kwargs):
        super().__init__(**kwargs)
        self.num_classes = num_classes
        self.add_state("intersection_observable", default=torch.zeros(
            num_classes), dist_reduce_fx="sum")
        self.add_state("union_observable", default=torch.zeros(
            num_classes), dist_reduce_fx="sum")
        self.add_state("intersection_non_observable",
                       default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("union_non_observable", default=torch.zeros(
            num_classes), dist_reduce_fx="sum")

    def update(self, output, data):

        gt = data["seg_masks"]
        pred = output['output']

        if "confidence_map" in data:
            observable_mask = torch.logical_and(
                output["valid_bev"][..., :-1], data["confidence_map"] == 0)
            non_observable_mask = torch.logical_and(
                output["valid_bev"][..., :-1], data["confidence_map"] == 1)
        else:
            observable_mask = output["valid_bev"][..., :-1]
            non_observable_mask = torch.logical_not(observable_mask)

        for class_idx in range(self.num_classes):
            pred_mask = pred[:, class_idx] > 0.5
            gt_mask = gt[..., class_idx]

            # For observable areas
            intersection_observable = torch.logical_and(
                torch.logical_and(pred_mask, gt_mask), observable_mask
            ).sum()
            union_observable = torch.logical_and(
                torch.logical_or(pred_mask, gt_mask), observable_mask
            ).sum()
            self.intersection_observable[class_idx] += intersection_observable
            self.union_observable[class_idx] += union_observable

            # For non-observable areas
            intersection_non_observable = torch.logical_and(
                torch.logical_and(pred_mask, gt_mask), non_observable_mask
            ).sum()
            union_non_observable = torch.logical_and(
                torch.logical_or(pred_mask, gt_mask), non_observable_mask
            ).sum()

            self.intersection_non_observable[class_idx] += intersection_non_observable
            self.union_non_observable[class_idx] += union_non_observable

    def compute(self):
        raise NotImplemented


class ObservableIOU(IOU):
    def __init__(self, class_idx=0, **kwargs):
        super().__init__(**kwargs)
        self.class_idx = class_idx

    def compute(self):
        return (self.intersection_observable / (self.union_observable + 1e-6))[self.class_idx]


class UnobservableIOU(IOU):
    def __init__(self, class_idx=0, **kwargs):
        super().__init__(**kwargs)
        self.class_idx = class_idx

    def compute(self):
        return (self.intersection_non_observable / (self.union_non_observable + 1e-6))[self.class_idx]


class MeanObservableIOU(IOU):
    def compute(self):
        return self.intersection_observable.sum() / (self.union_observable.sum() + 1e-6)


class MeanUnobservableIOU(IOU):
    def compute(self):
        return self.intersection_non_observable.sum() / (self.union_non_observable.sum() + 1e-6)


class mAP(torchmetrics.classification.MultilabelPrecision):
    def __init__(self, num_labels, **kwargs):
        super().__init__(num_labels=num_labels, **kwargs)

    def update(self, output, data):

        if "confidence_map" in data:
            observable_mask = torch.logical_and(
                output["valid_bev"][..., :-1], data["confidence_map"] == 0)
        else:
            observable_mask = output["valid_bev"][..., :-1]

        pred = output['output']
        pred = pred.permute(0, 2, 3, 1)
        pred = pred[observable_mask]

        target = data['seg_masks']
        target = target[observable_mask]

        super(mAP, self).update(pred, target)