Spaces:
Sleeping
Sleeping
File size: 5,969 Bytes
bbf40b2 fa54254 bbf40b2 fa54254 bbf40b2 fa54254 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import numpy as np
import cv2
import os
from ultralytics import YOLO
from PIL import Image
from src.table_ocr import TableEx
from src.datum_ocr import DatumOCR
from src.categories import CATEGORIES as categories, generate_colors
os.environ['TESSDATA_PREFIX'] = './tessdata'
def load_model():
"""Load the custom YOLO model using Ultralytics"""
# Load the model using the Ultralytics YOLO class
model = YOLO('src/yoloCADex.pt')
model.to('cpu')
return model
def process_image(image):
"""Process the uploaded image with the YOLO model and return the results"""
# Check if image is valid
if image is None or not isinstance(image, Image.Image):
return None, {"error": "Invalid image input"}, None, None, None
model = load_model()
category_colors = generate_colors(len(categories))
img_array = np.array(image) # Convert to format expected by the model
table_extractor = TableEx() # Initialize TableEx for table extraction
date_extractor = DatumOCR() # Initialize DatumOCR for OCR
results = model(img_array) # Run inference with CPU specified
img_with_boxes = img_array.copy() # Create a copy of the image for drawing
detections = [] # Initialize results table
table_data = [] # Storage for extracted table data and images
table_images = []
gdnt_data = [] # Storage for extracted table data and images
surface_data = []
# Process results
for result in results:
boxes = result.boxes
for box in boxes:
# Get box coordinates
x1, y1, x2, y2 = map(int, box.xyxy[0])
# Get confidence and class
conf = float(box.conf[0])
cls_id = int(box.cls[0])
if cls_id < len(categories):
cls_name = categories[cls_id]
color = category_colors[cls_id]
if cls_name == "table":
table_region, extracted_info = table_extractor.extract_table_data(img_array, x1, y1, x2, y2)
if table_region is not None:
table_images.append(table_region)
if extracted_info is not None:
table_data.append(extracted_info)
else:
cls_name = f"Unknown ({cls_id})"
color = (255, 255, 255) # White for unknown categories
label = f"{cls_name} {conf:.2f}"
# Draw rectangle with category-specific color
cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), color, 2)
# Create a filled rectangle for text background
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
cv2.rectangle(img_with_boxes, (x1, y1 - text_size[1] - 10), (x1 + text_size[0], y1), color, -1)
# Add label with contrasting text color
# Choose black or white text based on background brightness
brightness = (color[0] + color[1] + color[2]) / 3
text_color = (0, 0, 0) if brightness > 127 else (255, 255, 255)
cv2.putText(img_with_boxes, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 2)
# Store detection for table with color information
detections.append({
"category": cls_name,
"confidence": conf,
"position": (x1, y1, x2, y2),
"color": color
})
# Extract GD&T and Datum OCR information
gdnt_info = date_extractor.read_rois(img_array, [4], boxes, 0)
if gdnt_info is not None:
gdnt_data.append(gdnt_info)
surface_info = date_extractor.read_rois(img_array, [3,6,8,9,10,11], boxes, 0)
if surface_info is not None:
surface_data.append(surface_info)
# If we have detected tables but no extracted images, handle that case
if len(table_data) > 0 and len(table_images) == 0:
table_images = [Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))]
# Return the detection result image, any extracted table image, and the JSON data
return (
Image.fromarray(img_with_boxes), # Main detection image
table_images[0] if table_images else None, # First table image or None
table_data[0] if len(table_data) == 1 else table_data, # JSON data for gr.JSON
gdnt_data[0] if len(gdnt_data) == 1 else table_data, # JSON data for gr.JSON
surface_data[0] if len(surface_data) == 1 else table_data # JSON data for gr.JSON
)
# Create Gradio interface
with gr.Blocks(title="CAD 2d Drawing Data Extraction") as app:
gr.Markdown("# CAD 2d Drawing Data Extraction")
gr.Markdown("Upload an image to detect objects. Tables will be automatically extracted.")
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(type="pil", label="Input Image")
gr.Markdown("## Extracted Table Region")
table_image = gr.Image()
gr.Markdown("### Extracted GD&T Data")
gdnt_json = gr.JSON(open=True)
gr.Markdown("### Extracted Surface Data")
surface_json = gr.JSON(open=True)
with gr.Column(scale=3):
submit_btn = gr.Button("Detect Objects", variant="primary")
gr.Markdown("## Detection Results")
output_image = gr.Image()
gr.Markdown("### Extracted Table Data")
table_json = gr.JSON(open=True)
submit_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[output_image, table_image, table_json, gdnt_json, surface_json]
)
# Launch the app
if __name__ == "__main__":
app.launch() |