EC2 Default User
dont't draw text
c4c5de7
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
import cv2
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import csv
import pandas as pd
from ultralytics import YOLO
import torch
from paddleocr import PaddleOCR
import postprocess
import gradio as gr
device = "cuda" if torch.cuda.is_available() else "cpu"
detection_model = YOLO('yolov8/runs/detect/yolov8s-custom-detection/weights/best.pt').to(device)
structure_model = YOLO('yolov8/runs/detect/yolov8s-custom-structure-all/weights/best.pt').to(device)
ocr_model = PaddleOCR(use_angle_cls=True, lang="ch", det_limit_side_len=1920) # TODO use large det_limit_side_len to get better OCR result
detection_class_names = ['table', 'table rotated']
structure_class_names = [
'table', 'table column', 'table row', 'table column header',
'table projected row header', 'table spanning cell', 'no object'
]
structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
structure_class_thresholds = {
"table": 0.5,
"table column": 0.5,
"table row": 0.5,
"table column header": 0.5,
"table projected row header": 0.5,
"table spanning cell": 0.5,
"no object": 10
}
def table_detection(image):
imgsz = 800
pred = detection_model.predict(image, imgsz=imgsz)
pred = pred[0].boxes
result = pred.cpu().numpy()
result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
return result_list
def table_structure(image):
imgsz = 1024
pred = structure_model.predict(image, imgsz=imgsz)
pred = pred[0].boxes
result = pred.cpu().numpy()
result_list = [list(result.xywhn[i]) + [result.conf[i], result.cls[i]] for i in range(result.shape[0])]
return result_list
def crop_image(image, detection_result):
# crop_filenames = []
width = image.shape[1]
height = image.shape[0]
# print(width, height)
crop_image = image
for i, result in enumerate(detection_result[:1]): # TODO only return first detected table
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
# x1 = max(0, int((min_x-w/2-0.02)*width)) # TODO expand 2%
# y1 = max(0, int((min_y-h/2-0.02)*height)) # TODO expand 2%
# x2 = min(width, int((min_x+w/2+0.02)*width)) # TODO expand 2%
# y2 = min(height, int((min_y+h/2+0.02)*height)) # TODO expand 2%
x1 = max(0, int((min_x-w/2)*width)-10) # TODO expand 10px
y1 = max(0, int((min_y-h/2)*height)-10) # TODO expand 10px
x2 = min(width, int((min_x+w/2)*width)+10) # TODO expand 10px
y2 = min(height, int((min_y+h/2)*height)+10) # TODO expand 10px
# print(x1, y1, x2, y2)
crop_image = image[y1:y2, x1:x2, :]
# crop_filename = filename[:-4]+'_'+str(i)+'_'+detection_class_names[class_id]+filename[-4:]
# crop_filenames.append(crop_filename)
# cv2.imwrite(crop_filename, crop_image)
return crop_image
def convert_stucture(ocr_result, image, structure_result):
width = image.shape[1]
height = image.shape[0]
# print(width, height)
bboxes = []
scores = []
labels = []
for i, result in enumerate(structure_result):
class_id = int(result[5])
score = float(result[4])
min_x = result[0]
min_y = result[1]
w = result[2]
h = result[3]
x1 = int((min_x-w/2)*width)
y1 = int((min_y-h/2)*height)
x2 = int((min_x+w/2)*width)
y2 = int((min_y+h/2)*height)
# print(x1, y1, x2, y2)
bboxes.append([x1, y1, x2, y2])
scores.append(score)
labels.append(class_id)
table_objects = []
for bbox, score, label in zip(bboxes, scores, labels):
table_objects.append({'bbox': bbox, 'score': score, 'label': label})
# print('table_objects:', table_objects)
table = {'objects': table_objects, 'page_num': 0}
table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
if len(table_class_objects) > 1:
table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
try:
table_bbox = list(table_class_objects[0]['bbox'])
except:
table_bbox = (0,0,1000,1000)
# print('table_class_objects:', table_class_objects)
# print('table_bbox:', table_bbox)
page_tokens = ocr_result
tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
# print('tokens_in_table:', tokens_in_table)
table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
return table_structures, cells, confidence_score
def visualize_cells(image, table_structures, cells):
width = image.shape[1]
height = image.shape[0]
# print(width, height)
empty_image = np.zeros((height, width, 3), np.uint8)
empty_image.fill(255)
empty_image = Image.fromarray(cv2.cvtColor(empty_image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(empty_image)
fontStyle = ImageFont.truetype("SimSong.ttc", 10, encoding="utf-8")
num_cols = len(table_structures['columns'])
num_rows = len(table_structures['rows'])
data_rows = [['' for _ in range(num_cols)] for _ in range(num_rows)]
for i, cell in enumerate(cells):
bbox = cell['bbox']
x1 = int(bbox[0])
y1 = int(bbox[1])
x2 = int(bbox[2])
y2 = int(bbox[3])
col_num = cell['column_nums'][0]
row_num = cell['row_nums'][0]
spans = cell['spans']
text = ''
for span in spans:
if 'text' in span:
text += span['text']
data_rows[row_num][col_num] = text
# print('text:', text)
text_len = len(text)
# print('text_len:', text_len)
cell_width = x2-x1
# print('cell_width:', cell_width)
num_per_line = cell_width//10
# print('num_per_line:', num_per_line)
if num_per_line != 0:
line_num = text_len//num_per_line
else:
line_num = 0
# print('line_num:', line_num)
new_text = text[:num_per_line]+'\n'
for j in range(line_num):
new_text += text[(j+1)*num_per_line:(j+2)*num_per_line]+'\n'
# print('new_text:', new_text)
text = new_text
cv2.rectangle(image, (x1, y1), (x2, y2), color=(0,255,0))
# cv2.putText(image, str(row_num)+'-'+str(col_num), (x1, y1+30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
# cv2.rectangle(empty_image, (x1, y1), (x2, y2), color=(0,0,255))
# cv2.putText(empty_image, str(row_num)+'-'+str(col_num), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
# cv2.putText(empty_image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
draw.rectangle([(x1, y1), (x2, y2)], (255,255,255), (0,255,0))
# draw.text((x1-20, y1), str(row_num)+'-'+str(col_num), (255,0,0), font=fontStyle)
# draw.text((x1, y1), text, (0,0,255), font=fontStyle)
df = pd.DataFrame(data_rows)
df.columns = df.columns.astype(str)
return image, df, df.to_json()
def ocr(image):
result = ocr_model.ocr(image, cls=True)
result = result[0]
new_result = []
if result is not None:
bounding_boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
# print('txts:', txts)
# print('scores:', scores)
# print('bounding_boxes:', bounding_boxes)
for label, bbox in zip(txts, bounding_boxes):
new_result.append({'bbox': [bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], 'text': label})
return new_result
def detect_and_crop_table(image):
detection_result = table_detection(image)
# print('detection_result:', detection_result)
cropped_table = crop_image(image, detection_result)
return cropped_table
def recognize_table(image, ocr_result):
structure_result = table_structure(image)
print('structure_result:', structure_result)
table_structures, cells, confidence_score = convert_stucture(ocr_result, image, structure_result)
print('table_structures:', table_structures)
print('cells:', cells)
print('confidence_score:', confidence_score)
image, df, data = visualize_cells(image, table_structures, cells)
return image, df, data
def process_pdf(image):
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cropped_table = detect_and_crop_table(image)
ocr_result = ocr(cropped_table)
# print('ocr_result:', ocr_result)
image, df, data = recognize_table(cropped_table, ocr_result)
print('df:', df)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, df, data
title = "Demo: table detection & recognition with Table Structure Recognition (Yolov8)."
description = """Demo for table extraction with the Table Structure Recognition (Yolov8)."""
examples = [['image.png'], ['mistral_paper.png']]
app = gr.Interface(fn=process_pdf,
inputs=gr.Image(type="numpy"),
outputs=[gr.Image(type="numpy", label="Detected table"), gr.Dataframe(label="Table as CSV"), gr.JSON(label="Data as JSON")],
title=title,
description=description,
examples=examples)
app.queue()
# app.launch(debug=True, share=True)
app.launch()