# %%

try:
    from ipytorch import logging
except Exception as e:
    import logging

from typing import Any, Optional, Protocol, Iterable, Callable
from numpy.lib import extract
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import is_equiv


class ReasoningMetric(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # if self.config_name in ["cmmlu"]:
        features = datasets.Features(
            {
                "responses": datasets.Value("string"),
                # "responses": datasets.Sequence(datasets.Value("float")),
                "references": datasets.Value("string"),
            }
        )

        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.EvaluationModuleInfo(
            # This is the description that will appear on the modules page.
            # module_type="measurement",
            description="",
            citation="",
            inputs_description="",
            # This defines the format of each prediction and reference
            features=features,
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"],
        )

    def _compute(self, responses, references, verbose=False):
        extract_responses, extract_references = getattr(Metrics, self.config_name)(
            responses, references
        )
        df = pd.DataFrame(
            {
                "responses": responses,
                "references": references,
            }
        )
        df["extract_responses"] = extract_responses
        df["extract_references"] = extract_references
        print(df)
        results = {
            "Accuracy": (df["extract_references"] == df["extract_responses"])
            .astype(int)
            .mean(),
        }
        logging.info(results)
        if verbose:
            results["df"] = df
        return results


class Suite(EvaluationSuite):
    task_class = Task

    def run(
        self,
        model_or_pipeline: Any,
    ) -> dict[str, float]:
        self.assert_suite_nonempty()

        def run_tasks(tasks):
            for task in (bar := tqdm(tasks, leave=False)):
                bar.desc = f"complete {task.name}."
                if task.name not in self.cached_result:
                    self.cached_result[task.name] = task.run(model_or_pipeline)
            results = [self.cached_result[task.name] for task in tasks]
            return pd.DataFrame(results).mean().to_dict()

        if isinstance(self.suite, dict):
            for category, tasks in (bar := tqdm(self.suite.items())):
                bar.desc = f"complete {category}."
                logging.warning(f"Combined results {category}: {run_tasks(tasks)}")
        else:
            logging.warning(f"Combined results: {run_tasks(self.suite)}")

        return self.cached_result

    def add(self, name):
        self.load(name)

    def load(self, name):
        chat = False
        match name:
            case _ if "chat" in name:
                chat = True
        match name:
            case _ if name.startswith("mmlu"):
                suite = MMLU.suite(chat=chat)
            case _ if name.startswith("cmmlu"):
                suite = CMMLU.suite(chat=chat)
            case "gsm8k":
                suite = Task(
                    dataset_name=("gsm8k", "main"),
                    metric_name=("sustech/tlem", "gsm8k"),
                    input_column="question",
                    label_column="answer",
                )
            case "bbh":
                suite = BBH.suite()
            case "arc":
                suite = ARC.suite()
            case "hellaswag":
                suite = HellaSwag.suite()
            case "drop":
                suite = DROP.suite()
            case "winogrande":
                suite = Winogrande.suite()
            case _ if name.startswith("ceval"):
                suite = CEVAL.suite(chat=chat)
            case "mt_bench":
                suite = Task(
                    dataset_name="SUSTech/mt_bench_judge",
                    split="train",
                    prompt=mt_bench_prompt
                    # metric_name=("sustech/tlem", "gsm8k"),
                )
        match name:
            case _ if "test" in name:
                suite = suite["Test"]

        self.suite = suite

    def __init__(self, name="tlem"):
        super().__init__(name)
        self.cached_result = {}
        self.suite = []