yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
# code adapted from https://github.com/ibm-aur-nlp/PubTabNet/blob/master/src/metric.py
# tree edit distance video explanation: https://www.youtube.com/watch?v=6Ur8B35xCj8
import apted
import distance
from collections import deque
from lxml import etree, html
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Tuple
class TableTree(apted.helpers.Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)
def bracket(self):
"""Show tree using brackets notation."""
if self.tag == "td":
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
self.tag,
self.colspan,
self.rowspan,
self.content,
)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(apted.Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value."""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1."""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (
(node1.tag != node2.tag)
or (node1.colspan != node2.colspan)
or (node1.rowspan != node2.rowspan)
):
return 1.0
if node1.tag == "td":
if node1.content or node2.content:
return self.normalized_distance(node1.content, node2.content)
return 0.0
class TEDS(object):
"""Tree Edit Distance basead Similarity"""
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (
n_jobs >= 1
), "n_jobs must be an integer greather than 1"
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
"""Tokenizes table cells"""
self.__tokens__.append("<%s>" % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != "unk":
self.__tokens__.append("</%s>" % node.tag)
if node.tag != "td" and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
"""Converts HTML tree to the format required by apted"""
global __tokens__
if node.tag == "td":
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
new_node = TableTree(
node.tag,
int(node.attrib.get("colspan", "1")),
int(node.attrib.get("rowspan", "1")),
cell,
*deque(),
)
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != "td":
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
"""Computes TEDS score between the prediction and the ground truth of a
given sample
"""
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath("body/table") and true.xpath("body/table"):
pred = pred.xpath("body/table")[0]
true = true.xpath("body/table")[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
distance = apted.APTED(
tree_pred, tree_true, CustomConfig()
).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
def batch_evaluate(self, results_json):
"""Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
"""
samples = results_json.keys()
print(f"Total samples: {len(samples)}")
if self.n_jobs == 1:
scores = [
self.evaluate(
results_json[filename]["pred"],
results_json[filename]["gt"],
)
for filename in tqdm(samples)
]
else:
inputs = [
{
"pred": results_json[filename]["pred"],
"true": results_json[filename]["gt"],
}
for filename in samples
]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
)
output = dict()
for i, j in zip(samples, scores):
if "span" in results_json[i]["gt"]:
output[i] = dict(scores=j, type="complex")
else:
output[i] = dict(scores=j, type="simple")
# scores = dict(zip(samples, scores))
return output
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [
function(**a) if use_kwargs else function(a) for a in array[:front_num]
]
else:
front = []
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [
function(**a) if use_kwargs else function(a)
for a in tqdm(array[front_num:])
]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
"total": len(futures),
"unit": "it",
"unit_scale": True,
"leave": True,
}
# Print out the progress as tasks complete
for f in tqdm(as_completed(futures), **kwargs):
pass
out = []
# Get the results from the futures.
for i, future in tqdm(enumerate(futures)):
try:
out.append(future.result())
except Exception as e:
out.append(e)
return front + out
if __name__ == "__main__":
import json
import pprint
import numpy as np
import argparse
parser = argparse.ArgumentParser(description="TEDS Computation")
parser.add_argument("-f", "--file", help="path to html table results in json file")
parser.add_argument("-t", "--type", help="html, html+cell")
parser.add_argument("-n", "--njob", default=200, help="number of jobs in parallel")
args = parser.parse_args()
results_file = args.file
with open(results_file, "r") as f:
results_json = json.load(f)
if args.type == "html":
s_only = True
else:
s_only = False
teds = TEDS(structure_only=s_only, n_jobs=args.njob)
scores = teds.batch_evaluate(results_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)
# compute teds for simple and complex tables
total, simple, complex = list(), list(), list()
for _, obj in scores.items():
if obj["type"] == "simple":
simple.append(obj["scores"])
elif obj["type"] == "complex":
complex.append(obj["scores"])
total.append(obj["scores"])
total, simple, complex = np.array(total), np.array(simple), np.array(complex)
print(
f"Simple: {np.mean(simple)} \nComplex: {np.mean(complex)} \nTotal: {np.mean(total)}"
)