import matplotlib.pyplot as plt |
import matplotlib.patches as patches |
from matplotlib.patches import Patch |
import io |
import cv2 |
from PIL import Image, ImageDraw, ImageFont |
import numpy as np |
import csv |
import pandas as pd |
from ultralytics import YOLO |
import torch |
from paddleocr import PaddleOCR |
import postprocess |
import gradio as gr |
device = "cuda" if torch.cuda.is_available() else "cpu" |
detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device) |
structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device) |
ocr_model = PaddleOCR(use_angle_cls=True, lang=uk", det_limit_side_len=1920) # TODO use large det_limit_side_len to get better OCR result |
detection_class_names = ['table', 'table rotated'] |
structure_class_names = [ |
'table', 'table column', 'table row', 'table column header', |
'table projected row header', 'table spanning cell', 'no object' |
] |
structure_class_map = {k: v for v, k in enumerate(structure_class_names)} |
structure_class_thresholds = { |
"table": 0.5, |
"table column": 0.5, |
"table row": 0.5, |
"table column header": 0.5, |
"table projected row header": 0.5, |
"table spanning cell": 0.5, |
"no object": 10 |
} |
def table_detection(image): |
imgsz = 800 |
pred = detection_model.predict(image, imgsz=imgsz) |
pred = pred[0].boxes |
result = pred.cpu().numpy() |
result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] |
return result_list |
def table_structure(image): |
imgsz = 1024 |
pred = structure_model.predict(image, imgsz=imgsz) |
pred = pred[0].boxes |
result = pred.cpu().numpy() |
result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] |
return result_list |
def crop_image(image, detection_result): |
# crop_filenames = [] |
width = image.shape[1] |
height = image.shape[0] |
# print(width, height) |
crop_image = image |
for i, result in enumerate(detection_result[:1]): # TODO only return first detected table |
class_id = int(result[5]) |
score = float(result[4]) |
min_x = result[0] |
min_y = result[1] |
w = result[2] |
h = result[3] |
# x1 = max(0, int((min_x-w/2-0.02)*width)) # TODO expand 2% |
# y1 = max(0, int((min_y-h/2-0.02)*height)) # TODO expand 2% |
# x2 = min(width, int((min_x+w/2+0.02)*width)) # TODO expand 2% |
# y2 = min(height, int((min_y+h/2+0.02)*height)) # TODO expand 2% |
x1 = max(0, int((min_x-w/2)*width)-10) # TODO expand 10px |
y1 = max(0, int((min_y-h/2)*height)-10) # TODO expand 10px |
x2 = min(width, int((min_x+w/2)*width)+10) # TODO expand 10px |
y2 = min(height, int((min_y+h/2)*height)+10) # TODO expand 10px |
# print(x1, y1, x2, y2) |
crop_image = image[y1:y2, x1:x2, :] |
# crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:] |
# crop_filenames.append(crop_filename) |
# cv2.imwrite(crop_filename, crop_image) |
return crop_image |
def convert_stucture(ocr_result, image, structure_result): |
width = image.shape[1] |
height = image.shape[0] |
# print(width, height) |
bboxes = [] |
scores = [] |
labels = [] |
for i, result in enumerate(structure_result): |
class_id = int(result[5]) |
score = float(result[4]) |
min_x = result[0] |
min_y = result[1] |
w = result[2] |
h = result[3] |
x1 = int((min_x-w/2)*width) |
y1 = int((min_y-h/2)*height) |
x2 = int((min_x+w/2)*width) |
y2 = int((min_y+h/2)*height) |
# print(x1, y1, x2, y2) |
bboxes.append([x1, y1, x2, y2]) |
scores.append(score) |
labels.append(class_id) |
table_objects = [] |
for bbox, score, label in zip(bboxes, scores, labels): |
table_objects.append({'bbox': bbox, 'score': score, 'label': label}) |
# print('table_objects:', table_objects) |
table = {'objects': table_objects, 'page_num': 0} |
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] |
if len(table_class_objects) > 1: |
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) |
try: |
table_bbox = list(table_class_objects[0]['bbox']) |
except: |
table_bbox = (0,0,1000,1000) |
# print('table_class_objects:', table_class_objects) |
# print('table_bbox:', table_bbox) |
page_tokens = ocr_result |
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] |
# print('tokens_in_table:', tokens_in_table) |
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) |
return table_structures, cells, confidence_score |
def visualize_cells(image, table_structures, cells): |
width = image.shape[1] |
height = image.shape[0] |
# print(width, height) |
empty_image = np.zeros((height, width, 3), np.uint8) |
empty_image.fill(255) |
empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB)) |
draw = ImageDraw.Draw(empty_image) |
fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8") |
num_cols = len(table_structures['columns']) |
num_rows = len(table_structures['rows']) |
data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)] |
for i, cell in enumerate(cells): |
bbox = cell['bbox'] |
x1 = int(bbox[0]) |
y1 = int(bbox[1]) |
x2 = int(bbox[2]) |
y2 = int(bbox[3]) |
col_num = cell['column_nums'][0] |
row_num = cell['row_nums'][0] |
spans = cell['spans'] |
text = '' |
for span in spans: |
if 'text' in span: |
text += span['text'] |
data_rows[row_num][col_num] = text |
# print('text:', text) |
text_len = len(text) |
# print('text_len:', text_len) |
cell_width = x2-x1 |
# print('cell_width:', cell_width) |
num_per_line = cell_width//10 |
# print('num_per_line:', num_per_line) |
if num_per_line != 0: |
line_num = text_len//num_per_line |
else: |
line_num = 0 |
# print('line_num:', line_num) |
new_text = text[:num_per_line]+'\n' |
for j in range(line_num): |
new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n' |
# print('new_text:', new_text) |
text = new_text |
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0)) |
# cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) |
# cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255)) |
# cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) |
# cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255)) |
draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0)) |
# draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle) |
# draw.text((x1, y1), text, (0,0,255), font=fontStyle) |
df = pd.DataFrame(data_rows) |
df.columns = df.columns.astype(str) |
return image, df, df.to_json() |
def ocr(image): |
result = ocr_model.ocr(image, cls=True) |
result = result[0] |
new_result = [] |
if result is not None: |
bounding_boxes = [line[0] for line in result] |
txts = [line[1][0] for line in result] |
scores = [line[1][1] for line in result] |
# print('txts:', txts) |
# print('scores:', scores) |
# print('bounding_boxes:', bounding_boxes) |
for label, bbox in zip(txts, bounding_boxes): |
new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label}) |
return new_result |
def detect_and_crop_table(image): |
detection_result = table_detection(image) |
# print('detection_result:', detection_result) |
cropped_table = crop_image(image, detection_result) |
return cropped_table |
def recognize_table(image, ocr_result): |
structure_result = table_structure(image) |
print('structure_result:', structure_result) |
table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result) |
print('table_structures:', table_structures) |
print('cells:', cells) |
print('confidence_score:', confidence_score) |
image, df, data = visualize_cells(image, table_structures, cells) |
return image, df, data |
def process_pdf(image): |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
cropped_table = detect_and_crop_table(image) |
ocr_result = ocr(cropped_table) |
# print('ocr_result:', ocr_result) |
image, df, data = recognize_table(cropped_table, ocr_result) |
print('df:', df) |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
return image, df, data |
title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)." |
description = """Demo for table extraction with the Table Structure Recognition (Yolov8).""" |
examples = [['image.png'], ['mistral_paper.png']] |
app = gr.Interface(fn=process_pdf, |
inputs=gr.Image(type="numpy"), |
outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")], |
title=title, |
description=description, |
examples=examples) |
app.queue() |
# app.launch(debug=True, share=True) |
app.launch() |