from typing import Dict, List, Union
from pathlib import Path
import datasets
import torch
import evaluate
import json
from tqdm import tqdm
from detection_metrics.pycocotools.coco import COCO 
from detection_metrics.coco_evaluate import COCOEvaluator 
from detection_metrics.utils import _TYPING_PREDICTION, _TYPING_REFERENCE

_DESCRIPTION = "This class evaluates object detection models using the COCO dataset \
    and its evaluation metrics."
_HOMEPAGE = "https://cocodataset.org"
_CITATION = """
    @misc{lin2015microsoft, \
      title={Microsoft COCO: Common Objects in Context},
      author={Tsung-Yi Lin and Michael Maire and Serge Belongie and Lubomir Bourdev and \
          Ross Girshick and James Hays and Pietro Perona and Deva Ramanan and C. Lawrence Zitnick \
              and Piotr Dollár},
      year={2015},
      eprint={1405.0312},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
"""
_REFERENCE_URLS = [
    "https://ieeexplore.ieee.org/abstract/document/9145130",
    "https://www.mdpi.com/2079-9292/10/3/279",
    "https://cocodataset.org/#detection-eval",
]
_KWARGS_DESCRIPTION = """\
Computes COCO metrics for object detection: AP(mAP) and its variants.

Args:
    coco (COCO): COCO Evaluator object for evaluating predictions.
    **kwargs: Additional keyword arguments forwarded to evaluate.Metrics.
"""

class EvaluateObjectDetection(evaluate.Metric):
    """
    Class for evaluating object detection models.
    """

    def __init__(self, json_gt: Union[Path, Dict], iou_type: str = "bbox", **kwargs):
        """
        Initializes the EvaluateObjectDetection class.

        Args:
            json_gt: JSON with ground-truth annotations in COCO format.
            # coco_groundtruth (COCO): COCO Evaluator object for evaluating predictions.
            **kwargs: Additional keyword arguments forwarded to evaluate.Metrics.
        """
        super().__init__(**kwargs)

        # Create COCO object from ground-truth annotations
        if isinstance(json_gt, Path):
            assert json_gt.exists(), f"Path {json_gt} does not exist."
            with open(json_gt) as f:
                json_data = json.load(f)
        elif isinstance(json_gt, dict):
            json_data = json_gt
        coco = COCO(json_data)

        self.coco_evaluator = COCOEvaluator(coco, [iou_type])

    def remove_classes(self, classes_to_remove: List[str]):
        to_remove = [c.upper() for c in classes_to_remove]
        cats = {}
        for id, cat in self.coco_evaluator.coco_eval["bbox"].cocoGt.cats.items():
            if cat["name"].upper() not in to_remove:
                cats[id] = cat
        self.coco_evaluator.coco_eval["bbox"].cocoGt.cats = cats
        self.coco_evaluator.coco_gt.cats = cats
        self.coco_evaluator.coco_gt.dataset["categories"] = list(cats.values())
        self.coco_evaluator.coco_eval["bbox"].params.catIds = [c["id"] for c in cats.values()]
        
    def _info(self):
        """
        Returns the MetricInfo object with information about the module.

        Returns:
            evaluate.MetricInfo: Metric information object.
        """
        return evaluate.MetricInfo(
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features(
                {
                    "predictions": [
                        datasets.Features(
                            {
                                "scores": datasets.Sequence(datasets.Value("float")),
                                "labels": datasets.Sequence(datasets.Value("int64")),
                                "boxes": datasets.Sequence(
                                    datasets.Sequence(datasets.Value("float"))
                                ),
                            }
                        )
                    ],
                    "references": [
                        datasets.Features(
                            {
                                "image_id": datasets.Sequence(datasets.Value("int64")),
                            }
                        )
                    ],
                }
            ),
            # Homepage of the module for documentation
            homepage=_HOMEPAGE,
            # Additional links to the codebase or references
            reference_urls=_REFERENCE_URLS,
        )

    def _preprocess(
        self, predictions: List[Dict[str, torch.Tensor]]
    ) -> List[_TYPING_PREDICTION]:
        """
        Preprocesses the predictions before computing the scores.

        Args:
            predictions (List[Dict[str, torch.Tensor]]): A list of prediction dicts.

        Returns:
            List[_TYPING_PREDICTION]: A list of preprocessed prediction dicts.
        """
        processed_predictions = []
        for pred in predictions:
            processed_pred: _TYPING_PREDICTION = {}
            for k, val in pred.items():
                if isinstance(val, torch.Tensor):
                    val = val.detach().cpu().tolist()
                if k == "labels":
                    val = list(map(int, val))
                processed_pred[k] = val
            processed_predictions.append(processed_pred)
        return processed_predictions

    def _clear_predictions(self, predictions):
        # Remove unnecessary keys from predictions
        required = ["scores", "labels", "boxes"]
        ret = []
        for prediction in predictions:
            ret.append({k: v for k, v in prediction.items() if k in required})
        return ret
    
    def _clear_references(self, references):
        required = [""]
        ret = []
        for ref in references:
            ret.append({k: v for k, v in ref.items() if k in required})
        return ret
                
    def add(self, *, prediction = None, reference = None, **kwargs):
        """
        Preprocesses the predictions and references and calls the parent class function.

        Args:
            prediction: A list of prediction dicts.
            reference: A list of reference dicts.
            **kwargs: Additional keyword arguments.
        """
        if prediction is not None:
            prediction = self._clear_predictions(prediction)
            prediction = self._preprocess(prediction)
        
        res = {}  # {image_id} : prediction
        for output, target in zip(prediction, reference):
            res[target["image_id"][0]] = output
        self.coco_evaluator.update(res)

        super(evaluate.Metric, self).add(prediction=prediction, references=reference, **kwargs)

    def _compute(
        self,
        predictions: List[List[_TYPING_PREDICTION]],
        references: List[List[_TYPING_REFERENCE]],
    ) -> Dict[str, Dict[str, float]]:
        """
        Returns the evaluation scores.

        Args:
            predictions (List[List[_TYPING_PREDICTION]]): A list of predictions.
            references (List[List[_TYPING_REFERENCE]]): A list of references.

        Returns:
            Dict: A dictionary containing evaluation scores.
        """
        print("Synchronizing processes")
        self.coco_evaluator.synchronize_between_processes()
        
        print("Accumulating values")
        self.coco_evaluator.accumulate()
        
        print("Summarizing results")
        self.coco_evaluator.summarize()
        
        stats = self.coco_evaluator.get_results()
        return stats