# code adapted from https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py
# tree edit distance video explanation: https://www.youtube.com/watch?v=6Ur8B35xCj8
import apted
import distance
from collections import deque
from lxml import etree, html
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Tuple


class TableTree(apted.helpers.Tree):
    def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
        self.tag = tag
        self.colspan = colspan
        self.rowspan = rowspan
        self.content = content
        self.children = list(children)

    def bracket(self):
        """Show tree using brackets notation."""
        if self.tag == "td":
            result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
                self.tag,
                self.colspan,
                self.rowspan,
                self.content,
            )
        else:
            result = '"tag": %s' % self.tag
        for child in self.children:
            result += child.bracket()
        return "{{{}}}".format(result)


class CustomConfig(apted.Config):
    @staticmethod
    def maximum(*sequences):
        """Get maximum possible value."""
        return max(map(len, sequences))

    def normalized_distance(self, *sequences):
        """Get distance from 0 to 1."""
        return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

    def rename(self, node1, node2):
        """Compares attributes of trees"""
        if (
            (node1.tag != node2.tag)
            or (node1.colspan != node2.colspan)
            or (node1.rowspan != node2.rowspan)
        ):
            return 1.0
        if node1.tag == "td":
            if node1.content or node2.content:
                return self.normalized_distance(node1.content, node2.content)
        return 0.0


class TEDS(object):
    """Tree Edit Distance basead Similarity"""

    def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
        assert isinstance(n_jobs, int) and (
            n_jobs >= 1
        ), "n_jobs must be an integer greather than 1"
        self.structure_only = structure_only
        self.n_jobs = n_jobs
        self.ignore_nodes = ignore_nodes
        self.__tokens__ = []

    def tokenize(self, node):
        """Tokenizes table cells"""
        self.__tokens__.append("<%s>" % node.tag)
        if node.text is not None:
            self.__tokens__ += list(node.text)
        for n in node.getchildren():
            self.tokenize(n)
        if node.tag != "unk":
            self.__tokens__.append("</%s>" % node.tag)
        if node.tag != "td" and node.tail is not None:
            self.__tokens__ += list(node.tail)

    def load_html_tree(self, node, parent=None):
        """Converts HTML tree to the format required by apted"""
        global __tokens__
        if node.tag == "td":
            if self.structure_only:
                cell = []
            else:
                self.__tokens__ = []
                self.tokenize(node)
                cell = self.__tokens__[1:-1].copy()
            new_node = TableTree(
                node.tag,
                int(node.attrib.get("colspan", "1")),
                int(node.attrib.get("rowspan", "1")),
                cell,
                *deque(),
            )
        else:
            new_node = TableTree(node.tag, None, None, None, *deque())
        if parent is not None:
            parent.children.append(new_node)
        if node.tag != "td":
            for n in node.getchildren():
                self.load_html_tree(n, new_node)
        if parent is None:
            return new_node

    def evaluate(self, pred, true):
        """Computes TEDS score between the prediction and the ground truth of a
        given sample
        """
        if (not pred) or (not true):
            return 0.0
        parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
        pred = html.fromstring(pred, parser=parser)
        true = html.fromstring(true, parser=parser)
        if pred.xpath("body/table") and true.xpath("body/table"):
            pred = pred.xpath("body/table")[0]
            true = true.xpath("body/table")[0]
            if self.ignore_nodes:
                etree.strip_tags(pred, *self.ignore_nodes)
                etree.strip_tags(true, *self.ignore_nodes)
            n_nodes_pred = len(pred.xpath(".//*"))
            n_nodes_true = len(true.xpath(".//*"))
            n_nodes = max(n_nodes_pred, n_nodes_true)
            tree_pred = self.load_html_tree(pred)
            tree_true = self.load_html_tree(true)
            distance = apted.APTED(
                tree_pred, tree_true, CustomConfig()
            ).compute_edit_distance()
            return 1.0 - (float(distance) / n_nodes)
        else:
            return 0.0

    def batch_evaluate(self, results_json):
        """Computes TEDS score between the prediction and the ground truth of
        a batch of samples
        @params pred_json: {'FILENAME': 'HTML CODE', ...}
        @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
        @output: {'FILENAME': 'TEDS SCORE', ...}
        """
        samples = results_json.keys()
        print(f"Total samples: {len(samples)}")
        if self.n_jobs == 1:
            scores = [
                self.evaluate(
                    results_json[filename]["pred"],
                    results_json[filename]["gt"],
                )
                for filename in tqdm(samples)
            ]
        else:
            inputs = [
                {
                    "pred": results_json[filename]["pred"],
                    "true": results_json[filename]["gt"],
                }
                for filename in samples
            ]
            scores = parallel_process(
                inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
            )
        output = dict()
        for i, j in zip(samples, scores):
            if "span" in results_json[i]["gt"]:
                output[i] = dict(scores=j, type="complex")
            else:
                output[i] = dict(scores=j, type="simple")
        # scores = dict(zip(samples, scores))
        return output


def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
    """
    A parallel version of the map function with a progress bar.

    Args:
        array (array-like): An array to iterate over.
        function (function): A python function to apply to the elements of array
        n_jobs (int, default=16): The number of cores to use
        use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
            keyword arguments to function
        front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
            Useful for catching bugs
    Returns:
        [function(array[0]), function(array[1]), ...]
    """
    # We run the first few iterations serially to catch bugs
    if front_num > 0:
        front = [
            function(**a) if use_kwargs else function(a) for a in array[:front_num]
        ]
    else:
        front = []
    # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
    if n_jobs == 1:
        return front + [
            function(**a) if use_kwargs else function(a)
            for a in tqdm(array[front_num:])
        ]
    # Assemble the workers
    with ProcessPoolExecutor(max_workers=n_jobs) as pool:
        # Pass the elements of array into function
        if use_kwargs:
            futures = [pool.submit(function, **a) for a in array[front_num:]]
        else:
            futures = [pool.submit(function, a) for a in array[front_num:]]
        kwargs = {
            "total": len(futures),
            "unit": "it",
            "unit_scale": True,
            "leave": True,
        }
        # Print out the progress as tasks complete
        for f in tqdm(as_completed(futures), **kwargs):
            pass
    out = []
    # Get the results from the futures.
    for i, future in tqdm(enumerate(futures)):
        try:
            out.append(future.result())
        except Exception as e:
            out.append(e)
    return front + out


if __name__ == "__main__":
    import json
    import pprint
    import numpy as np
    import argparse

    parser = argparse.ArgumentParser(description="TEDS Computation")

    parser.add_argument("-f", "--file", help="path to html table results in json file")
    parser.add_argument("-t", "--type", help="html, html+cell")
    parser.add_argument("-n", "--njob", default=200, help="number of jobs in parallel")
    args = parser.parse_args()

    results_file = args.file
    with open(results_file, "r") as f:
        results_json = json.load(f)

    if args.type == "html":
        s_only = True
    else:
        s_only = False
    teds = TEDS(structure_only=s_only, n_jobs=args.njob)
    scores = teds.batch_evaluate(results_json)
    pp = pprint.PrettyPrinter()
    pp.pprint(scores)

    # compute teds for simple and complex tables
    total, simple, complex = list(), list(), list()
    for _, obj in scores.items():
        if obj["type"] == "simple":
            simple.append(obj["scores"])
        elif obj["type"] == "complex":
            complex.append(obj["scores"])
        total.append(obj["scores"])

    total, simple, complex = np.array(total), np.array(simple), np.array(complex)
    print(
        f"Simple: {np.mean(simple)} \nComplex: {np.mean(complex)} \nTotal: {np.mean(total)}"
    )