Armandoliv commited on
Commit
dd5b268
·
1 Parent(s): 3c96ee0

create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import pandas as pd
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ ## for plotting
8
+ import matplotlib.pyplot as plt
9
+
10
+ ## for ocr
11
+ import pdf2image
12
+ import cv2
13
+ import layoutparser as lp
14
+
15
+ from docx import Document
16
+ from docx.shared import Inches
17
+
18
+
19
+ def parse_doc(dic):
20
+ for k,v in dic.items():
21
+ if "Title" in k:
22
+ print('\x1b[1;31m'+ v +'\x1b[0m')
23
+ elif "Figure" in k:
24
+ plt.figure(figsize=(10,5))
25
+ plt.imshow(v)
26
+ plt.show()
27
+ else:
28
+ print(v)
29
+ print(" ")
30
+
31
+
32
+ def to_image(filename):
33
+ doc = pdf2image.convert_from_path(filename, dpi=350, last_page=1)
34
+ # Save imgs
35
+ folder = "doc"
36
+ if folder not in os.listdir():
37
+ os.makedirs(folder)
38
+
39
+ p = 1
40
+ for page in doc:
41
+ image_name = "page_"+str(p)+".jpg"
42
+ page.save(os.path.join(folder, image_name), "JPEG")
43
+ p = p+1
44
+
45
+ return doc
46
+
47
+
48
+
49
+ def detect(doc):
50
+ # General
51
+ model = lp.Detectron2LayoutModel("lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config",
52
+ extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8],
53
+ label_map={0:"Text", 1:"Title", 2:"List", 3:"Table", 4:"Figure"})
54
+ ## turn img into array
55
+ img = np.asarray(doc[0])
56
+
57
+ ## predict
58
+ detected = model.detect(img)
59
+
60
+
61
+ return img, detected
62
+
63
+
64
+ # sort detected
65
+ def split_page(img, n, axis):
66
+ new_detected, start = [], 0
67
+ for s in range(n):
68
+ end = len(img[0])/3 * s if axis == "x" else len(img[1])/3
69
+ section = lp.Interval(start=start, end=end, axis=axis).put_on_canvas(img)
70
+ filter_detected = detected.filter_by(section, center=True)._blocks
71
+ new_detected = new_detected + filter_detected
72
+ start = end
73
+ return lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)])
74
+
75
+
76
+
77
+ def get_detected(img, detected):
78
+ n_cols,n_rows = 1,1
79
+
80
+ ## if single page just sort based on y
81
+ if (n_cols == 1) and (n_rows == 1):
82
+ new_detected = detected.sort(key=lambda x: x.coordinates[1])
83
+ detected = lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)])
84
+
85
+ ## if multi columns sort by x,y
86
+ elif (n_cols > 1) and (n_rows == 1):
87
+ detected = split_page(img, n_cols, axis="x")
88
+
89
+ ## if multi rows sort by y,x
90
+ elif (n_cols > 1) and (n_rows == 1):
91
+ detected = split_page(img, n_rows, axis="y")
92
+
93
+ ## if multi columns-rows
94
+ else:
95
+ pass
96
+
97
+ return detected
98
+
99
+
100
+ def predict_elements(img, detected)->dict:
101
+ model = lp.TesseractAgent(languages='eng')
102
+ dic_predicted = {}
103
+
104
+ for block in [block for block in detected if block.type in ["Title","Text", "List"]]:
105
+ ## segmentation
106
+ segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img)
107
+ ## extraction
108
+ extracted = model.detect(segmented)
109
+ ## save
110
+ dic_predicted[str(block.id)+"-"+block.type] = extracted.replace('\n',' ').strip()
111
+
112
+ for block in [block for block in detected if block.type == "Figure"]:
113
+ ## segmentation
114
+ segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img)
115
+ ## save
116
+ dic_predicted[str(block.id)+"-"+block.type] = segmented
117
+
118
+
119
+ for block in [block for block in detected if block.type == "Table"]:
120
+ ## segmentation
121
+ segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img)
122
+ ## extraction
123
+ extracted = model.detect(segmented)
124
+ ## save
125
+ dic_predicted[str(block.id)+"-"+block.type] = pd.read_csv( io.StringIO(extracted) )
126
+
127
+
128
+ return dic_predicted
129
+
130
+ def gen_doc(dic_predicted:dict):
131
+ document = Document()
132
+
133
+ for k,v in dic_predicted.items():
134
+
135
+ if "Figure" in k:
136
+ cv2.imwrite(f'{k}.jpg', dic_predicted['2-Figure'])
137
+ document.add_picture(f'{k}.jpg', width=Inches(3))
138
+
139
+ elif "Table" in k:
140
+ table = document.add_table(rows=v.shape[0], cols=v.shape[1])
141
+ hdr_cells = table.rows[0].cells
142
+ for idx, col in enumerate(v.columns):
143
+ hdr_cells[idx].text = col
144
+ for c in v.iterrows():
145
+
146
+ for idx, col in enumerate(v.columns):
147
+ try:
148
+ if len(c[1][col].strip())>0:
149
+ row_cells = table.add_row().cells
150
+ row_cells[idx].text = str(c[1][col])
151
+ except:
152
+ continue
153
+
154
+ else:
155
+ document.add_paragraph(str(v))
156
+
157
+ document.save('demo.docx')
158
+
159
+
160
+ def main_convert(filename):
161
+ print(filename.name)
162
+ doc = to_image(filename.name)
163
+ img, detected = detect(doc)
164
+ n_detected = get_detected(img, detected)
165
+ dic_predicted = predict_elements(img, n_detected)
166
+ gen_doc(dic_predicted)
167
+
168
+ im_out = lp.draw_box(img, detected, box_width=5, box_alpha=0.2, show_element_type=True)
169
+
170
+ return im_out, 'demo.docx', dic_predicted
171
+
172
+
173
+ inputs = [gr.File(type='file', label="Original PDF File")]
174
+ outputs = [gr.File(label="Converted DOC File"),gr.Image(type="PIL.Image", label="Detected Image"), gr.JSON()]
175
+
176
+ title = "A Document AI parser"
177
+ description = "This demo uses AI Models to detect text, titles, tables, figures and lists as well as table cells from an Scanned document.\nBased on the layout it determines reading order and generates an MS-DOC file to Download."
178
+
179
+
180
+ io = gr.Interface(fn=main_convert, inputs=inputs, outputs=outputs, title=title, description=description,
181
+ css= """.gr-button-primary { background: -webkit-linear-gradient(
182
+ 90deg, #355764 0%, #55a8a1 100% ) !important; background: #355764;
183
+ background: linear-gradient(
184
+ 90deg, #355764 0%, #55a8a1 100% ) !important;
185
+ background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important;
186
+ background: -webkit-linear-gradient(
187
+ 90deg, #355764 0%, #55a8a1 100% ) !important;
188
+ color:white !important}"""
189
+ )
190
+
191
+ io.launch()