Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import cv2 | |
import os | |
from ultralytics import YOLO | |
from PIL import Image | |
from src.table_ocr import TableEx | |
from src.datum_ocr import DatumOCR | |
from src.categories import CATEGORIES as categories, generate_colors | |
os.environ['TESSDATA_PREFIX'] = './tessdata' | |
def load_model(): | |
"""Load the custom YOLO model using Ultralytics""" | |
# Load the model using the Ultralytics YOLO class | |
model = YOLO('src/yoloCADex.pt') | |
model.to('cpu') | |
return model | |
def process_image(image): | |
"""Process the uploaded image with the YOLO model and return the results""" | |
# Check if image is valid | |
if image is None or not isinstance(image, Image.Image): | |
return None, {"error": "Invalid image input"}, None, None, None | |
model = load_model() | |
category_colors = generate_colors(len(categories)) | |
img_array = np.array(image) # Convert to format expected by the model | |
table_extractor = TableEx() # Initialize TableEx for table extraction | |
date_extractor = DatumOCR() # Initialize DatumOCR for OCR | |
results = model(img_array) # Run inference with CPU specified | |
img_with_boxes = img_array.copy() # Create a copy of the image for drawing | |
detections = [] # Initialize results table | |
table_data = [] # Storage for extracted table data and images | |
table_images = [] | |
gdnt_data = [] # Storage for extracted table data and images | |
surface_data = [] | |
# Process results | |
for result in results: | |
boxes = result.boxes | |
for box in boxes: | |
# Get box coordinates | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
# Get confidence and class | |
conf = float(box.conf[0]) | |
cls_id = int(box.cls[0]) | |
if cls_id < len(categories): | |
cls_name = categories[cls_id] | |
color = category_colors[cls_id] | |
if cls_name == "table": | |
table_region, extracted_info = table_extractor.extract_table_data(img_array, x1, y1, x2, y2) | |
if table_region is not None: | |
table_images.append(table_region) | |
if extracted_info is not None: | |
table_data.append(extracted_info) | |
else: | |
cls_name = f"Unknown ({cls_id})" | |
color = (255, 255, 255) # White for unknown categories | |
label = f"{cls_name} {conf:.2f}" | |
# Draw rectangle with category-specific color | |
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), color, 2) | |
# Create a filled rectangle for text background | |
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] | |
cv2.rectangle(img_with_boxes, (x1, y1 - text_size[1] - 10), (x1 + text_size[0], y1), color, -1) | |
# Add label with contrasting text color | |
# Choose black or white text based on background brightness | |
brightness = (color[0] + color[1] + color[2]) / 3 | |
text_color = (0, 0, 0) if brightness > 127 else (255, 255, 255) | |
cv2.putText(img_with_boxes, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 2) | |
# Store detection for table with color information | |
detections.append({ | |
"category": cls_name, | |
"confidence": conf, | |
"position": (x1, y1, x2, y2), | |
"color": color | |
}) | |
# Extract GD&T and Datum OCR information | |
gdnt_info = date_extractor.read_rois(img_array, [4], boxes, 0) | |
if gdnt_info is not None: | |
gdnt_data.append(gdnt_info) | |
surface_info = date_extractor.read_rois(img_array, [3,6,8,9,10,11], boxes, 0) | |
if surface_info is not None: | |
surface_data.append(surface_info) | |
# If we have detected tables but no extracted images, handle that case | |
if len(table_data) > 0 and len(table_images) == 0: | |
table_images = [Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))] | |
# Return the detection result image, any extracted table image, and the JSON data | |
return ( | |
Image.fromarray(img_with_boxes), # Main detection image | |
table_images[0] if table_images else None, # First table image or None | |
table_data[0] if len(table_data) == 1 else table_data, # JSON data for gr.JSON | |
gdnt_data[0] if len(gdnt_data) == 1 else table_data, # JSON data for gr.JSON | |
surface_data[0] if len(surface_data) == 1 else table_data # JSON data for gr.JSON | |
) | |
# Create Gradio interface | |
with gr.Blocks(title="CAD 2d Drawing Data Extraction") as app: | |
gr.Markdown("# CAD 2d Drawing Data Extraction") | |
gr.Markdown("Upload an image to detect objects. Tables will be automatically extracted.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_image = gr.Image(type="pil", label="Input Image") | |
gr.Markdown("## Extracted Table Region") | |
table_image = gr.Image() | |
gr.Markdown("### Extracted GD&T Data") | |
gdnt_json = gr.JSON(open=True) | |
gr.Markdown("### Extracted Surface Data") | |
surface_json = gr.JSON(open=True) | |
with gr.Column(scale=3): | |
submit_btn = gr.Button("Detect Objects", variant="primary") | |
gr.Markdown("## Detection Results") | |
output_image = gr.Image() | |
gr.Markdown("### Extracted Table Data") | |
table_json = gr.JSON(open=True) | |
submit_btn.click( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[output_image, table_image, table_json, gdnt_json, surface_json] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
app.launch() |