|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import random |
|
from collections import Counter |
|
|
|
from rag.utils import num_tokens_from_string |
|
from . import rag_tokenizer |
|
import re |
|
import copy |
|
import roman_numbers as r |
|
from word2number import w2n |
|
from cn2an import cn2an |
|
from PIL import Image |
|
|
|
import chardet |
|
|
|
all_codecs = [ |
|
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', |
|
'cp037', 'cp273', 'cp424', 'cp437', |
|
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', |
|
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869', |
|
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125', |
|
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', |
|
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr', |
|
'gb2312', 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', |
|
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1', |
|
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', |
|
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13', |
|
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u', |
|
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman', |
|
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213', |
|
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251', |
|
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256', |
|
'windows-1257', 'windows-1258', 'latin-2' |
|
] |
|
|
|
|
|
def find_codec(blob): |
|
detected = chardet.detect(blob[:1024]) |
|
if detected['confidence'] > 0.5: |
|
return detected['encoding'] |
|
|
|
for c in all_codecs: |
|
try: |
|
blob[:1024].decode(c) |
|
return c |
|
except Exception: |
|
pass |
|
try: |
|
blob.decode(c) |
|
return c |
|
except Exception: |
|
pass |
|
|
|
return "utf-8" |
|
|
|
QUESTION_PATTERN = [ |
|
r"第([零一二三四五六七八九十百0-9]+)问", |
|
r"第([零一二三四五六七八九十百0-9]+)条", |
|
r"[\((]([零一二三四五六七八九十百]+)[\))]", |
|
r"第([0-9]+)问", |
|
r"第([0-9]+)条", |
|
r"([0-9]{1,2})[\. 、]", |
|
r"([零一二三四五六七八九十百]+)[ 、]", |
|
r"[\((]([0-9]{1,2})[\))]", |
|
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", |
|
r"QUESTION (I+V?|VI*|XI|IX|X)", |
|
r"QUESTION ([0-9]+)", |
|
] |
|
|
|
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list): |
|
section, last_section = box['text'], last_box['text'] |
|
q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+' |
|
full_reg = reg + q_reg |
|
has_bull = re.match(full_reg, section) |
|
index_str = None |
|
if has_bull: |
|
if 'x0' not in last_box: |
|
last_box['x0'] = box['x0'] |
|
if 'top' not in last_box: |
|
last_box['top'] = box['top'] |
|
if last_bull and box['x0']-last_box['x0']>10: |
|
return None, last_index |
|
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20: |
|
return None, last_index |
|
avg_bull_x0 = 0 |
|
if bull_x0_list: |
|
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list) |
|
else: |
|
avg_bull_x0 = box['x0'] |
|
if box['x0'] - avg_bull_x0 > 10: |
|
return None, last_index |
|
index_str = has_bull.group(1) |
|
index = index_int(index_str) |
|
if last_section[-1] == ':' or last_section[-1] == ':': |
|
return None, last_index |
|
if not last_index or index >= last_index: |
|
bull_x0_list.append(box['x0']) |
|
return has_bull, index |
|
if section[-1] == '?' or section[-1] == '?': |
|
bull_x0_list.append(box['x0']) |
|
return has_bull, index |
|
if box['layout_type'] == 'title': |
|
bull_x0_list.append(box['x0']) |
|
return has_bull, index |
|
pure_section = section.lstrip(re.match(reg, section).group()).lower() |
|
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)' |
|
if re.match(ask_reg, pure_section): |
|
bull_x0_list.append(box['x0']) |
|
return has_bull, index |
|
return None, last_index |
|
|
|
def index_int(index_str): |
|
res = -1 |
|
try: |
|
res=int(index_str) |
|
except ValueError: |
|
try: |
|
res=w2n.word_to_num(index_str) |
|
except ValueError: |
|
try: |
|
res = cn2an(index_str) |
|
except ValueError: |
|
try: |
|
res = r.number(index_str) |
|
except ValueError: |
|
return -1 |
|
return res |
|
|
|
def qbullets_category(sections): |
|
global QUESTION_PATTERN |
|
hits = [0] * len(QUESTION_PATTERN) |
|
for i, pro in enumerate(QUESTION_PATTERN): |
|
for sec in sections: |
|
if re.match(pro, sec) and not not_bullet(sec): |
|
hits[i] += 1 |
|
break |
|
maxium = 0 |
|
res = -1 |
|
for i, h in enumerate(hits): |
|
if h <= maxium: |
|
continue |
|
res = i |
|
maxium = h |
|
return res, QUESTION_PATTERN[res] |
|
|
|
|
|
BULLET_PATTERN = [[ |
|
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", |
|
r"第[零一二三四五六七八九十百0-9]+章", |
|
r"第[零一二三四五六七八九十百0-9]+节", |
|
r"第[零一二三四五六七八九十百0-9]+条", |
|
r"[\((][零一二三四五六七八九十百]+[\))]", |
|
], [ |
|
r"第[0-9]+章", |
|
r"第[0-9]+节", |
|
r"[0-9]{,2}[\. 、]", |
|
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]", |
|
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", |
|
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", |
|
], [ |
|
r"第[零一二三四五六七八九十百0-9]+章", |
|
r"第[零一二三四五六七八九十百0-9]+节", |
|
r"[零一二三四五六七八九十百]+[ 、]", |
|
r"[\((][零一二三四五六七八九十百]+[\))]", |
|
r"[\((][0-9]{,2}[\))]", |
|
], [ |
|
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", |
|
r"Chapter (I+V?|VI*|XI|IX|X)", |
|
r"Section [0-9]+", |
|
r"Article [0-9]+" |
|
] |
|
] |
|
|
|
|
|
def random_choices(arr, k): |
|
k = min(len(arr), k) |
|
return random.choices(arr, k=k) |
|
|
|
|
|
def not_bullet(line): |
|
patt = [ |
|
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}" |
|
] |
|
return any([re.match(r, line) for r in patt]) |
|
|
|
|
|
def bullets_category(sections): |
|
global BULLET_PATTERN |
|
hits = [0] * len(BULLET_PATTERN) |
|
for i, pro in enumerate(BULLET_PATTERN): |
|
for sec in sections: |
|
for p in pro: |
|
if re.match(p, sec) and not not_bullet(sec): |
|
hits[i] += 1 |
|
break |
|
maxium = 0 |
|
res = -1 |
|
for i, h in enumerate(hits): |
|
if h <= maxium: |
|
continue |
|
res = i |
|
maxium = h |
|
return res |
|
|
|
|
|
def is_english(texts): |
|
eng = 0 |
|
if not texts: |
|
return False |
|
for t in texts: |
|
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()): |
|
eng += 1 |
|
if eng / len(texts) > 0.8: |
|
return True |
|
return False |
|
|
|
def is_chinese(text): |
|
chinese = 0 |
|
for ch in text: |
|
if '\u4e00' <= ch <= '\u9fff': |
|
chinese += 1 |
|
if chinese / len(text) > 0.2: |
|
return True |
|
return False |
|
|
|
def tokenize(d, t, eng): |
|
d["content_with_weight"] = t |
|
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t) |
|
d["content_ltks"] = rag_tokenizer.tokenize(t) |
|
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) |
|
|
|
|
|
def tokenize_chunks(chunks, doc, eng, pdf_parser=None): |
|
res = [] |
|
|
|
for ck in chunks: |
|
if len(ck.strip()) == 0: |
|
continue |
|
logging.debug("-- {}".format(ck)) |
|
d = copy.deepcopy(doc) |
|
if pdf_parser: |
|
try: |
|
d["image"], poss = pdf_parser.crop(ck, need_position=True) |
|
add_positions(d, poss) |
|
ck = pdf_parser.remove_tag(ck) |
|
except NotImplementedError: |
|
pass |
|
tokenize(d, ck, eng) |
|
res.append(d) |
|
return res |
|
|
|
|
|
def tokenize_chunks_docx(chunks, doc, eng, images): |
|
res = [] |
|
|
|
for ck, image in zip(chunks, images): |
|
if len(ck.strip()) == 0: |
|
continue |
|
logging.debug("-- {}".format(ck)) |
|
d = copy.deepcopy(doc) |
|
d["image"] = image |
|
tokenize(d, ck, eng) |
|
res.append(d) |
|
return res |
|
|
|
|
|
def tokenize_table(tbls, doc, eng, batch_size=10): |
|
res = [] |
|
|
|
for (img, rows), poss in tbls: |
|
if not rows: |
|
continue |
|
if isinstance(rows, str): |
|
d = copy.deepcopy(doc) |
|
tokenize(d, rows, eng) |
|
d["content_with_weight"] = rows |
|
if img: |
|
d["image"] = img |
|
if poss: |
|
add_positions(d, poss) |
|
res.append(d) |
|
continue |
|
de = "; " if eng else "; " |
|
for i in range(0, len(rows), batch_size): |
|
d = copy.deepcopy(doc) |
|
r = de.join(rows[i:i + batch_size]) |
|
tokenize(d, r, eng) |
|
d["image"] = img |
|
add_positions(d, poss) |
|
res.append(d) |
|
return res |
|
|
|
|
|
def add_positions(d, poss): |
|
if not poss: |
|
return |
|
page_num_int = [] |
|
position_int = [] |
|
top_int = [] |
|
for pn, left, right, top, bottom in poss: |
|
page_num_int.append(int(pn + 1)) |
|
top_int.append(int(top)) |
|
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom))) |
|
d["page_num_int"] = page_num_int |
|
d["position_int"] = position_int |
|
d["top_int"] = top_int |
|
|
|
|
|
def remove_contents_table(sections, eng=False): |
|
i = 0 |
|
while i < len(sections): |
|
def get(i): |
|
nonlocal sections |
|
return (sections[i] if isinstance(sections[i], |
|
type("")) else sections[i][0]).strip() |
|
|
|
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", |
|
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)): |
|
i += 1 |
|
continue |
|
sections.pop(i) |
|
if i >= len(sections): |
|
break |
|
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2]) |
|
while not prefix: |
|
sections.pop(i) |
|
if i >= len(sections): |
|
break |
|
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2]) |
|
sections.pop(i) |
|
if i >= len(sections) or not prefix: |
|
break |
|
for j in range(i, min(i + 128, len(sections))): |
|
if not re.match(prefix, get(j)): |
|
continue |
|
for _ in range(i, j): |
|
sections.pop(i) |
|
break |
|
|
|
|
|
def make_colon_as_title(sections): |
|
if not sections: |
|
return [] |
|
if isinstance(sections[0], type("")): |
|
return sections |
|
i = 0 |
|
while i < len(sections): |
|
txt, layout = sections[i] |
|
i += 1 |
|
txt = txt.split("@")[0].strip() |
|
if not txt: |
|
continue |
|
if txt[-1] not in "::": |
|
continue |
|
txt = txt[::-1] |
|
arr = re.split(r"([。?!!?;;]| \.)", txt) |
|
if len(arr) < 2 or len(arr[1]) < 32: |
|
continue |
|
sections.insert(i - 1, (arr[0][::-1], "title")) |
|
i += 1 |
|
|
|
|
|
def title_frequency(bull, sections): |
|
bullets_size = len(BULLET_PATTERN[bull]) |
|
levels = [bullets_size+1 for _ in range(len(sections))] |
|
if not sections or bull < 0: |
|
return bullets_size+1, levels |
|
|
|
for i, (txt, layout) in enumerate(sections): |
|
for j, p in enumerate(BULLET_PATTERN[bull]): |
|
if re.match(p, txt.strip()) and not not_bullet(txt): |
|
levels[i] = j |
|
break |
|
else: |
|
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]): |
|
levels[i] = bullets_size |
|
most_level = bullets_size+1 |
|
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1): |
|
if level <= bullets_size: |
|
most_level = level |
|
break |
|
return most_level, levels |
|
|
|
|
|
def not_title(txt): |
|
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): |
|
return False |
|
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32): |
|
return True |
|
return re.search(r"[,;,。;!!]", txt) |
|
|
|
|
|
def hierarchical_merge(bull, sections, depth): |
|
if not sections or bull < 0: |
|
return [] |
|
if isinstance(sections[0], type("")): |
|
sections = [(s, "") for s in sections] |
|
sections = [(t, o) for t, o in sections if |
|
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] |
|
bullets_size = len(BULLET_PATTERN[bull]) |
|
levels = [[] for _ in range(bullets_size + 2)] |
|
|
|
|
|
for i, (txt, layout) in enumerate(sections): |
|
for j, p in enumerate(BULLET_PATTERN[bull]): |
|
if re.match(p, txt.strip()): |
|
levels[j].append(i) |
|
break |
|
else: |
|
if re.search(r"(title|head)", layout) and not not_title(txt): |
|
levels[bullets_size].append(i) |
|
else: |
|
levels[bullets_size + 1].append(i) |
|
sections = [t for t, _ in sections] |
|
|
|
|
|
|
|
def binary_search(arr, target): |
|
if not arr: |
|
return -1 |
|
if target > arr[-1]: |
|
return len(arr) - 1 |
|
if target < arr[0]: |
|
return -1 |
|
s, e = 0, len(arr) |
|
while e - s > 1: |
|
i = (e + s) // 2 |
|
if target > arr[i]: |
|
s = i |
|
continue |
|
elif target < arr[i]: |
|
e = i |
|
continue |
|
else: |
|
assert False |
|
return s |
|
|
|
cks = [] |
|
readed = [False] * len(sections) |
|
levels = levels[::-1] |
|
for i, arr in enumerate(levels[:depth]): |
|
for j in arr: |
|
if readed[j]: |
|
continue |
|
readed[j] = True |
|
cks.append([j]) |
|
if i + 1 == len(levels) - 1: |
|
continue |
|
for ii in range(i + 1, len(levels)): |
|
jj = binary_search(levels[ii], j) |
|
if jj < 0: |
|
continue |
|
if levels[ii][jj] > cks[-1][-1]: |
|
cks[-1].pop(-1) |
|
cks[-1].append(levels[ii][jj]) |
|
for ii in cks[-1]: |
|
readed[ii] = True |
|
|
|
if not cks: |
|
return cks |
|
|
|
for i in range(len(cks)): |
|
cks[i] = [sections[j] for j in cks[i][::-1]] |
|
logging.debug("\n* ".join(cks[i])) |
|
|
|
res = [[]] |
|
num = [0] |
|
for ck in cks: |
|
if len(ck) == 1: |
|
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0])) |
|
if n + num[-1] < 218: |
|
res[-1].append(ck[0]) |
|
num[-1] += n |
|
continue |
|
res.append(ck) |
|
num.append(n) |
|
continue |
|
res.append(ck) |
|
num.append(218) |
|
|
|
return res |
|
|
|
|
|
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"): |
|
if not sections: |
|
return [] |
|
if isinstance(sections[0], type("")): |
|
sections = [(s, "") for s in sections] |
|
cks = [""] |
|
tk_nums = [0] |
|
|
|
def add_chunk(t, pos): |
|
nonlocal cks, tk_nums, delimiter |
|
tnum = num_tokens_from_string(t) |
|
if not pos: |
|
pos = "" |
|
if tnum < 8: |
|
pos = "" |
|
|
|
if tk_nums[-1] > chunk_token_num: |
|
|
|
if t.find(pos) < 0: |
|
t += pos |
|
cks.append(t) |
|
tk_nums.append(tnum) |
|
else: |
|
if cks[-1].find(pos) < 0: |
|
t += pos |
|
cks[-1] += t |
|
tk_nums[-1] += tnum |
|
|
|
for sec, pos in sections: |
|
add_chunk(sec, pos) |
|
|
|
return cks |
|
|
|
|
|
def docx_question_level(p, bull = -1): |
|
txt = re.sub(r"\u3000", " ", p.text).strip() |
|
if p.style.name.startswith('Heading'): |
|
return int(p.style.name.split(' ')[-1]), txt |
|
else: |
|
if bull < 0: |
|
return 0, txt |
|
for j, title in enumerate(BULLET_PATTERN[bull]): |
|
if re.match(title, txt): |
|
return j+1, txt |
|
return len(BULLET_PATTERN[bull]), txt |
|
|
|
|
|
def concat_img(img1, img2): |
|
if img1 and not img2: |
|
return img1 |
|
if not img1 and img2: |
|
return img2 |
|
if not img1 and not img2: |
|
return None |
|
width1, height1 = img1.size |
|
width2, height2 = img2.size |
|
|
|
new_width = max(width1, width2) |
|
new_height = height1 + height2 |
|
new_image = Image.new('RGB', (new_width, new_height)) |
|
|
|
new_image.paste(img1, (0, 0)) |
|
new_image.paste(img2, (0, height1)) |
|
|
|
return new_image |
|
|
|
|
|
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): |
|
if not sections: |
|
return [], [] |
|
|
|
cks = [""] |
|
images = [None] |
|
tk_nums = [0] |
|
|
|
def add_chunk(t, image, pos=""): |
|
nonlocal cks, tk_nums, delimiter |
|
tnum = num_tokens_from_string(t) |
|
if tnum < 8: |
|
pos = "" |
|
if tk_nums[-1] > chunk_token_num: |
|
if t.find(pos) < 0: |
|
t += pos |
|
cks.append(t) |
|
images.append(image) |
|
tk_nums.append(tnum) |
|
else: |
|
if cks[-1].find(pos) < 0: |
|
t += pos |
|
cks[-1] += t |
|
images[-1] = concat_img(images[-1], image) |
|
tk_nums[-1] += tnum |
|
|
|
for sec, image in sections: |
|
add_chunk(sec, image, '') |
|
|
|
return cks, images |
|
|
|
|