File size: 6,075 Bytes
dd5b268
a0b4833
fe5a8c6
a0b4833
dd5b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9729f3
dd5b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120daf4
dd5b268
120daf4
dd5b268
120daf4
dd5b268
120daf4
dd5b268
 
 
120daf4
 
 
 
 
 
dd5b268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import os
os.system('pip install "detectron2@git+https://github.com/facebookresearch/[email protected]#egg=detectron2"')

import io
import pandas as pd
import numpy as np
import gradio as gr

## for plotting
import matplotlib.pyplot as plt

## for ocr
import pdf2image
import cv2
import layoutparser as lp

from docx import Document
from docx.shared import Inches


def parse_doc(dic):
    for k,v in dic.items():
        if "Title" in k:
            print('\x1b[1;31m'+ v +'\x1b[0m')
        elif "Figure" in k:
            plt.figure(figsize=(10,5))
            plt.imshow(v)
            plt.show()
        else:
            print(v)
        print(" ")


def to_image(filename):
  doc = pdf2image.convert_from_path(filename, dpi=350, last_page=1)
  # Save imgs
  folder = "doc"
  if folder not in os.listdir():
      os.makedirs(folder)

  p = 1
  for page in doc:
      image_name = "page_"+str(p)+".jpg"  
      page.save(os.path.join(folder, image_name), "JPEG")
      p = p+1

  return doc



def detect(doc):
  # General
  model = lp.Detectron2LayoutModel("lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config",
                                 extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8],
                                 label_map={0:"Text", 1:"Title", 2:"List", 3:"Table", 4:"Figure"})
  ## turn img into array
  img = np.asarray(doc[0])

  ## predict
  detected = model.detect(img)

 
  return img, detected 


# sort detected
def split_page(img, n, axis):
    new_detected, start = [], 0
    for s in range(n):
        end = len(img[0])/3 * s if axis == "x" else len(img[1])/3
        section = lp.Interval(start=start, end=end, axis=axis).put_on_canvas(img)
        filter_detected = detected.filter_by(section, center=True)._blocks
        new_detected = new_detected + filter_detected
        start = end
    return lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)])



def get_detected(img, detected):
  n_cols,n_rows = 1,1

  ## if single page just sort based on y
  if (n_cols == 1) and (n_rows == 1):
      new_detected = detected.sort(key=lambda x: x.coordinates[1])
      detected = lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)])
      
  ## if multi columns sort by x,y
  elif (n_cols > 1) and (n_rows == 1):
      detected = split_page(img, n_cols, axis="x")

  ## if multi rows sort by y,x
  elif (n_cols > 1) and (n_rows == 1):
      detected = split_page(img, n_rows, axis="y")
      
  ## if multi columns-rows
  else:
      pass
  
  return detected


def predict_elements(img, detected)->dict:
  model = lp.TesseractAgent(languages='eng')
  dic_predicted = {}

  for block in [block for block in detected if block.type in ["Title","Text", "List"]]:
    ## segmentation
    segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img)
    ## extraction
    extracted = model.detect(segmented)
    ## save
    dic_predicted[str(block.id)+"-"+block.type] = extracted.replace('\n',' ').strip()

  for block in [block for block in detected if block.type == "Figure"]:
      ## segmentation
      segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img)
      ## save
      dic_predicted[str(block.id)+"-"+block.type] = segmented


  for block in [block for block in detected if block.type == "Table"]:
      ## segmentation
      segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img)
      ## extraction
      extracted = model.detect(segmented)
      ## save
      dic_predicted[str(block.id)+"-"+block.type] = pd.read_csv( io.StringIO(extracted) )

  
  return dic_predicted

def gen_doc(dic_predicted:dict):
  document = Document()

  for k,v in dic_predicted.items():

    if "Figure" in k:
      cv2.imwrite(f'{k}.jpg', dic_predicted[k])
      document.add_picture(f'{k}.jpg', width=Inches(3))

    elif "Table" in k:
      table = document.add_table(rows=v.shape[0], cols=v.shape[1])
      hdr_cells = table.rows[0].cells
      for idx, col in enumerate(v.columns):
        hdr_cells[idx].text = col
      for c in v.iterrows():
        
        for idx, col in enumerate(v.columns):
          try:
            if len(c[1][col].strip())>0:
              row_cells = table.add_row().cells
              row_cells[idx].text = str(c[1][col]) 
          except:
            continue
    
    else:
      document.add_paragraph(str(v))

  document.save('demo.docx')


def main_convert(filename):
  print(filename.name)
  doc = to_image(filename.name)

  img, detected = detect(doc)

  n_detected = get_detected(img, detected)

  dic_predicted = predict_elements(img, n_detected)

  gen_doc(dic_predicted)

  im_out = lp.draw_box(img, detected, box_width=5, box_alpha=0.2, show_element_type=True)
  dict_out = {}
  for k,v in dic_predicted.items():
    if "figure" not in k.lower():
      dict_out[k] = dic_predicted[k]

  return  'demo.docx', im_out, dict_out
  
  
inputs = [gr.File(type='file', label="Original PDF File")]
outputs = [gr.File(label="Converted DOC File"),gr.Image(type="PIL.Image", label="Detected Image"),  gr.JSON()]

title = "A Document AI parser"
description = "This demo uses AI Models to detect text, titles, tables, figures and lists as well as table cells from an Scanned document.\nBased on the layout it determines reading order and generates an MS-DOC file to Download."


io = gr.Interface(fn=main_convert, inputs=inputs, outputs=outputs, title=title, description=description, 
                  css= """.gr-button-primary { background: -webkit-linear-gradient( 
                    90deg, #355764 0%, #55a8a1 100% ) !important;     background: #355764;
                        background: linear-gradient( 
                    90deg, #355764 0%, #55a8a1 100% ) !important;
                        background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important;
                        background: -webkit-linear-gradient( 
                    90deg, #355764 0%, #55a8a1 100% ) !important;
                    color:white !important}"""
                  )
                  
io.launch()