Spaces:
Runtime error
Runtime error
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import logging | |
| import os | |
| import re | |
| from collections import Counter | |
| import numpy as np | |
| from huggingface_hub import snapshot_download | |
| from deepdoc.utils.file_utils import get_project_base_directory | |
| from deepdoc.utils import rag_tokenizer | |
| from .recognizer import Recognizer | |
| class TableStructureRecognizer(Recognizer): | |
| labels = [ | |
| "table", | |
| "table column", | |
| "table row", | |
| "table column header", | |
| "table projected row header", | |
| "table spanning cell", | |
| ] | |
| def __init__(self): | |
| try: | |
| super().__init__(self.labels, "tsr", os.path.join( | |
| get_project_base_directory(), | |
| "rag/res/deepdoc")) | |
| except Exception as e: | |
| super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc", | |
| local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), | |
| local_dir_use_symlinks=False)) | |
| def __call__(self, images, thr=0.2): | |
| tbls = super().__call__(images, thr) | |
| res = [] | |
| # align left&right for rows, align top&bottom for columns | |
| for tbl in tbls: | |
| lts = [{"label": b["type"], | |
| "score": b["score"], | |
| "x0": b["bbox"][0], "x1": b["bbox"][2], | |
| "top": b["bbox"][1], "bottom": b["bbox"][-1] | |
| } for b in tbl] | |
| if not lts: | |
| continue | |
| left = [b["x0"] for b in lts if b["label"].find( | |
| "row") > 0 or b["label"].find("header") > 0] | |
| right = [b["x1"] for b in lts if b["label"].find( | |
| "row") > 0 or b["label"].find("header") > 0] | |
| if not left: | |
| continue | |
| left = np.mean(left) if len(left) > 4 else np.min(left) | |
| right = np.mean(right) if len(right) > 4 else np.max(right) | |
| for b in lts: | |
| if b["label"].find("row") > 0 or b["label"].find("header") > 0: | |
| if b["x0"] > left: | |
| b["x0"] = left | |
| if b["x1"] < right: | |
| b["x1"] = right | |
| top = [b["top"] for b in lts if b["label"] == "table column"] | |
| bottom = [b["bottom"] for b in lts if b["label"] == "table column"] | |
| if not top: | |
| res.append(lts) | |
| continue | |
| top = np.median(top) if len(top) > 4 else np.min(top) | |
| bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) | |
| for b in lts: | |
| if b["label"] == "table column": | |
| if b["top"] > top: | |
| b["top"] = top | |
| if b["bottom"] < bottom: | |
| b["bottom"] = bottom | |
| res.append(lts) | |
| return res | |
| def is_caption(bx): | |
| patt = [ | |
| r"[图表]+[ 0-9::]{2,}" | |
| ] | |
| if any([re.match(p, bx["text"].strip()) for p in patt]) \ | |
| or bx["layout_type"].find("caption") >= 0: | |
| return True | |
| return False | |
| def blockType(b): | |
| patt = [ | |
| ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), | |
| (r"^(20|19)[0-9]{2}年$", "Dt"), | |
| (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"), | |
| ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"), | |
| (r"^第*[一二三四1-4]季度$", "Dt"), | |
| (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), | |
| (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), | |
| ("^[0-9.,+%/ -]+$", "Nu"), | |
| (r"^[0-9A-Z/\._~-]+$", "Ca"), | |
| (r"^[A-Z]*[a-z' -]+$", "En"), | |
| (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), | |
| (r"^.{1}$", "Sg") | |
| ] | |
| for p, n in patt: | |
| if re.search(p, b["text"].strip()): | |
| return n | |
| tks = [t for t in rag_tokenizer.tokenize(b["text"]).split(" ") if len(t) > 1] | |
| if len(tks) > 3: | |
| if len(tks) < 12: | |
| return "Tx" | |
| else: | |
| return "Lx" | |
| if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr": | |
| return "Nr" | |
| return "Ot" | |
| def construct_table(boxes, is_english=False, html=False): | |
| cap = "" | |
| i = 0 | |
| while i < len(boxes): | |
| if TableStructureRecognizer.is_caption(boxes[i]): | |
| if is_english: | |
| cap + " " | |
| cap += boxes[i]["text"] | |
| boxes.pop(i) | |
| i -= 1 | |
| i += 1 | |
| if not boxes: | |
| return [] | |
| for b in boxes: | |
| b["btype"] = TableStructureRecognizer.blockType(b) | |
| max_type = Counter([b["btype"] for b in boxes]).items() | |
| max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" | |
| logging.debug("MAXTYPE: " + max_type) | |
| rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] | |
| rowh = np.min(rowh) if rowh else 0 | |
| boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) | |
| #for b in boxes:print(b) | |
| boxes[0]["rn"] = 0 | |
| rows = [[boxes[0]]] | |
| btm = boxes[0]["bottom"] | |
| for b in boxes[1:]: | |
| b["rn"] = len(rows) - 1 | |
| lst_r = rows[-1] | |
| if lst_r[-1].get("R", "") != b.get("R", "") \ | |
| or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") | |
| ): # new row | |
| btm = b["bottom"] | |
| b["rn"] += 1 | |
| rows.append([b]) | |
| continue | |
| btm = (btm + b["bottom"]) / 2. | |
| rows[-1].append(b) | |
| colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] | |
| colwm = np.min(colwm) if colwm else 0 | |
| crosspage = len(set([b["page_number"] for b in boxes])) > 1 | |
| if crosspage: | |
| boxes = Recognizer.sort_X_firstly(boxes, colwm / 2, False) | |
| else: | |
| boxes = Recognizer.sort_C_firstly(boxes, colwm / 2) | |
| boxes[0]["cn"] = 0 | |
| cols = [[boxes[0]]] | |
| right = boxes[0]["x1"] | |
| for b in boxes[1:]: | |
| b["cn"] = len(cols) - 1 | |
| lst_c = cols[-1] | |
| if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ | |
| "page_number"]) \ | |
| or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col | |
| right = b["x1"] | |
| b["cn"] += 1 | |
| cols.append([b]) | |
| continue | |
| right = (right + b["x1"]) / 2. | |
| cols[-1].append(b) | |
| tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] | |
| for b in boxes: | |
| tbl[b["rn"]][b["cn"]].append(b) | |
| if len(rows) >= 4: | |
| # remove single in column | |
| j = 0 | |
| while j < len(tbl[0]): | |
| e, ii = 0, 0 | |
| for i in range(len(tbl)): | |
| if tbl[i][j]: | |
| e += 1 | |
| ii = i | |
| if e > 1: | |
| break | |
| if e > 1: | |
| j += 1 | |
| continue | |
| f = (j > 0 and tbl[ii][j - 1] and tbl[ii] | |
| [j - 1][0].get("text")) or j == 0 | |
| ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] | |
| [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) | |
| if f and ff: | |
| j += 1 | |
| continue | |
| bx = tbl[ii][j][0] | |
| logging.debug("Relocate column single: " + bx["text"]) | |
| # j column only has one value | |
| left, right = 100000, 100000 | |
| if j > 0 and not f: | |
| for i in range(len(tbl)): | |
| if tbl[i][j - 1]: | |
| left = min(left, np.min( | |
| [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) | |
| if j + 1 < len(tbl[0]) and not ff: | |
| for i in range(len(tbl)): | |
| if tbl[i][j + 1]: | |
| right = min(right, np.min( | |
| [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) | |
| assert left < 100000 or right < 100000 | |
| if left < right: | |
| for jj in range(j, len(tbl[0])): | |
| for i in range(len(tbl)): | |
| for a in tbl[i][jj]: | |
| a["cn"] -= 1 | |
| if tbl[ii][j - 1]: | |
| tbl[ii][j - 1].extend(tbl[ii][j]) | |
| else: | |
| tbl[ii][j - 1] = tbl[ii][j] | |
| for i in range(len(tbl)): | |
| tbl[i].pop(j) | |
| else: | |
| for jj in range(j + 1, len(tbl[0])): | |
| for i in range(len(tbl)): | |
| for a in tbl[i][jj]: | |
| a["cn"] -= 1 | |
| if tbl[ii][j + 1]: | |
| tbl[ii][j + 1].extend(tbl[ii][j]) | |
| else: | |
| tbl[ii][j + 1] = tbl[ii][j] | |
| for i in range(len(tbl)): | |
| tbl[i].pop(j) | |
| cols.pop(j) | |
| assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( | |
| len(cols), len(tbl[0])) | |
| if len(cols) >= 4: | |
| # remove single in row | |
| i = 0 | |
| while i < len(tbl): | |
| e, jj = 0, 0 | |
| for j in range(len(tbl[i])): | |
| if tbl[i][j]: | |
| e += 1 | |
| jj = j | |
| if e > 1: | |
| break | |
| if e > 1: | |
| i += 1 | |
| continue | |
| f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] | |
| [jj][0].get("text")) or i == 0 | |
| ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] | |
| [jj][0].get("text")) or i + 1 >= len(tbl) | |
| if f and ff: | |
| i += 1 | |
| continue | |
| bx = tbl[i][jj][0] | |
| logging.debug("Relocate row single: " + bx["text"]) | |
| # i row only has one value | |
| up, down = 100000, 100000 | |
| if i > 0 and not f: | |
| for j in range(len(tbl[i - 1])): | |
| if tbl[i - 1][j]: | |
| up = min(up, np.min( | |
| [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) | |
| if i + 1 < len(tbl) and not ff: | |
| for j in range(len(tbl[i + 1])): | |
| if tbl[i + 1][j]: | |
| down = min(down, np.min( | |
| [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) | |
| assert up < 100000 or down < 100000 | |
| if up < down: | |
| for ii in range(i, len(tbl)): | |
| for j in range(len(tbl[ii])): | |
| for a in tbl[ii][j]: | |
| a["rn"] -= 1 | |
| if tbl[i - 1][jj]: | |
| tbl[i - 1][jj].extend(tbl[i][jj]) | |
| else: | |
| tbl[i - 1][jj] = tbl[i][jj] | |
| tbl.pop(i) | |
| else: | |
| for ii in range(i + 1, len(tbl)): | |
| for j in range(len(tbl[ii])): | |
| for a in tbl[ii][j]: | |
| a["rn"] -= 1 | |
| if tbl[i + 1][jj]: | |
| tbl[i + 1][jj].extend(tbl[i][jj]) | |
| else: | |
| tbl[i + 1][jj] = tbl[i][jj] | |
| tbl.pop(i) | |
| rows.pop(i) | |
| # which rows are headers | |
| hdset = set([]) | |
| for i in range(len(tbl)): | |
| cnt, h = 0, 0 | |
| for j, arr in enumerate(tbl[i]): | |
| if not arr: | |
| continue | |
| cnt += 1 | |
| if max_type == "Nu" and arr[0]["btype"] == "Nu": | |
| continue | |
| if any([a.get("H") for a in arr]) \ | |
| or (max_type == "Nu" and arr[0]["btype"] != "Nu"): | |
| h += 1 | |
| if h / cnt > 0.5: | |
| hdset.add(i) | |
| if html: | |
| return TableStructureRecognizer.__html_table(cap, hdset, | |
| TableStructureRecognizer.__cal_spans(boxes, rows, | |
| cols, tbl, True) | |
| ) | |
| return TableStructureRecognizer.__desc_table(cap, hdset, | |
| TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, | |
| False), | |
| is_english) | |
| def __html_table(cap, hdset, tbl): | |
| # constrcut HTML | |
| html = "<table>" | |
| if cap: | |
| html += f"<caption>{cap}</caption>" | |
| for i in range(len(tbl)): | |
| row = "<tr>" | |
| txts = [] | |
| for j, arr in enumerate(tbl[i]): | |
| if arr is None: | |
| continue | |
| if not arr: | |
| row += "<td></td>" if i not in hdset else "<th></th>" | |
| continue | |
| txt = "" | |
| if arr: | |
| h = min(np.min([c["bottom"] - c["top"] | |
| for c in arr]) / 2, 10) | |
| txt = " ".join([c["text"] | |
| for c in Recognizer.sort_Y_firstly(arr, h)]) | |
| txts.append(txt) | |
| sp = "" | |
| if arr[0].get("colspan"): | |
| sp = "colspan={}".format(arr[0]["colspan"]) | |
| if arr[0].get("rowspan"): | |
| sp += " rowspan={}".format(arr[0]["rowspan"]) | |
| if i in hdset: | |
| row += f"<th {sp} >" + txt + "</th>" | |
| else: | |
| row += f"<td {sp} >" + txt + "</td>" | |
| if i in hdset: | |
| if all([t in hdset for t in txts]): | |
| continue | |
| for t in txts: | |
| hdset.add(t) | |
| if row != "<tr>": | |
| row += "</tr>" | |
| else: | |
| row = "" | |
| html += "\n" + row | |
| html += "\n</table>" | |
| return html | |
| def __desc_table(cap, hdr_rowno, tbl, is_english): | |
| # get text of every colomn in header row to become header text | |
| clmno = len(tbl[0]) | |
| rowno = len(tbl) | |
| headers = {} | |
| hdrset = set() | |
| lst_hdr = [] | |
| de = "的" if not is_english else " for " | |
| for r in sorted(list(hdr_rowno)): | |
| headers[r] = ["" for _ in range(clmno)] | |
| for i in range(clmno): | |
| if not tbl[r][i]: | |
| continue | |
| txt = " ".join([a["text"].strip() for a in tbl[r][i]]) | |
| headers[r][i] = txt | |
| hdrset.add(txt) | |
| if all([not t for t in headers[r]]): | |
| del headers[r] | |
| hdr_rowno.remove(r) | |
| continue | |
| for j in range(clmno): | |
| if headers[r][j]: | |
| continue | |
| if j >= len(lst_hdr): | |
| break | |
| headers[r][j] = lst_hdr[j] | |
| lst_hdr = headers[r] | |
| for i in range(rowno): | |
| if i not in hdr_rowno: | |
| continue | |
| for j in range(i + 1, rowno): | |
| if j not in hdr_rowno: | |
| break | |
| for k in range(clmno): | |
| if not headers[j - 1][k]: | |
| continue | |
| if headers[j][k].find(headers[j - 1][k]) >= 0: | |
| continue | |
| if len(headers[j][k]) > len(headers[j - 1][k]): | |
| headers[j][k] += (de if headers[j][k] | |
| else "") + headers[j - 1][k] | |
| else: | |
| headers[j][k] = headers[j - 1][k] \ | |
| + (de if headers[j - 1][k] else "") \ | |
| + headers[j][k] | |
| logging.debug( | |
| f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") | |
| row_txt = [] | |
| for i in range(rowno): | |
| if i in hdr_rowno: | |
| continue | |
| rtxt = [] | |
| def append(delimer): | |
| nonlocal rtxt, row_txt | |
| rtxt = delimer.join(rtxt) | |
| if row_txt and len(row_txt[-1]) + len(rtxt) < 64: | |
| row_txt[-1] += "\n" + rtxt | |
| else: | |
| row_txt.append(rtxt) | |
| r = 0 | |
| if len(headers.items()): | |
| _arr = [(i - r, r) for r, _ in headers.items() if r < i] | |
| if _arr: | |
| _, r = min(_arr, key=lambda x: x[0]) | |
| if r not in headers and clmno <= 2: | |
| for j in range(clmno): | |
| if not tbl[i][j]: | |
| continue | |
| txt = "".join([a["text"].strip() for a in tbl[i][j]]) | |
| if txt: | |
| rtxt.append(txt) | |
| if rtxt: | |
| append(":") | |
| continue | |
| for j in range(clmno): | |
| if not tbl[i][j]: | |
| continue | |
| txt = "".join([a["text"].strip() for a in tbl[i][j]]) | |
| if not txt: | |
| continue | |
| ctt = headers[r][j] if r in headers else "" | |
| if ctt: | |
| ctt += ":" | |
| ctt += txt | |
| if ctt: | |
| rtxt.append(ctt) | |
| if rtxt: | |
| row_txt.append("; ".join(rtxt)) | |
| if cap: | |
| if is_english: | |
| from_ = " in " | |
| else: | |
| from_ = "来自" | |
| row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] | |
| return row_txt | |
| def __cal_spans(boxes, rows, cols, tbl, html=True): | |
| # caculate span | |
| clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) | |
| for cln in cols] | |
| crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) | |
| for cln in cols] | |
| rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) | |
| for row in rows] | |
| rbtm = [np.mean([c.get("R_btm", c["bottom"]) | |
| for c in row]) for row in rows] | |
| for b in boxes: | |
| if "SP" not in b: | |
| continue | |
| b["colspan"] = [b["cn"]] | |
| b["rowspan"] = [b["rn"]] | |
| # col span | |
| for j in range(0, len(clft)): | |
| if j == b["cn"]: | |
| continue | |
| if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: | |
| continue | |
| if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: | |
| continue | |
| b["colspan"].append(j) | |
| # row span | |
| for j in range(0, len(rtop)): | |
| if j == b["rn"]: | |
| continue | |
| if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: | |
| continue | |
| if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: | |
| continue | |
| b["rowspan"].append(j) | |
| def join(arr): | |
| if not arr: | |
| return "" | |
| return "".join([t["text"] for t in arr]) | |
| # rm the spaning cells | |
| for i in range(len(tbl)): | |
| for j, arr in enumerate(tbl[i]): | |
| if not arr: | |
| continue | |
| if all(["rowspan" not in a and "colspan" not in a for a in arr]): | |
| continue | |
| rowspan, colspan = [], [] | |
| for a in arr: | |
| if isinstance(a.get("rowspan", 0), list): | |
| rowspan.extend(a["rowspan"]) | |
| if isinstance(a.get("colspan", 0), list): | |
| colspan.extend(a["colspan"]) | |
| rowspan, colspan = set(rowspan), set(colspan) | |
| if len(rowspan) < 2 and len(colspan) < 2: | |
| for a in arr: | |
| if "rowspan" in a: | |
| del a["rowspan"] | |
| if "colspan" in a: | |
| del a["colspan"] | |
| continue | |
| rowspan, colspan = sorted(rowspan), sorted(colspan) | |
| rowspan = list(range(rowspan[0], rowspan[-1] + 1)) | |
| colspan = list(range(colspan[0], colspan[-1] + 1)) | |
| assert i in rowspan, rowspan | |
| assert j in colspan, colspan | |
| arr = [] | |
| for r in rowspan: | |
| for c in colspan: | |
| arr_txt = join(arr) | |
| if tbl[r][c] and join(tbl[r][c]) != arr_txt: | |
| arr.extend(tbl[r][c]) | |
| tbl[r][c] = None if html else arr | |
| for a in arr: | |
| if len(rowspan) > 1: | |
| a["rowspan"] = len(rowspan) | |
| elif "rowspan" in a: | |
| del a["rowspan"] | |
| if len(colspan) > 1: | |
| a["colspan"] = len(colspan) | |
| elif "colspan" in a: | |
| del a["colspan"] | |
| tbl[rowspan[0]][colspan[0]] = arr | |
| return tbl | |