|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import re |
|
from collections import Counter |
|
|
|
import numpy as np |
|
from huggingface_hub import snapshot_download |
|
|
|
from api.utils.file_utils import get_project_base_directory |
|
from rag.nlp import rag_tokenizer |
|
from .recognizer import Recognizer |
|
|
|
|
|
class TableStructureRecognizer(Recognizer): |
|
labels = [ |
|
"table", |
|
"table column", |
|
"table row", |
|
"table column header", |
|
"table projected row header", |
|
"table spanning cell", |
|
] |
|
|
|
def __init__(self): |
|
try: |
|
super().__init__(self.labels, "tsr", os.path.join( |
|
get_project_base_directory(), |
|
"rag/res/deepdoc")) |
|
except Exception: |
|
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc", |
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), |
|
local_dir_use_symlinks=False)) |
|
|
|
def __call__(self, images, thr=0.2): |
|
tbls = super().__call__(images, thr) |
|
res = [] |
|
|
|
for tbl in tbls: |
|
lts = [{"label": b["type"], |
|
"score": b["score"], |
|
"x0": b["bbox"][0], "x1": b["bbox"][2], |
|
"top": b["bbox"][1], "bottom": b["bbox"][-1] |
|
} for b in tbl] |
|
if not lts: |
|
continue |
|
|
|
left = [b["x0"] for b in lts if b["label"].find( |
|
"row") > 0 or b["label"].find("header") > 0] |
|
right = [b["x1"] for b in lts if b["label"].find( |
|
"row") > 0 or b["label"].find("header") > 0] |
|
if not left: |
|
continue |
|
left = np.mean(left) if len(left) > 4 else np.min(left) |
|
right = np.mean(right) if len(right) > 4 else np.max(right) |
|
for b in lts: |
|
if b["label"].find("row") > 0 or b["label"].find("header") > 0: |
|
if b["x0"] > left: |
|
b["x0"] = left |
|
if b["x1"] < right: |
|
b["x1"] = right |
|
|
|
top = [b["top"] for b in lts if b["label"] == "table column"] |
|
bottom = [b["bottom"] for b in lts if b["label"] == "table column"] |
|
if not top: |
|
res.append(lts) |
|
continue |
|
top = np.median(top) if len(top) > 4 else np.min(top) |
|
bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) |
|
for b in lts: |
|
if b["label"] == "table column": |
|
if b["top"] > top: |
|
b["top"] = top |
|
if b["bottom"] < bottom: |
|
b["bottom"] = bottom |
|
|
|
res.append(lts) |
|
return res |
|
|
|
@staticmethod |
|
def is_caption(bx): |
|
patt = [ |
|
r"[图表]+[ 0-9::]{2,}" |
|
] |
|
if any([re.match(p, bx["text"].strip()) for p in patt]) \ |
|
or bx["layout_type"].find("caption") >= 0: |
|
return True |
|
return False |
|
|
|
@staticmethod |
|
def blockType(b): |
|
patt = [ |
|
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), |
|
(r"^(20|19)[0-9]{2}年$", "Dt"), |
|
(r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"), |
|
("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"), |
|
(r"^第*[一二三四1-4]季度$", "Dt"), |
|
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), |
|
(r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), |
|
("^[0-9.,+%/ -]+$", "Nu"), |
|
(r"^[0-9A-Z/\._~-]+$", "Ca"), |
|
(r"^[A-Z]*[a-z' -]+$", "En"), |
|
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), |
|
(r"^.{1}$", "Sg") |
|
] |
|
for p, n in patt: |
|
if re.search(p, b["text"].strip()): |
|
return n |
|
tks = [t for t in rag_tokenizer.tokenize(b["text"]).split() if len(t) > 1] |
|
if len(tks) > 3: |
|
if len(tks) < 12: |
|
return "Tx" |
|
else: |
|
return "Lx" |
|
|
|
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr": |
|
return "Nr" |
|
|
|
return "Ot" |
|
|
|
@staticmethod |
|
def construct_table(boxes, is_english=False, html=False): |
|
cap = "" |
|
i = 0 |
|
while i < len(boxes): |
|
if TableStructureRecognizer.is_caption(boxes[i]): |
|
if is_english: |
|
cap + " " |
|
cap += boxes[i]["text"] |
|
boxes.pop(i) |
|
i -= 1 |
|
i += 1 |
|
|
|
if not boxes: |
|
return [] |
|
for b in boxes: |
|
b["btype"] = TableStructureRecognizer.blockType(b) |
|
max_type = Counter([b["btype"] for b in boxes]).items() |
|
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" |
|
logging.debug("MAXTYPE: " + max_type) |
|
|
|
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] |
|
rowh = np.min(rowh) if rowh else 0 |
|
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) |
|
|
|
boxes[0]["rn"] = 0 |
|
rows = [[boxes[0]]] |
|
btm = boxes[0]["bottom"] |
|
for b in boxes[1:]: |
|
b["rn"] = len(rows) - 1 |
|
lst_r = rows[-1] |
|
if lst_r[-1].get("R", "") != b.get("R", "") \ |
|
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") |
|
): |
|
btm = b["bottom"] |
|
b["rn"] += 1 |
|
rows.append([b]) |
|
continue |
|
btm = (btm + b["bottom"]) / 2. |
|
rows[-1].append(b) |
|
|
|
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] |
|
colwm = np.min(colwm) if colwm else 0 |
|
crosspage = len(set([b["page_number"] for b in boxes])) > 1 |
|
if crosspage: |
|
boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False) |
|
else: |
|
boxes = Recognizer.sort_C_firstly(boxes, colwm / 2) |
|
boxes[0]["cn"] = 0 |
|
cols = [[boxes[0]]] |
|
right = boxes[0]["x1"] |
|
for b in boxes[1:]: |
|
b["cn"] = len(cols) - 1 |
|
lst_c = cols[-1] |
|
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ |
|
"page_number"]) \ |
|
or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): |
|
right = b["x1"] |
|
b["cn"] += 1 |
|
cols.append([b]) |
|
continue |
|
right = (right + b["x1"]) / 2. |
|
cols[-1].append(b) |
|
|
|
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] |
|
for b in boxes: |
|
tbl[b["rn"]][b["cn"]].append(b) |
|
|
|
if len(rows) >= 4: |
|
|
|
j = 0 |
|
while j < len(tbl[0]): |
|
e, ii = 0, 0 |
|
for i in range(len(tbl)): |
|
if tbl[i][j]: |
|
e += 1 |
|
ii = i |
|
if e > 1: |
|
break |
|
if e > 1: |
|
j += 1 |
|
continue |
|
f = (j > 0 and tbl[ii][j - 1] and tbl[ii] |
|
[j - 1][0].get("text")) or j == 0 |
|
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] |
|
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) |
|
if f and ff: |
|
j += 1 |
|
continue |
|
bx = tbl[ii][j][0] |
|
logging.debug("Relocate column single: " + bx["text"]) |
|
|
|
left, right = 100000, 100000 |
|
if j > 0 and not f: |
|
for i in range(len(tbl)): |
|
if tbl[i][j - 1]: |
|
left = min(left, np.min( |
|
[bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) |
|
if j + 1 < len(tbl[0]) and not ff: |
|
for i in range(len(tbl)): |
|
if tbl[i][j + 1]: |
|
right = min(right, np.min( |
|
[a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) |
|
assert left < 100000 or right < 100000 |
|
if left < right: |
|
for jj in range(j, len(tbl[0])): |
|
for i in range(len(tbl)): |
|
for a in tbl[i][jj]: |
|
a["cn"] -= 1 |
|
if tbl[ii][j - 1]: |
|
tbl[ii][j - 1].extend(tbl[ii][j]) |
|
else: |
|
tbl[ii][j - 1] = tbl[ii][j] |
|
for i in range(len(tbl)): |
|
tbl[i].pop(j) |
|
|
|
else: |
|
for jj in range(j + 1, len(tbl[0])): |
|
for i in range(len(tbl)): |
|
for a in tbl[i][jj]: |
|
a["cn"] -= 1 |
|
if tbl[ii][j + 1]: |
|
tbl[ii][j + 1].extend(tbl[ii][j]) |
|
else: |
|
tbl[ii][j + 1] = tbl[ii][j] |
|
for i in range(len(tbl)): |
|
tbl[i].pop(j) |
|
cols.pop(j) |
|
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( |
|
len(cols), len(tbl[0])) |
|
|
|
if len(cols) >= 4: |
|
|
|
i = 0 |
|
while i < len(tbl): |
|
e, jj = 0, 0 |
|
for j in range(len(tbl[i])): |
|
if tbl[i][j]: |
|
e += 1 |
|
jj = j |
|
if e > 1: |
|
break |
|
if e > 1: |
|
i += 1 |
|
continue |
|
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] |
|
[jj][0].get("text")) or i == 0 |
|
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] |
|
[jj][0].get("text")) or i + 1 >= len(tbl) |
|
if f and ff: |
|
i += 1 |
|
continue |
|
|
|
bx = tbl[i][jj][0] |
|
logging.debug("Relocate row single: " + bx["text"]) |
|
|
|
up, down = 100000, 100000 |
|
if i > 0 and not f: |
|
for j in range(len(tbl[i - 1])): |
|
if tbl[i - 1][j]: |
|
up = min(up, np.min( |
|
[bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) |
|
if i + 1 < len(tbl) and not ff: |
|
for j in range(len(tbl[i + 1])): |
|
if tbl[i + 1][j]: |
|
down = min(down, np.min( |
|
[a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) |
|
assert up < 100000 or down < 100000 |
|
if up < down: |
|
for ii in range(i, len(tbl)): |
|
for j in range(len(tbl[ii])): |
|
for a in tbl[ii][j]: |
|
a["rn"] -= 1 |
|
if tbl[i - 1][jj]: |
|
tbl[i - 1][jj].extend(tbl[i][jj]) |
|
else: |
|
tbl[i - 1][jj] = tbl[i][jj] |
|
tbl.pop(i) |
|
|
|
else: |
|
for ii in range(i + 1, len(tbl)): |
|
for j in range(len(tbl[ii])): |
|
for a in tbl[ii][j]: |
|
a["rn"] -= 1 |
|
if tbl[i + 1][jj]: |
|
tbl[i + 1][jj].extend(tbl[i][jj]) |
|
else: |
|
tbl[i + 1][jj] = tbl[i][jj] |
|
tbl.pop(i) |
|
rows.pop(i) |
|
|
|
|
|
hdset = set([]) |
|
for i in range(len(tbl)): |
|
cnt, h = 0, 0 |
|
for j, arr in enumerate(tbl[i]): |
|
if not arr: |
|
continue |
|
cnt += 1 |
|
if max_type == "Nu" and arr[0]["btype"] == "Nu": |
|
continue |
|
if any([a.get("H") for a in arr]) \ |
|
or (max_type == "Nu" and arr[0]["btype"] != "Nu"): |
|
h += 1 |
|
if h / cnt > 0.5: |
|
hdset.add(i) |
|
|
|
if html: |
|
return TableStructureRecognizer.__html_table(cap, hdset, |
|
TableStructureRecognizer.__cal_spans(boxes, rows, |
|
cols, tbl, True) |
|
) |
|
|
|
return TableStructureRecognizer.__desc_table(cap, hdset, |
|
TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, |
|
False), |
|
is_english) |
|
|
|
@staticmethod |
|
def __html_table(cap, hdset, tbl): |
|
|
|
html = "<table>" |
|
if cap: |
|
html += f"<caption>{cap}</caption>" |
|
for i in range(len(tbl)): |
|
row = "<tr>" |
|
txts = [] |
|
for j, arr in enumerate(tbl[i]): |
|
if arr is None: |
|
continue |
|
if not arr: |
|
row += "<td></td>" if i not in hdset else "<th></th>" |
|
continue |
|
txt = "" |
|
if arr: |
|
h = min(np.min([c["bottom"] - c["top"] |
|
for c in arr]) / 2, 10) |
|
txt = " ".join([c["text"] |
|
for c in Recognizer.sort_Y_firstly(arr, h)]) |
|
txts.append(txt) |
|
sp = "" |
|
if arr[0].get("colspan"): |
|
sp = "colspan={}".format(arr[0]["colspan"]) |
|
if arr[0].get("rowspan"): |
|
sp += " rowspan={}".format(arr[0]["rowspan"]) |
|
if i in hdset: |
|
row += f"<th {sp} >" + txt + "</th>" |
|
else: |
|
row += f"<td {sp} >" + txt + "</td>" |
|
|
|
if i in hdset: |
|
if all([t in hdset for t in txts]): |
|
continue |
|
for t in txts: |
|
hdset.add(t) |
|
|
|
if row != "<tr>": |
|
row += "</tr>" |
|
else: |
|
row = "" |
|
html += "\n" + row |
|
html += "\n</table>" |
|
return html |
|
|
|
@staticmethod |
|
def __desc_table(cap, hdr_rowno, tbl, is_english): |
|
|
|
clmno = len(tbl[0]) |
|
rowno = len(tbl) |
|
headers = {} |
|
hdrset = set() |
|
lst_hdr = [] |
|
de = "的" if not is_english else " for " |
|
for r in sorted(list(hdr_rowno)): |
|
headers[r] = ["" for _ in range(clmno)] |
|
for i in range(clmno): |
|
if not tbl[r][i]: |
|
continue |
|
txt = " ".join([a["text"].strip() for a in tbl[r][i]]) |
|
headers[r][i] = txt |
|
hdrset.add(txt) |
|
if all([not t for t in headers[r]]): |
|
del headers[r] |
|
hdr_rowno.remove(r) |
|
continue |
|
for j in range(clmno): |
|
if headers[r][j]: |
|
continue |
|
if j >= len(lst_hdr): |
|
break |
|
headers[r][j] = lst_hdr[j] |
|
lst_hdr = headers[r] |
|
for i in range(rowno): |
|
if i not in hdr_rowno: |
|
continue |
|
for j in range(i + 1, rowno): |
|
if j not in hdr_rowno: |
|
break |
|
for k in range(clmno): |
|
if not headers[j - 1][k]: |
|
continue |
|
if headers[j][k].find(headers[j - 1][k]) >= 0: |
|
continue |
|
if len(headers[j][k]) > len(headers[j - 1][k]): |
|
headers[j][k] += (de if headers[j][k] |
|
else "") + headers[j - 1][k] |
|
else: |
|
headers[j][k] = headers[j - 1][k] \ |
|
+ (de if headers[j - 1][k] else "") \ |
|
+ headers[j][k] |
|
|
|
logging.debug( |
|
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") |
|
row_txt = [] |
|
for i in range(rowno): |
|
if i in hdr_rowno: |
|
continue |
|
rtxt = [] |
|
|
|
def append(delimer): |
|
nonlocal rtxt, row_txt |
|
rtxt = delimer.join(rtxt) |
|
if row_txt and len(row_txt[-1]) + len(rtxt) < 64: |
|
row_txt[-1] += "\n" + rtxt |
|
else: |
|
row_txt.append(rtxt) |
|
|
|
r = 0 |
|
if len(headers.items()): |
|
_arr = [(i - r, r) for r, _ in headers.items() if r < i] |
|
if _arr: |
|
_, r = min(_arr, key=lambda x: x[0]) |
|
|
|
if r not in headers and clmno <= 2: |
|
for j in range(clmno): |
|
if not tbl[i][j]: |
|
continue |
|
txt = "".join([a["text"].strip() for a in tbl[i][j]]) |
|
if txt: |
|
rtxt.append(txt) |
|
if rtxt: |
|
append(":") |
|
continue |
|
|
|
for j in range(clmno): |
|
if not tbl[i][j]: |
|
continue |
|
txt = "".join([a["text"].strip() for a in tbl[i][j]]) |
|
if not txt: |
|
continue |
|
ctt = headers[r][j] if r in headers else "" |
|
if ctt: |
|
ctt += ":" |
|
ctt += txt |
|
if ctt: |
|
rtxt.append(ctt) |
|
|
|
if rtxt: |
|
row_txt.append("; ".join(rtxt)) |
|
|
|
if cap: |
|
if is_english: |
|
from_ = " in " |
|
else: |
|
from_ = "来自" |
|
row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] |
|
return row_txt |
|
|
|
@staticmethod |
|
def __cal_spans(boxes, rows, cols, tbl, html=True): |
|
|
|
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) |
|
for cln in cols] |
|
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) |
|
for cln in cols] |
|
rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) |
|
for row in rows] |
|
rbtm = [np.mean([c.get("R_btm", c["bottom"]) |
|
for c in row]) for row in rows] |
|
for b in boxes: |
|
if "SP" not in b: |
|
continue |
|
b["colspan"] = [b["cn"]] |
|
b["rowspan"] = [b["rn"]] |
|
|
|
for j in range(0, len(clft)): |
|
if j == b["cn"]: |
|
continue |
|
if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: |
|
continue |
|
if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: |
|
continue |
|
b["colspan"].append(j) |
|
|
|
for j in range(0, len(rtop)): |
|
if j == b["rn"]: |
|
continue |
|
if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: |
|
continue |
|
if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: |
|
continue |
|
b["rowspan"].append(j) |
|
|
|
def join(arr): |
|
if not arr: |
|
return "" |
|
return "".join([t["text"] for t in arr]) |
|
|
|
|
|
for i in range(len(tbl)): |
|
for j, arr in enumerate(tbl[i]): |
|
if not arr: |
|
continue |
|
if all(["rowspan" not in a and "colspan" not in a for a in arr]): |
|
continue |
|
rowspan, colspan = [], [] |
|
for a in arr: |
|
if isinstance(a.get("rowspan", 0), list): |
|
rowspan.extend(a["rowspan"]) |
|
if isinstance(a.get("colspan", 0), list): |
|
colspan.extend(a["colspan"]) |
|
rowspan, colspan = set(rowspan), set(colspan) |
|
if len(rowspan) < 2 and len(colspan) < 2: |
|
for a in arr: |
|
if "rowspan" in a: |
|
del a["rowspan"] |
|
if "colspan" in a: |
|
del a["colspan"] |
|
continue |
|
rowspan, colspan = sorted(rowspan), sorted(colspan) |
|
rowspan = list(range(rowspan[0], rowspan[-1] + 1)) |
|
colspan = list(range(colspan[0], colspan[-1] + 1)) |
|
assert i in rowspan, rowspan |
|
assert j in colspan, colspan |
|
arr = [] |
|
for r in rowspan: |
|
for c in colspan: |
|
arr_txt = join(arr) |
|
if tbl[r][c] and join(tbl[r][c]) != arr_txt: |
|
arr.extend(tbl[r][c]) |
|
tbl[r][c] = None if html else arr |
|
for a in arr: |
|
if len(rowspan) > 1: |
|
a["rowspan"] = len(rowspan) |
|
elif "rowspan" in a: |
|
del a["rowspan"] |
|
if len(colspan) > 1: |
|
a["colspan"] = len(colspan) |
|
elif "colspan" in a: |
|
del a["colspan"] |
|
tbl[rowspan[0]][colspan[0]] = arr |
|
|
|
return tbl |
|
|