# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

import datasets
import evaluate
from seametrics.user_friendly.utils import calculate_from_payload

import wandb

_CITATION = """\
@InProceedings{huggingface:module,
title = {A great new module},
authors={huggingface, Inc.},
year={2020}
}\
@article{milan2016mot16,
  title={MOT16: A benchmark for multi-object tracking},
  author={Milan, Anton and Leal-Taix{\'e}, Laura and Reid, Ian and Roth, Stefan and Schindler, Konrad},
  journal={arXiv preprint arXiv:1603.00831},
  year={2016}
}
"""

_DESCRIPTION = """\
The MOT Metrics module is designed to evaluate multi-object tracking (MOT) 
algorithms by computing various metrics based on predicted and ground truth bounding 
boxes. It serves as a crucial tool in assessing the performance of MOT systems, 
aiding in the iterative improvement of tracking algorithms."""


_KWARGS_DESCRIPTION = """

Calculates how good are predictions given some references, using certain scores
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    max_iou (`float`, *optional*):
        If specified, this is the minimum Intersection over Union (IoU) threshold to consider a detection as a true positive.
        Default is 0.5.
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class UserFriendlyMetrics(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features(
                {
                    "predictions": datasets.Sequence(
                        datasets.Sequence(datasets.Value("float"))
                    ),
                    "references": datasets.Sequence(
                        datasets.Sequence(datasets.Value("float"))
                    ),
                }
            ),
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"],
        )

    def _download_and_prepare(self, dl_manager):
        """Optional: download external resources useful to compute the scores"""
        # TODO: Download external resources if needed
        pass

    def _compute(
        self,
        payload,
        max_iou: float = 0.5,
        filters={},
        recognition_thresholds=[0.3, 0.5, 0.8],
        debug: bool = False,
    ):
        """Returns the scores"""
        # TODO: Compute the different scores of the module
        return calculate_from_payload(
            payload, max_iou, filters, recognition_thresholds, debug
        )
        # return calculate(predictions, references, max_iou)

    def wandb(
        self,
        results,
        wandb_section: str = None,
        wandb_project="user_friendly_metrics",
        log_plots: bool = True,
        debug: bool = False,
    ):
        """
        Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for global metrics.

        Args:
            results (dict): Results dictionary with 'global' and 'per_sequence' keys.
            wandb_section (str, optional): W&B section for metric grouping. Defaults to None.
            wandb_project (str, optional): The name of the wandb project. Defaults to 'user_friendly_metrics'.
            log_plots (bool, optional): Generates categorized bar charts for global metrics. Defaults to True.
            debug (bool, optional): Logs detailed summaries and histories to the terminal console. Defaults to False.
        """

        current_datetime = datetime.datetime.now()
        formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
        wandb.login(key=os.getenv("WANDB_API_KEY"))

        run = wandb.init(
            project=wandb_project,
            name=f"evaluation-{formatted_datetime}",
            reinit=True,
            settings=wandb.Settings(silent=not debug),
        )

        categories = {
            "confusion_metrics": {"fp", "tp", "fn"},
            "evaluation_metrics": {"f1", "recall", "precision"},
            "recognition_metrics": {
                "recognition_0.3",
                "recognition_0.5",
                "recognition_0.8",
                "recognized_0.3",
                "recognized_0.5",
                "recognized_0.8",
            },
        }

        chart_data = {key: [] for key in categories.keys()}

        # Log global metrics
        if "global" in results:
            for global_key, global_metrics in results["global"].items():
                for metric, value in global_metrics["all"].items():
                    log_key = (
                        f"{wandb_section}/global/{global_key}/{metric}"
                        if wandb_section
                        else f"global/{global_key}/{metric}"
                    )
                    run.log({log_key: value})

                    if debug:
                        print(f"Logged to W&B: {log_key} = {value}")

                    for category, metrics in categories.items():
                        if metric in metrics:
                            chart_data[category].append([metric, value])

        if log_plots:
            for category, data in chart_data.items():
                if data:
                    table_data = [[label, value] for label, value in data]
                    table = wandb.Table(data=table_data, columns=["metrics", "value"])
                    run.log(
                        {
                            f"{category}_bar_chart": wandb.plot.bar(
                                table,
                                "metrics",
                                "value",
                                title=f"{category.replace('_', ' ').title()}",
                            )
                        }
                    )

        if "per_sequence" in results:
            sorted_sequences = sorted(
                results["per_sequence"].items(),
                key=lambda x: x[1]
                .get("evaluation_metrics", {})
                .get("f1", {})
                .get("all", 0),
                reverse=True,
            )

            for sequence_name, sequence_data in sorted_sequences:
                for seq_key, seq_metrics in sequence_data.items():
                    for metric, value in seq_metrics["all"].items():
                        log_key = (
                            f"{wandb_section}/per_sequence/{sequence_name}/{seq_key}/{metric}"
                            if wandb_section
                            else f"per_sequence/{sequence_name}/{seq_key}/{metric}"
                        )
                        run.log({log_key: value})
                        if debug:
                            print(
                                f"Logged to W&B: {sequence_name} -> {log_key} = {value}"
                            )

        if debug:
            print("\nDebug Mode: Logging Summary and History")
            print(f"Results Summary:\n{results}")
            print(f"WandB Settings:\n{run.settings}")
            print("All metrics have been logged.")

        run.finish()