import gradio as gr import numpy as np import cv2 import os from ultralytics import YOLO from PIL import Image from src.table_ocr import TableEx from src.datum_ocr import DatumOCR from src.categories import CATEGORIES as categories, generate_colors os.environ['TESSDATA_PREFIX'] = './tessdata' def load_model(): """Load the custom YOLO model using Ultralytics""" # Load the model using the Ultralytics YOLO class model = YOLO('src/yoloCADex.pt') model.to('cpu') return model def process_image(image): """Process the uploaded image with the YOLO model and return the results""" # Check if image is valid if image is None or not isinstance(image, Image.Image): return None, {"error": "Invalid image input"}, None, None, None model = load_model() category_colors = generate_colors(len(categories)) img_array = np.array(image) # Convert to format expected by the model table_extractor = TableEx() # Initialize TableEx for table extraction date_extractor = DatumOCR() # Initialize DatumOCR for OCR results = model(img_array) # Run inference with CPU specified img_with_boxes = img_array.copy() # Create a copy of the image for drawing detections = [] # Initialize results table table_data = [] # Storage for extracted table data and images table_images = [] gdnt_data = [] # Storage for extracted table data and images surface_data = [] # Process results for result in results: boxes = result.boxes for box in boxes: # Get box coordinates x1, y1, x2, y2 = map(int, box.xyxy[0]) # Get confidence and class conf = float(box.conf[0]) cls_id = int(box.cls[0]) if cls_id < len(categories): cls_name = categories[cls_id] color = category_colors[cls_id] if cls_name == "table": table_region, extracted_info = table_extractor.extract_table_data(img_array, x1, y1, x2, y2) if table_region is not None: table_images.append(table_region) if extracted_info is not None: table_data.append(extracted_info) else: cls_name = f"Unknown ({cls_id})" color = (255, 255, 255) # White for unknown categories label = f"{cls_name} {conf:.2f}" # Draw rectangle with category-specific color cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), color, 2) # Create a filled rectangle for text background text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] cv2.rectangle(img_with_boxes, (x1, y1 - text_size[1] - 10), (x1 + text_size[0], y1), color, -1) # Add label with contrasting text color # Choose black or white text based on background brightness brightness = (color[0] + color[1] + color[2]) / 3 text_color = (0, 0, 0) if brightness > 127 else (255, 255, 255) cv2.putText(img_with_boxes, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 2) # Store detection for table with color information detections.append({ "category": cls_name, "confidence": conf, "position": (x1, y1, x2, y2), "color": color }) # Extract GD&T and Datum OCR information gdnt_info = date_extractor.read_rois(img_array, [4], boxes, 0) if gdnt_info is not None: gdnt_data.append(gdnt_info) surface_info = date_extractor.read_rois(img_array, [3,6,8,9,10,11], boxes, 0) if surface_info is not None: surface_data.append(surface_info) # If we have detected tables but no extracted images, handle that case if len(table_data) > 0 and len(table_images) == 0: table_images = [Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))] # Return the detection result image, any extracted table image, and the JSON data return ( Image.fromarray(img_with_boxes), # Main detection image table_images[0] if table_images else None, # First table image or None table_data[0] if len(table_data) == 1 else table_data, # JSON data for gr.JSON gdnt_data[0] if len(gdnt_data) == 1 else table_data, # JSON data for gr.JSON surface_data[0] if len(surface_data) == 1 else table_data # JSON data for gr.JSON ) # Create Gradio interface with gr.Blocks(title="CAD 2d Drawing Data Extraction") as app: gr.Markdown("# CAD 2d Drawing Data Extraction") gr.Markdown("Upload an image to detect objects. Tables will be automatically extracted.") with gr.Row(): with gr.Column(scale=2): input_image = gr.Image(type="pil", label="Input Image") gr.Markdown("## Extracted Table Region") table_image = gr.Image() gr.Markdown("### Extracted GD&T Data") gdnt_json = gr.JSON(open=True) gr.Markdown("### Extracted Surface Data") surface_json = gr.JSON(open=True) with gr.Column(scale=3): submit_btn = gr.Button("Detect Objects", variant="primary") gr.Markdown("## Detection Results") output_image = gr.Image() gr.Markdown("### Extracted Table Data") table_json = gr.JSON(open=True) submit_btn.click( fn=process_image, inputs=[input_image], outputs=[output_image, table_image, table_json, gdnt_json, surface_json] ) # Launch the app if __name__ == "__main__": app.launch()