from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import cached_property
from tqdm.auto import tqdm
from typing import Any, Optional, Protocol, Iterable, Callable

from .utils import (
    NUMERIC_IN_ZH,
    extract_choice_ans,
    extract_numeric,
    get_answer,
    is_equiv,
)

from evaluate import load

TextGenerationPipeline = Callable[[Iterable[str]], list[str]]


def fake_pipeline(prompts: Iterable[str]) -> list[str]:
    return [prompt for prompt in tqdm(prompts)]


@dataclass
class Task:
    dataset_name: str | tuple[str, str] = ("gsm8k", "main")
    split: str = "test"
    # metrics: list[str] = field(default_factory=list)
    metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k")
    input_column: str = "question"
    label_column: str = "answer"
    prompt: Optional[Callable | str] = None

    @cached_property
    def name(self):
        return (
            self.dataset_name
            if isinstance(self.dataset_name, str)
            else self.dataset_name[0]
        ) + f"-{self.split}"

    @cached_property
    def samples(self):
        return self.dataset[self.input_column]

    @cached_property
    def dataset(self):
        ds = load_dataset(
            *self.dataset_name
            if isinstance(self.dataset_name, tuple)
            else self.dataset_name,
            split=self.split,
        )
        if self.prompt is not None:
            ds = ds.map(
                lambda example: {
                    self.input_column: self.prompt.format(
                        input_column=example[self.input_column]
                    )
                }
                if isinstance(self.prompt, str)
                else self.prompt(example),
            )

        return ds

    @cached_property
    def metric(self):
        metric = (
            load(self.metric_name)
            if isinstance(self.metric_name, str)
            else load(*self.metric_name)
        )
        return metric

    def run(self, pipeline: TextGenerationPipeline = fake_pipeline):
        outputs = pipeline(self.samples)
        return self.metric.compute(
            responses=outputs, references=self.dataset[self.label_column]
        )


class Metrics:
    def gsm8k(responses: list[str], answers: list[str | int]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response)
            gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
            scores.append(1.0 * (pred == gold))
        return scores

    def MATH(responses: list[str], answers: list[str]):
        scores = []

        for response, answer in zip(responses, answers):
            indices = [pos for pos, char in enumerate(response) if char == "$"]
            if len(indices) <= 2:
                scores.append(0)
                continue
            else:
                result = response[indices[-2] + 1 : indices[-1]]
                gold = get_answer(answer)
                scores.append(1.0 * is_equiv(result, gold))

        return scores

    def math23k(responses: list[str], answers: list[str]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
            gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH)
            scores.append(1.0 * (pred == gold))
        return scores

    def gsm8k_zh(responses: list[str], answers: list[str]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
            gold = extract_numeric(answer)
            scores.append(1.0 * (pred == gold))
        return scores

    def svamp(responses: list[float], answers: list[str]):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
            gold = answer
            scores.append(1.0 * (float(pred) == gold))
        return scores

    def mmlu(responses, answers):
        scores = []
        for response, answer in zip(responses, answers):
            pred = extract_choice_ans(response)
            gold = answer.lower()
            scores.append(1.0 * (pred == gold))
        return scores