import evaluate
from rapidfuzz.distance.Levenshtein import distance, normalized_similarity

import config

BLEU = evaluate.load("saridormi/b_norm", cache_dir=config.CACHE_DIR)


def bleu_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return BLEU.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["b_norm"]
    return BLEU.compute(predictions=[pred], references=[ref])["b_norm"]


METEOR = evaluate.load("meteor", cache_dir=config.CACHE_DIR)


def meteor_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return METEOR.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["meteor"]
    return METEOR.compute(predictions=[pred], references=[ref])["meteor"]


ROUGE = evaluate.load("rouge", cache_dir=config.CACHE_DIR)


def rouge1_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge1"]
    return ROUGE.compute(predictions=[pred], references=[ref])["rouge1"]


def rouge2_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rouge2"]
    return ROUGE.compute(predictions=[pred], references=[ref])["rouge2"]


def rougeL_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return ROUGE.compute(predictions=[pred] * len(kwargs["refs"]), references=kwargs["refs"])["rougeL"]
    return ROUGE.compute(predictions=[pred], references=[ref])["rougeL"]


BERTSCORE = evaluate.load("bertscore", cache_dir=config.CACHE_DIR)


def bertscore_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return BERTSCORE.compute(predictions=[pred], references=[kwargs["refs"]], model_type="distilbert-base-uncased")[
            "f1"
        ][0]
    return BERTSCORE.compute(predictions=[pred], references=[ref], model_type="distilbert-base-uncased")["f1"][0]


CHRF = evaluate.load("chrf")


def chrf_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        return CHRF.compute(predictions=[pred], references=[kwargs["refs"]])["score"]
    return CHRF.compute(predictions=[pred], references=[[ref]])["score"]


def edit_distance_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        scores = [distance(pred, ref) for ref in kwargs["refs"]]
        return sum(scores) / len(scores)
    return distance(pred, ref)


def edit_distance_norm_fn(pred, ref, **kwargs):
    if "refs" in kwargs:
        scores = [normalized_similarity(pred, ref) for ref in kwargs["refs"]]
        return sum(scores) / len(scores)
    return normalized_similarity(pred, ref)


AGGR_METRICS = {
    "editdist": edit_distance_fn,
    "editsim": edit_distance_norm_fn,
    "bleu": bleu_fn,
    "meteor": meteor_fn,
    "rouge1": rouge1_fn,
    "rouge2": rouge2_fn,
    "rougeL": rougeL_fn,
    "bertscore": bertscore_fn,
    "chrF": chrf_fn,
}


REL_METRICS = {
    "editdist": edit_distance_fn,
}