CadExtractor / app.py
Martin Krockert
Demo with tesseract / paddle and finetuned yolo 12
fa54254
import gradio as gr
import numpy as np
import cv2
import os
from ultralytics import YOLO
from PIL import Image
from src.table_ocr import TableEx
from src.datum_ocr import DatumOCR
from src.categories import CATEGORIES as categories, generate_colors
os.environ['TESSDATA_PREFIX'] = './tessdata'
def load_model():
"""Load the custom YOLO model using Ultralytics"""
# Load the model using the Ultralytics YOLO class
model = YOLO('src/yoloCADex.pt')
model.to('cpu')
return model
def process_image(image):
"""Process the uploaded image with the YOLO model and return the results"""
# Check if image is valid
if image is None or not isinstance(image, Image.Image):
return None, {"error": "Invalid image input"}, None, None, None
model = load_model()
category_colors = generate_colors(len(categories))
img_array = np.array(image) # Convert to format expected by the model
table_extractor = TableEx() # Initialize TableEx for table extraction
date_extractor = DatumOCR() # Initialize DatumOCR for OCR
results = model(img_array) # Run inference with CPU specified
img_with_boxes = img_array.copy() # Create a copy of the image for drawing
detections = [] # Initialize results table
table_data = [] # Storage for extracted table data and images
table_images = []
gdnt_data = [] # Storage for extracted table data and images
surface_data = []
# Process results
for result in results:
boxes = result.boxes
for box in boxes:
# Get box coordinates
x1, y1, x2, y2 = map(int, box.xyxy[0])
# Get confidence and class
conf = float(box.conf[0])
cls_id = int(box.cls[0])
if cls_id < len(categories):
cls_name = categories[cls_id]
color = category_colors[cls_id]
if cls_name == "table":
table_region, extracted_info = table_extractor.extract_table_data(img_array, x1, y1, x2, y2)
if table_region is not None:
table_images.append(table_region)
if extracted_info is not None:
table_data.append(extracted_info)
else:
cls_name = f"Unknown ({cls_id})"
color = (255, 255, 255) # White for unknown categories
label = f"{cls_name} {conf:.2f}"
# Draw rectangle with category-specific color
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), color, 2)
# Create a filled rectangle for text background
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(img_with_boxes, (x1, y1 - text_size[1] - 10), (x1 + text_size[0], y1), color, -1)
# Add label with contrasting text color
# Choose black or white text based on background brightness
brightness = (color[0] + color[1] + color[2]) / 3
text_color = (0, 0, 0) if brightness > 127 else (255, 255, 255)
cv2.putText(img_with_boxes, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 2)
# Store detection for table with color information
detections.append({
"category": cls_name,
"confidence": conf,
"position": (x1, y1, x2, y2),
"color": color
})
# Extract GD&T and Datum OCR information
gdnt_info = date_extractor.read_rois(img_array, [4], boxes, 0)
if gdnt_info is not None:
gdnt_data.append(gdnt_info)
surface_info = date_extractor.read_rois(img_array, [3,6,8,9,10,11], boxes, 0)
if surface_info is not None:
surface_data.append(surface_info)
# If we have detected tables but no extracted images, handle that case
if len(table_data) > 0 and len(table_images) == 0:
table_images = [Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))]
# Return the detection result image, any extracted table image, and the JSON data
return (
Image.fromarray(img_with_boxes), # Main detection image
table_images[0] if table_images else None, # First table image or None
table_data[0] if len(table_data) == 1 else table_data, # JSON data for gr.JSON
gdnt_data[0] if len(gdnt_data) == 1 else table_data, # JSON data for gr.JSON
surface_data[0] if len(surface_data) == 1 else table_data # JSON data for gr.JSON
)
# Create Gradio interface
with gr.Blocks(title="CAD 2d Drawing Data Extraction") as app:
gr.Markdown("# CAD 2d Drawing Data Extraction")
gr.Markdown("Upload an image to detect objects. Tables will be automatically extracted.")
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(type="pil", label="Input Image")
gr.Markdown("## Extracted Table Region")
table_image = gr.Image()
gr.Markdown("### Extracted GD&T Data")
gdnt_json = gr.JSON(open=True)
gr.Markdown("### Extracted Surface Data")
surface_json = gr.JSON(open=True)
with gr.Column(scale=3):
submit_btn = gr.Button("Detect Objects", variant="primary")
gr.Markdown("## Detection Results")
output_image = gr.Image()
gr.Markdown("### Extracted Table Data")
table_json = gr.JSON(open=True)
submit_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[output_image, table_image, table_json, gdnt_json, surface_json]
)
# Launch the app
if __name__ == "__main__":
app.launch()