|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
from matplotlib.patches import Patch |
|
import io |
|
import cv2 |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import csv |
|
import pandas as pd |
|
|
|
from ultralytics import YOLO |
|
import torch |
|
|
|
from paddleocr import PaddleOCR |
|
import postprocess |
|
|
|
import gradio as gr |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device) |
|
structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device) |
|
ocr_model = PaddleOCR(use_angle_cls=True, lang="ch", det_limit_side_len=1920) |
|
|
|
detection_class_names = ['table', 'table rotated'] |
|
structure_class_names = [ |
|
'table', 'table column', 'table row', 'table column header', |
|
'table projected row header', 'table spanning cell', 'no object' |
|
] |
|
structure_class_map = {k: v for v, k in enumerate(structure_class_names)} |
|
structure_class_thresholds = { |
|
"table": 0.5, |
|
"table column": 0.5, |
|
"table row": 0.5, |
|
"table column header": 0.5, |
|
"table projected row header": 0.5, |
|
"table spanning cell": 0.5, |
|
"no object": 10 |
|
} |
|
|
|
|
|
def table_detection(image): |
|
imgsz = 800 |
|
pred = detection_model.predict(image, imgsz=imgsz) |
|
pred = pred[0].boxes |
|
result = pred.cpu().numpy() |
|
result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] |
|
return result_list |
|
|
|
|
|
def table_structure(image): |
|
imgsz = 1024 |
|
pred = structure_model.predict(image, imgsz=imgsz) |
|
pred = pred[0].boxes |
|
result = pred.cpu().numpy() |
|
result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])] |
|
return result_list |
|
|
|
|
|
def crop_image(image, detection_result): |
|
|
|
width = image.shape[1] |
|
height = image.shape[0] |
|
|
|
crop_image = image |
|
for i, result in enumerate(detection_result[:1]): |
|
class_id = int(result[5]) |
|
score = float(result[4]) |
|
min_x = result[0] |
|
min_y = result[1] |
|
w = result[2] |
|
h = result[3] |
|
|
|
|
|
|
|
|
|
|
|
x1 = max(0, int((min_x-w/2)*width)-10) |
|
y1 = max(0, int((min_y-h/2)*height)-10) |
|
x2 = min(width, int((min_x+w/2)*width)+10) |
|
y2 = min(height, int((min_y+h/2)*height)+10) |
|
|
|
crop_image = image[y1:y2, x1:x2, :] |
|
|
|
|
|
|
|
return crop_image |
|
|
|
|
|
def convert_stucture(ocr_result, image, structure_result): |
|
width = image.shape[1] |
|
height = image.shape[0] |
|
|
|
|
|
bboxes = [] |
|
scores = [] |
|
labels = [] |
|
for i, result in enumerate(structure_result): |
|
class_id = int(result[5]) |
|
score = float(result[4]) |
|
min_x = result[0] |
|
min_y = result[1] |
|
w = result[2] |
|
h = result[3] |
|
|
|
x1 = int((min_x-w/2)*width) |
|
y1 = int((min_y-h/2)*height) |
|
x2 = int((min_x+w/2)*width) |
|
y2 = int((min_y+h/2)*height) |
|
|
|
|
|
bboxes.append([x1, y1, x2, y2]) |
|
scores.append(score) |
|
labels.append(class_id) |
|
|
|
table_objects = [] |
|
for bbox, score, label in zip(bboxes, scores, labels): |
|
table_objects.append({'bbox': bbox, 'score': score, 'label': label}) |
|
|
|
|
|
table = {'objects': table_objects, 'page_num': 0} |
|
|
|
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']] |
|
if len(table_class_objects) > 1: |
|
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True) |
|
try: |
|
table_bbox = list(table_class_objects[0]['bbox']) |
|
except: |
|
table_bbox = (0,0,1000,1000) |
|
|
|
|
|
|
|
page_tokens = ocr_result |
|
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5] |
|
|
|
|
|
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds) |
|
|
|
return table_structures, cells, confidence_score |
|
|
|
|
|
def visualize_cells(image, table_structures, cells): |
|
width = image.shape[1] |
|
height = image.shape[0] |
|
|
|
empty_image = np.zeros((height, width, 3), np.uint8) |
|
empty_image.fill(255) |
|
empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB)) |
|
draw = ImageDraw.Draw(empty_image) |
|
fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8") |
|
|
|
num_cols = len(table_structures['columns']) |
|
num_rows = len(table_structures['rows']) |
|
data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)] |
|
for i, cell in enumerate(cells): |
|
bbox = cell['bbox'] |
|
x1 = int(bbox[0]) |
|
y1 = int(bbox[1]) |
|
x2 = int(bbox[2]) |
|
y2 = int(bbox[3]) |
|
col_num = cell['column_nums'][0] |
|
row_num = cell['row_nums'][0] |
|
spans = cell['spans'] |
|
text = '' |
|
for span in spans: |
|
if 'text' in span: |
|
text += span['text'] |
|
data_rows[row_num][col_num] = text |
|
|
|
|
|
text_len = len(text) |
|
|
|
cell_width = x2-x1 |
|
|
|
num_per_line = cell_width//10 |
|
|
|
if num_per_line != 0: |
|
line_num = text_len//num_per_line |
|
else: |
|
line_num = 0 |
|
|
|
new_text = text[:num_per_line]+'\n' |
|
for j in range(line_num): |
|
new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n' |
|
|
|
text = new_text |
|
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0)) |
|
|
|
|
|
|
|
|
|
|
|
draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0)) |
|
|
|
|
|
|
|
df = pd.DataFrame(data_rows) |
|
df.columns = df.columns.astype(str) |
|
return image, df, df.to_json() |
|
|
|
|
|
def ocr(image): |
|
result = ocr_model.ocr(image, cls=True) |
|
result = result[0] |
|
new_result = [] |
|
if result is not None: |
|
bounding_boxes = [line[0] for line in result] |
|
txts = [line[1][0] for line in result] |
|
scores = [line[1][1] for line in result] |
|
|
|
|
|
|
|
for label, bbox in zip(txts, bounding_boxes): |
|
new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label}) |
|
|
|
return new_result |
|
|
|
|
|
def detect_and_crop_table(image): |
|
detection_result = table_detection(image) |
|
|
|
cropped_table = crop_image(image, detection_result) |
|
|
|
return cropped_table |
|
|
|
|
|
def recognize_table(image, ocr_result): |
|
structure_result = table_structure(image) |
|
print('structure_result:', structure_result) |
|
table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result) |
|
print('table_structures:', table_structures) |
|
print('cells:', cells) |
|
print('confidence_score:', confidence_score) |
|
image, df, data = visualize_cells(image, table_structures, cells) |
|
|
|
return image, df, data |
|
|
|
|
|
def process_pdf(image): |
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
|
cropped_table = detect_and_crop_table(image) |
|
|
|
ocr_result = ocr(cropped_table) |
|
|
|
|
|
image, df, data = recognize_table(cropped_table, ocr_result) |
|
print('df:', df) |
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
return image, df, data |
|
|
|
|
|
title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)." |
|
description = """Demo for table extraction with the Table Structure Recognition (Yolov8).""" |
|
examples = [['image.png'], ['mistral_paper.png']] |
|
|
|
app = gr.Interface(fn=process_pdf, |
|
inputs=gr.Image(type="numpy"), |
|
outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")], |
|
title=title, |
|
description=description, |
|
examples=examples) |
|
app.queue() |
|
|
|
app.launch() |