""" Official evaluation script for CUAD dataset. """

import argparse
import json
import re
import string
import sys

import numpy as np


IOU_THRESH = 0.5


def get_jaccard(prediction, ground_truth):
    remove_tokens = [".", ",", ";", ":"]

    for token in remove_tokens:
        ground_truth = ground_truth.replace(token, "")
        prediction = prediction.replace(token, "")

    ground_truth, prediction = ground_truth.lower(), prediction.lower()
    ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ")
    ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" "))

    intersection = ground_truth.intersection(prediction)
    union = ground_truth.union(prediction)
    jaccard = len(intersection) / len(union)
    return jaccard


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_precision_recall(predictions, ground_truths, qa_id):
    tp, fp, fn = 0, 0, 0

    substr_ok = "Parties" in qa_id

    # first check if ground truth is empty
    if len(ground_truths) == 0:
        if len(predictions) > 0:
            fp += len(predictions)  # false positive for each one
    else:
        for ground_truth in ground_truths:
            assert len(ground_truth) > 0
            # check if there is a match
            match_found = False
            for pred in predictions:
                if substr_ok:
                    is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred
                else:
                    is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH
                if is_match:
                    match_found = True

            if match_found:
                tp += 1
            else:
                fn += 1

        # now also get any fps by looping through preds
        for pred in predictions:
            # Check if there's a match. if so, don't count (don't want to double count based on the above)
            # but if there's no match, then this is a false positive.
            # (Note: we get the true positives in the above loop instead of this loop so that we don't double count
            # multiple predictions that are matched with the same answer.)
            match_found = False
            for ground_truth in ground_truths:
                assert len(ground_truth) > 0
                if substr_ok:
                    is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred
                else:
                    is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH
                if is_match:
                    match_found = True

            if not match_found:
                fp += 1

    precision = tp / (tp + fp) if tp + fp > 0 else np.nan
    recall = tp / (tp + fn) if tp + fn > 0 else np.nan

    return precision, recall


def process_precisions(precisions):
    """
    Processes precisions to ensure that precision and recall don't both get worse.
    Assumes the list precision is sorted in order of recalls
    """
    precision_best = precisions[::-1]
    for i in range(1, len(precision_best)):
        precision_best[i] = max(precision_best[i - 1], precision_best[i])
    precisions = precision_best[::-1]
    return precisions


def get_aupr(precisions, recalls):
    processed_precisions = process_precisions(precisions)
    aupr = np.trapz(processed_precisions, recalls)
    if np.isnan(aupr):
        return 0
    return aupr


def get_prec_at_recall(precisions, recalls, recall_thresh):
    """Assumes recalls are sorted in increasing order"""
    processed_precisions = process_precisions(precisions)
    prec_at_recall = 0
    for prec, recall in zip(processed_precisions, recalls):
        if recall >= recall_thresh:
            prec_at_recall = prec
            break
    return prec_at_recall


def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)


def metric_max_over_ground_truths(metric_fn, predictions, ground_truths):
    score = 0
    for pred in predictions:
        for ground_truth in ground_truths:
            score = metric_fn(pred, ground_truth)
            if score == 1:  # break the loop when one prediction matches the ground truth
                break
        if score == 1:
            break
    return score


def compute_score(dataset, predictions):
    f1 = exact_match = total = 0
    precisions = []
    recalls = []
    for article in dataset:
        for paragraph in article["paragraphs"]:
            for qa in paragraph["qas"]:
                total += 1
                if qa["id"] not in predictions:
                    message = "Unanswered question " + qa["id"] + " will receive score 0."
                    print(message, file=sys.stderr)
                    continue
                ground_truths = list(map(lambda x: x["text"], qa["answers"]))
                prediction = predictions[qa["id"]]
                precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"])

                precisions.append(precision)
                recalls.append(recall)

                if precision == 0 and recall == 0:
                    f1 += 0
                else:
                    f1 += 2 * (precision * recall) / (precision + recall)

                exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)

    precisions = [x for _, x in sorted(zip(recalls, precisions))]
    recalls.sort()

    f1 = 100.0 * f1 / total
    exact_match = 100.0 * exact_match / total
    aupr = get_aupr(precisions, recalls)

    prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9)
    prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8)

    return {
        "exact_match": exact_match,
        "f1": f1,
        "aupr": aupr,
        "prec_at_80_recall": prec_at_80_recall,
        "prec_at_90_recall": prec_at_90_recall,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluation for CUAD")
    parser.add_argument("dataset_file", help="Dataset file")
    parser.add_argument("prediction_file", help="Prediction File")
    args = parser.parse_args()
    with open(args.dataset_file) as dataset_file:
        dataset_json = json.load(dataset_file)
        dataset = dataset_json["data"]
    with open(args.prediction_file) as prediction_file:
        predictions = json.load(prediction_file)
    print(json.dumps(compute_score(dataset, predictions)))