import enum

import pandas as pd

from tasks import ner, nli, qa, summarization


class LanguageType(enum.Enum):
    Low = "Low"
    High = "High"


class ModelType(enum.Enum):
    English = "English"
    Multilingual = "Multilingual"



QA = "QA"
SUMMARIZATION = "Summarization"
NLI = "NLI"
NER = "NER"


def construct_generic_prompt(
    task,
    instruction,
    test_example,
    zero_shot,
    num_examples,
    selected_language,
    dataset,
    config,
):
    print(task)
    if task == SUMMARIZATION:
        prompt = summarization.construct_prompt(
            instruction=instruction,
            test_example=test_example,
            zero_shot=zero_shot,
            dataset=dataset,
            num_examples=num_examples,
            lang=str(selected_language).lower(),
            config=config,
        )
    elif task == NER:
        prompt = ner.construct_prompt(
            instruction=instruction,
            test_example=test_example,
            zero_shot=zero_shot,
            dataset=dataset,
            num_examples=num_examples,
            lang=str(selected_language).lower(),
            config=config,
        )
    elif task == QA:
        prompt = qa.construct_prompt(
            instruction=instruction,
            test_example=test_example,
            zero_shot=zero_shot,
            num_examples=num_examples,
            lang=str(selected_language).lower(),
            config=config,
            # dataset_name=dataset
        )
    else:
        prompt = nli.construct_prompt(
            instruction=instruction,
            test_example=test_example,
            zero_shot=zero_shot,
            num_examples=num_examples,
            lang=str(selected_language).lower(),
            config=config,
        )
    return prompt


def _get_language_type(language: str):
    df = pd.read_csv("utils/languages_by_word_count.csv")
    number_of_words = df[df["Language"] == language]["number of words"].iloc[0]
    print(number_of_words)
    return LanguageType.Low if number_of_words < 150276400 else LanguageType.High


class Config:
    def __init__(
        self, prefix="source", context="source", examples="source", output="source"
    ):
        self.prefix = prefix
        self.context = context
        self.examples = examples
        self.output = output

    def set(self, prefix=None, context=None, examples=None, output=None):
        if prefix:
            self.prefix = prefix
        if context:
            self.context = context
        if examples:
            self.examples = examples
        if output:
            self.output = output

    def to_dict(self):
        return {
            "instruction": self.prefix,
            "context": self.context,
            "examples": self.examples,
            "output": self.output,
        }


def recommend_config(task, lang, model_type):
    language_type = _get_language_type(lang)
    config = Config(lang, lang, lang, lang)
    if task == QA:
        if model_type == ModelType.English.value:
            config.set(prefix=lang, context=lang, examples=lang, output=lang)
        else:
            config.set(prefix="English", context=lang, examples=lang, output=lang)
    if task == NER:
        if model_type == ModelType.English.value:
            config.set(prefix=lang, context=lang, examples=lang, output=lang)
        elif language_type == LanguageType.High:
            config.set(prefix="English", context=lang, examples=lang, output=lang)
        else:
            config.set(prefix="English", context=lang, examples=lang, output="English")
    if task == NLI:
        if model_type == ModelType.English.value:
            config.set(prefix=lang, context=lang, examples=lang, output=lang)
        elif language_type == LanguageType.High:
            config.set(prefix="English", context=lang, examples="English")
        else:
            config.set(prefix="English", context="English", examples="English")
    if task == SUMMARIZATION:
        config.set(context="English")
    print(config.to_dict())
    return config.to_dict()