import os import os os.system('pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2"') import io import pandas as pd import numpy as np import gradio as gr ## for plotting import matplotlib.pyplot as plt ## for ocr import pdf2image import cv2 import layoutparser as lp from docx import Document from docx.shared import Inches def parse_doc(dic): for k,v in dic.items(): if "Title" in k: print('\x1b[1;31m'+ v +'\x1b[0m') elif "Figure" in k: plt.figure(figsize=(10,5)) plt.imshow(v) plt.show() else: print(v) print(" ") def to_image(filename): doc = pdf2image.convert_from_path(filename, dpi=350, last_page=1) # Save imgs folder = "doc" if folder not in os.listdir(): os.makedirs(folder) p = 1 for page in doc: image_name = "page_"+str(p)+".jpg" page.save(os.path.join(folder, image_name), "JPEG") p = p+1 return doc def detect(doc): # General model = lp.Detectron2LayoutModel("lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config", extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8], label_map={0:"Text", 1:"Title", 2:"List", 3:"Table", 4:"Figure"}) ## turn img into array img = np.asarray(doc[0]) ## predict detected = model.detect(img) return img, detected # sort detected def split_page(img, n, axis): new_detected, start = [], 0 for s in range(n): end = len(img[0])/3 * s if axis == "x" else len(img[1])/3 section = lp.Interval(start=start, end=end, axis=axis).put_on_canvas(img) filter_detected = detected.filter_by(section, center=True)._blocks new_detected = new_detected + filter_detected start = end return lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)]) def get_detected(img, detected): n_cols,n_rows = 1,1 ## if single page just sort based on y if (n_cols == 1) and (n_rows == 1): new_detected = detected.sort(key=lambda x: x.coordinates[1]) detected = lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)]) ## if multi columns sort by x,y elif (n_cols > 1) and (n_rows == 1): detected = split_page(img, n_cols, axis="x") ## if multi rows sort by y,x elif (n_cols > 1) and (n_rows == 1): detected = split_page(img, n_rows, axis="y") ## if multi columns-rows else: pass return detected def predict_elements(img, detected)->dict: model = lp.TesseractAgent(languages='eng') dic_predicted = {} for block in [block for block in detected if block.type in ["Title","Text", "List"]]: ## segmentation segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img) ## extraction extracted = model.detect(segmented) ## save dic_predicted[str(block.id)+"-"+block.type] = extracted.replace('\n',' ').strip() for block in [block for block in detected if block.type == "Figure"]: ## segmentation segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img) ## save dic_predicted[str(block.id)+"-"+block.type] = segmented for block in [block for block in detected if block.type == "Table"]: ## segmentation segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img) ## extraction extracted = model.detect(segmented) ## save dic_predicted[str(block.id)+"-"+block.type] = pd.read_csv( io.StringIO(extracted) ) return dic_predicted def gen_doc(dic_predicted:dict): document = Document() for k,v in dic_predicted.items(): if "Figure" in k: cv2.imwrite(f'{k}.jpg', dic_predicted[k]) document.add_picture(f'{k}.jpg', width=Inches(3)) elif "Table" in k: table = document.add_table(rows=v.shape[0], cols=v.shape[1]) hdr_cells = table.rows[0].cells for idx, col in enumerate(v.columns): hdr_cells[idx].text = col for c in v.iterrows(): for idx, col in enumerate(v.columns): try: if len(c[1][col].strip())>0: row_cells = table.add_row().cells row_cells[idx].text = str(c[1][col]) except: continue else: document.add_paragraph(str(v)) document.save('demo.docx') def main_convert(filename): print(filename.name) doc = to_image(filename.name) img, detected = detect(doc) n_detected = get_detected(img, detected) dic_predicted = predict_elements(img, n_detected) gen_doc(dic_predicted) im_out = lp.draw_box(img, detected, box_width=5, box_alpha=0.2, show_element_type=True) dict_out = {} for k,v in dic_predicted.items(): if "figure" not in k.lower(): dict_out[k] = dic_predicted[k] return 'demo.docx', im_out, dict_out inputs = [gr.File(type='file', label="Original PDF File")] outputs = [gr.File(label="Converted DOC File"),gr.Image(type="PIL.Image", label="Detected Image"), gr.JSON()] title = "A Document AI parser" description = "This demo uses AI Models to detect text, titles, tables, figures and lists as well as table cells from an Scanned document.\nBased on the layout it determines reading order and generates an MS-DOC file to Download." io = gr.Interface(fn=main_convert, inputs=inputs, outputs=outputs, title=title, description=description, css= """.gr-button-primary { background: -webkit-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important; background: #355764; background: linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important; background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important; background: -webkit-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important; color:white !important}""" ) io.launch()