"""Demo gradio app for some text/query augmentation."""
from __future__ import annotations

import functools
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Sequence

import attr
import environ
import fasttext  # not working with python3.9
import gradio as gr
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline
from transformers.pipelines.token_classification import AggregationStrategy


def compose(*functions) -> Callable:
    """
    Compose functions.

        Args:
            functions: functions to compose.
        Returns:
            Composed functions.
    """

    def apply(f, g):
        return lambda x: f(g(x))

    return functools.reduce(apply, functions[::-1], lambda x: x)


def mapped(fn) -> Callable:
    """
    Decorator to apply map/filter to a function
    """

    def inner(func):
        partial_fn = functools.partial(fn, func)

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return partial_fn(*args, **kwargs)

        return wrapper

    return inner


@attr.frozen
class Prediction:
    """Dataclass to store prediction results."""

    label: str
    score: float


@attr.frozen
class Models:
    identification: Predictor
    translation: Predictor
    classification: Predictor
    ner: Predictor
    recipe: Predictor


@attr.frozen
class Predictor:
    load_fn: Callable
    predict_fn: Callable = attr.field(default=lambda model, query: model(query))
    model: Any = attr.field(init=False)

    def __attrs_post_init__(self):
        object.__setattr__(self, "model", self.load_fn())

    def __call__(self, *args: Any, **kwds: Any) -> Any:
        return self.predict_fn(self.model, *args, **kwds)


@environ.config(prefix="QUERY_INTERPRETATION")
class AppConfig:
    @environ.config
    class Identification:
        """Identification model configuration."""

        model = environ.var(default="./models/lid.176.ftz")
        max_results = environ.var(default=3, converter=int)

    @environ.config
    class Translation:
        """Translation models configuration."""

        model = environ.var(default="t5-small")
        sources = environ.var(default="de,fr")
        target = environ.var(default="en")

    @environ.config
    class Classification:
        """Classification model configuration."""

        model = environ.var(default="typeform/distilbert-base-uncased-mnli")
        max_results = environ.var(default=5, converter=int)

    @environ.config
    class NER:
        general = environ.var(
            default="asahi417/tner-xlm-roberta-large-uncased-wnut2017",
        )

        recipe = environ.var(default="adamlin/recipe-tag-model")

    identification: Identification = environ.group(Identification)
    translation: Translation = environ.group(Translation)
    classification: Classification = environ.group(Classification)
    ner: NER = environ.group(NER)


def predict(
    models: Models,
    query: str,
    categories: Sequence[str],
    supported_languages: tuple[str, ...] = ("fr", "de"),
) -> tuple[
    Mapping[str, float],
    Mapping[str, float],
    str,
    Sequence[tuple[str, str | None]],
    Sequence[tuple[str, str | None]],
]:
    """Predict from a textual query:
    - the language
    - classify as a recipe or not
    - extract the recipe
    """

    def predict_lang(query) -> Mapping[str, float]:
        def predict_fn(query) -> Sequence[Prediction]:
            return tuple(
                Prediction(label=label, score=score)
                for label, score in zip(*models.identification(query, k=176))
            )

        @mapped(map)
        def format_label(prediction: Prediction) -> Prediction:
            return attr.evolve(
                prediction,
                label=prediction.label.replace("__label__", ""),
            )

        def filter_labels(prediction: Prediction) -> bool:
            return prediction.label in supported_languages + ("en",)

        def format_output(predictions: Sequence[Prediction]) -> dict:
            return {pred.label: pred.score for pred in predictions}

        apply_fn = compose(
            predict_fn,
            format_label,
            functools.partial(filter, filter_labels),
            format_output,
        )
        return apply_fn(query)

    def translate_query(query: str, languages: Mapping[str, float]) -> str:
        def predicted_language() -> str:
            return max(languages.items(), key=lambda lang: lang[1])[0]

        def translate(query):
            lang = predicted_language()
            if lang in supported_languages:
                output = models.translation(query, lang)[0]["translation_text"]
            else:
                output = query

            return output

        return translate(query)

    def classify_query(query, categories) -> Mapping[str, float]:
        predictions = models.classification(query, categories)
        return dict(zip(predictions["labels"], predictions["scores"]))

    def extract_entities(
        predict_fn: Callable,
        query: str,
    ) -> Sequence[tuple[str, str | None]]:

        predictions = predict_fn(query)

        if len(predictions) == 0:
            return [(query, None)]
        else:
            return [
                (pred["word"], pred.get("entity_group", pred.get("entity", None)))
                for pred in predictions
            ]

    languages = predict_lang(query)
    translation = translate_query(query, languages)
    classifications = classify_query(translation, categories)
    general_entities = extract_entities(models.ner, query)
    recipe_entities = extract_entities(models.recipe, translation)
    return languages, classifications, translation, general_entities, recipe_entities


def main():
    cfg: AppConfig = AppConfig.from_environ()

    def load_translation_models(
        sources: Sequence[str],
        target: str,
        models: Sequence[str],
    ) -> Pipeline:
        result = {
            src: pipeline(f"translation_{src}_to_{target}", models)
            for src, models in zip(sources, models)
        }
        return result

    def extract_commas_separated_values(value: str) -> Sequence[str]:
        return tuple(filter(None, value.split(",")))

    models = Models(
        identification=Predictor(
            load_fn=lambda: fasttext.load_model(cfg.identification.model),
            predict_fn=lambda model, query, k: model.predict(query, k=k),
        ),
        translation=Predictor(
            load_fn=functools.partial(
                load_translation_models,
                sources=extract_commas_separated_values(cfg.translation.sources),
                target=cfg.translation.target,
                models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"],
            ),
            predict_fn=lambda models, query, src: models[src](query),
        ),
        classification=Predictor(
            load_fn=lambda: pipeline(
                "zero-shot-classification",
                model=cfg.classification.model,
            ),
            predict_fn=lambda model, query, categories: model(query, categories),
        ),
        ner=Predictor(
            load_fn=lambda: pipeline(
                "ner",
                model=cfg.ner.general,
                aggregation_strategy=AggregationStrategy.SIMPLE,
            ),
        ),
        recipe=Predictor(
            load_fn=lambda: pipeline("ner", model=cfg.ner.recipe),
        ),
    )

    iface = gr.Interface(
        fn=lambda query, categories: predict(
            models,
            query.strip(),
            extract_commas_separated_values(categories),
        ),
        examples=[["gateau au chocolat paris"], ["Newyork LA flight"]],
        inputs=[
            gr.inputs.Textbox(label="Query"),
            gr.inputs.Textbox(
                label="categories (commas separated and in english)",
                default="cooking and recipe,traveling,location,information,buy or sell",
            ),
        ],
        outputs=[
            gr.outputs.Label(
                num_top_classes=cfg.identification.max_results,
                type="auto",
                label="Language identification",
            ),
            gr.outputs.Label(
                num_top_classes=cfg.classification.max_results,
                type="auto",
                label="Predicted categories",
            ),
            gr.outputs.Textbox(
                label="English query",
                type="auto",
            ),
            gr.outputs.HighlightedText(label="NER generic"),
            gr.outputs.HighlightedText(label="NER Recipes"),
        ],
        interpretation="default",
    )

    iface.launch(debug=True)


if __name__ == "__main__":
    main()