Spaces:
Build error
Build error
# code adapted from https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py | |
# tree edit distance video explanation: https://www.youtube.com/watch?v=6Ur8B35xCj8 | |
import apted | |
import distance | |
from collections import deque | |
from lxml import etree, html | |
from tqdm import tqdm | |
from concurrent.futures import ProcessPoolExecutor, as_completed | |
from typing import Tuple | |
class TableTree(apted.helpers.Tree): | |
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): | |
self.tag = tag | |
self.colspan = colspan | |
self.rowspan = rowspan | |
self.content = content | |
self.children = list(children) | |
def bracket(self): | |
"""Show tree using brackets notation.""" | |
if self.tag == "td": | |
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % ( | |
self.tag, | |
self.colspan, | |
self.rowspan, | |
self.content, | |
) | |
else: | |
result = '"tag": %s' % self.tag | |
for child in self.children: | |
result += child.bracket() | |
return "{{{}}}".format(result) | |
class CustomConfig(apted.Config): | |
def maximum(*sequences): | |
"""Get maximum possible value.""" | |
return max(map(len, sequences)) | |
def normalized_distance(self, *sequences): | |
"""Get distance from 0 to 1.""" | |
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) | |
def rename(self, node1, node2): | |
"""Compares attributes of trees""" | |
if ( | |
(node1.tag != node2.tag) | |
or (node1.colspan != node2.colspan) | |
or (node1.rowspan != node2.rowspan) | |
): | |
return 1.0 | |
if node1.tag == "td": | |
if node1.content or node2.content: | |
return self.normalized_distance(node1.content, node2.content) | |
return 0.0 | |
class TEDS(object): | |
"""Tree Edit Distance basead Similarity""" | |
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): | |
assert isinstance(n_jobs, int) and ( | |
n_jobs >= 1 | |
), "n_jobs must be an integer greather than 1" | |
self.structure_only = structure_only | |
self.n_jobs = n_jobs | |
self.ignore_nodes = ignore_nodes | |
self.__tokens__ = [] | |
def tokenize(self, node): | |
"""Tokenizes table cells""" | |
self.__tokens__.append("<%s>" % node.tag) | |
if node.text is not None: | |
self.__tokens__ += list(node.text) | |
for n in node.getchildren(): | |
self.tokenize(n) | |
if node.tag != "unk": | |
self.__tokens__.append("</%s>" % node.tag) | |
if node.tag != "td" and node.tail is not None: | |
self.__tokens__ += list(node.tail) | |
def load_html_tree(self, node, parent=None): | |
"""Converts HTML tree to the format required by apted""" | |
global __tokens__ | |
if node.tag == "td": | |
if self.structure_only: | |
cell = [] | |
else: | |
self.__tokens__ = [] | |
self.tokenize(node) | |
cell = self.__tokens__[1:-1].copy() | |
new_node = TableTree( | |
node.tag, | |
int(node.attrib.get("colspan", "1")), | |
int(node.attrib.get("rowspan", "1")), | |
cell, | |
*deque(), | |
) | |
else: | |
new_node = TableTree(node.tag, None, None, None, *deque()) | |
if parent is not None: | |
parent.children.append(new_node) | |
if node.tag != "td": | |
for n in node.getchildren(): | |
self.load_html_tree(n, new_node) | |
if parent is None: | |
return new_node | |
def evaluate(self, pred, true): | |
"""Computes TEDS score between the prediction and the ground truth of a | |
given sample | |
""" | |
if (not pred) or (not true): | |
return 0.0 | |
parser = html.HTMLParser(remove_comments=True, encoding="utf-8") | |
pred = html.fromstring(pred, parser=parser) | |
true = html.fromstring(true, parser=parser) | |
if pred.xpath("body/table") and true.xpath("body/table"): | |
pred = pred.xpath("body/table")[0] | |
true = true.xpath("body/table")[0] | |
if self.ignore_nodes: | |
etree.strip_tags(pred, *self.ignore_nodes) | |
etree.strip_tags(true, *self.ignore_nodes) | |
n_nodes_pred = len(pred.xpath(".//*")) | |
n_nodes_true = len(true.xpath(".//*")) | |
n_nodes = max(n_nodes_pred, n_nodes_true) | |
tree_pred = self.load_html_tree(pred) | |
tree_true = self.load_html_tree(true) | |
distance = apted.APTED( | |
tree_pred, tree_true, CustomConfig() | |
).compute_edit_distance() | |
return 1.0 - (float(distance) / n_nodes) | |
else: | |
return 0.0 | |
def batch_evaluate(self, results_json): | |
"""Computes TEDS score between the prediction and the ground truth of | |
a batch of samples | |
@params pred_json: {'FILENAME': 'HTML CODE', ...} | |
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} | |
@output: {'FILENAME': 'TEDS SCORE', ...} | |
""" | |
samples = results_json.keys() | |
print(f"Total samples: {len(samples)}") | |
if self.n_jobs == 1: | |
scores = [ | |
self.evaluate( | |
results_json[filename]["pred"], | |
results_json[filename]["gt"], | |
) | |
for filename in tqdm(samples) | |
] | |
else: | |
inputs = [ | |
{ | |
"pred": results_json[filename]["pred"], | |
"true": results_json[filename]["gt"], | |
} | |
for filename in samples | |
] | |
scores = parallel_process( | |
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1 | |
) | |
output = dict() | |
for i, j in zip(samples, scores): | |
if "span" in results_json[i]["gt"]: | |
output[i] = dict(scores=j, type="complex") | |
else: | |
output[i] = dict(scores=j, type="simple") | |
# scores = dict(zip(samples, scores)) | |
return output | |
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): | |
""" | |
A parallel version of the map function with a progress bar. | |
Args: | |
array (array-like): An array to iterate over. | |
function (function): A python function to apply to the elements of array | |
n_jobs (int, default=16): The number of cores to use | |
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of | |
keyword arguments to function | |
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. | |
Useful for catching bugs | |
Returns: | |
[function(array[0]), function(array[1]), ...] | |
""" | |
# We run the first few iterations serially to catch bugs | |
if front_num > 0: | |
front = [ | |
function(**a) if use_kwargs else function(a) for a in array[:front_num] | |
] | |
else: | |
front = [] | |
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. | |
if n_jobs == 1: | |
return front + [ | |
function(**a) if use_kwargs else function(a) | |
for a in tqdm(array[front_num:]) | |
] | |
# Assemble the workers | |
with ProcessPoolExecutor(max_workers=n_jobs) as pool: | |
# Pass the elements of array into function | |
if use_kwargs: | |
futures = [pool.submit(function, **a) for a in array[front_num:]] | |
else: | |
futures = [pool.submit(function, a) for a in array[front_num:]] | |
kwargs = { | |
"total": len(futures), | |
"unit": "it", | |
"unit_scale": True, | |
"leave": True, | |
} | |
# Print out the progress as tasks complete | |
for f in tqdm(as_completed(futures), **kwargs): | |
pass | |
out = [] | |
# Get the results from the futures. | |
for i, future in tqdm(enumerate(futures)): | |
try: | |
out.append(future.result()) | |
except Exception as e: | |
out.append(e) | |
return front + out | |
if __name__ == "__main__": | |
import json | |
import pprint | |
import numpy as np | |
import argparse | |
parser = argparse.ArgumentParser(description="TEDS Computation") | |
parser.add_argument("-f", "--file", help="path to html table results in json file") | |
parser.add_argument("-t", "--type", help="html, html+cell") | |
parser.add_argument("-n", "--njob", default=200, help="number of jobs in parallel") | |
args = parser.parse_args() | |
results_file = args.file | |
with open(results_file, "r") as f: | |
results_json = json.load(f) | |
if args.type == "html": | |
s_only = True | |
else: | |
s_only = False | |
teds = TEDS(structure_only=s_only, n_jobs=args.njob) | |
scores = teds.batch_evaluate(results_json) | |
pp = pprint.PrettyPrinter() | |
pp.pprint(scores) | |
# compute teds for simple and complex tables | |
total, simple, complex = list(), list(), list() | |
for _, obj in scores.items(): | |
if obj["type"] == "simple": | |
simple.append(obj["scores"]) | |
elif obj["type"] == "complex": | |
complex.append(obj["scores"]) | |
total.append(obj["scores"]) | |
total, simple, complex = np.array(total), np.array(simple), np.array(complex) | |
print( | |
f"Simple: {np.mean(simple)} \nComplex: {np.mean(complex)} \nTotal: {np.mean(total)}" | |
) | |