try:
    from ipytorch import logging
except Exception as e:
    import logging

from typing import Any, Optional, Protocol, Iterable, Callable
from tqdm.auto import tqdm
from evaluate.evaluation_suite import EvaluationSuite
import evaluate
import numpy as np
import datasets
import pandas as pd
from .tasks import *
from .utils import *
from itertools import chain
from copy import deepcopy
from . import utils


class ReasoningMetric(evaluate.Metric):
    """TODO: Short description of my evaluation module."""

    def _info(self):
        # if self.config_name in ["cmmlu"]:
        features = datasets.Features(
            {
                "responses": datasets.Value("string"),
                # "responses": datasets.Sequence(datasets.Value("float")),
                "references": datasets.Value("string"),
            }
        )

        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.EvaluationModuleInfo(
            # This is the description that will appear on the modules page.
            # module_type="measurement",
            description="",
            citation="",
            inputs_description="",
            # This defines the format of each prediction and reference
            features=features,
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"],
        )

    def _compute(self, responses, references):
        return_value = getattr(Metrics, self.config_name)(responses, references)
        match return_value:
            case extract_responses, extract_references:
                results = {
                    self.config_name: np.mean(
                        sync_pipe(lambda x, y: x == y)(
                            zip(extract_responses, extract_references)
                        )
                    )
                }
            case dict():
                results = return_value

            case list():
                results = {self.config_name: np.mean(return_value)}

            case _:
                raise NotImplementedError

        return results


class Suite(EvaluationSuite):
    task_class = Task
    utils = utils
    supported_datasets = [
        "arc",
        "hellaswag",
        "mmlu-chat",
        "winogrande",
        "gsm8k",
        "cmmlu-chat",
        "ceval-chat",
        "bbh",
        "drop",
        "MATH",
    ]

    def __getitem__(self, key) -> Task:
        match key:
            case str():
                return self.suite[key]
            case slice() | int():
                return self.tasks[key]

    def agg(self, suite):
        for cate, tasks in suite.items():
            if isinstance(tasks, dict):
                suite[cate] = self.agg(tasks)
            else:
                suite[cate] = np.mean([pd.Series(task.result).mean() for task in tasks])

        return suite

    def run(
        self,
        model_or_pipeline: Any,
    ) -> dict[str, float]:
        self.assert_suite_nonempty()

        self.suite: dict[str, list[Task]]
        for task in (bar := tqdm(self.tasks)):
            bar.desc = f"complete {task.name}."
            _ = task.run(model_or_pipeline)
            logging.info(f"{task.name} {task.result=}")
        return self.agg(deepcopy(self.suite))

    def arun(self, model_or_pipeline):
        async def sync_function():
            return await tqdm.gather(
                *[task.arun(model_or_pipeline) for task in self.tasks], leave=False
            )

        asyncio.run(sync_function())

        return self.agg(deepcopy(self.suite))

    def get_suite(self, name) -> dict[str, Task]:
        chat = False
        suite={}
        match name:
            case _ if "chat" in name:
                chat = True
        match name:
            case _ if name.startswith("mmlu"):
                suite = MMLU.suite(chat=chat)
            case _ if name.startswith("cmmlu"):
                suite = CMMLU.suite(chat=chat)
            case _ if name.startswith("ceval"):
                suite = CEVAL.suite(chat=chat)
            case "gsm8k":
                suite = Task(
                    dataset_name=("gsm8k", "main"),
                    metric_name=("sustech/tlem", "gsm8k"),
                    input_column="question",
                    label_column="answer",
                )
            case "bbh":
                suite = BBH.suite()
            case "arc":
                suite = ARC.suite()
            case "hellaswag":
                suite = HellaSwag.suite()
            case "drop":
                suite = DROP.suite()
            case "winogrande":
                suite = Winogrande.suite()

            case "mt_bench":
                suite = Task(
                    dataset_name="SUSTech/mt_bench_judge",
                    split="train",
                    prompt=mt_bench_prompt
                    # metric_name=("sustech/tlem", "gsm8k"),
                )
            case "MATH" | "competition_math":
                suite = Task(
                    dataset_name="hendrycks/competition_math",
                    prompt="This is a math problem, please think step by step and slove it: {input_column}. Simplify your final answer as much as possible and surround them with '$' in TeX form.",
                    metric_name=("sustech/tlem", "MATH"),
                    input_column="problem",
                    label_column="solution",
                )

            case "open-leaderboard":
                for name in [
                    "arc",
                    "hellaswag",
                    "mmlu-chat",
                    "winogrande",
                    "gsm8k",
                    # "truthful_qa",
                    "drop",
                ]:
                    suite.update(self.get_suite(name))
            case "tlem":
                for name in [
                    "arc",
                    "hellaswag",
                    "mmlu-chat",
                    "winogrande",
                    "gsm8k",
                    # "truthful_qa",
                    "cmmlu-chat",
                    "ceval-chat",
                    "bbh",
                ]:
                    suite.update(self.get_suite(name))

            case "all":
                for name in self.supported_datasets:
                    suite.update(self.get_suite(name))
            case _:
                raise NotImplementedError(
                    f"{name} is not supported in {self.supported_datasets}"
                )

        if isinstance(suite, Task):
            suite = [suite]
        suite = {name: suite}

        return suite

    def singleton(self, task):
        try:
            return self.tasks[self.tasks.index(task)]
        except ValueError:
            logging.debug(f"add {task.name} to suite.")
            self.tasks.append(task)
            logging.debug(self.tasks)
            return self.tasks[-1]

    def drop_duplicates(self, suite):
        for category, tasks in suite.items():
            match tasks:
                case list():
                    suite[category] = [self.singleton(task) for task in tasks]
                case dict():
                    suite[category] = self.drop_duplicates(tasks)
                case _:
                    raise NotImplementedError
        return suite

    def load(self, name):
        sub_suite = self.get_suite(name)
        self.suite.update(sub_suite)
        self.suite = self.drop_duplicates(self.suite)
        # return self

    def __init__(self, name="tlem"):
        super().__init__(name)
        self.tasks = []
        self.suite = {}