File size: 5,969 Bytes
bbf40b2
fa54254
 
 
 
 
 
 
 
bbf40b2
fa54254
bbf40b2
fa54254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import numpy as np
import cv2
import os
from ultralytics import YOLO
from PIL import Image
from src.table_ocr import TableEx   
from src.datum_ocr import DatumOCR
from src.categories import CATEGORIES as categories, generate_colors

os.environ['TESSDATA_PREFIX'] = './tessdata'

def load_model():
    """Load the custom YOLO model using Ultralytics"""
    # Load the model using the Ultralytics YOLO class
    model = YOLO('src/yoloCADex.pt')
    model.to('cpu')
    return model


def process_image(image):
    """Process the uploaded image with the YOLO model and return the results"""
    
    # Check if image is valid
    if image is None or not isinstance(image, Image.Image):
        return None, {"error": "Invalid image input"}, None, None, None
    
    model = load_model()
    category_colors = generate_colors(len(categories))
    img_array = np.array(image)                 # Convert to format expected by the model
    table_extractor = TableEx()                 # Initialize TableEx for table extraction
    date_extractor = DatumOCR()                 # Initialize DatumOCR for OCR
    results = model(img_array)                  # Run inference with CPU specified
    img_with_boxes = img_array.copy()           # Create a copy of the image for drawing
    detections = []                             # Initialize results table
    
    table_data = []                             # Storage for extracted table data and images
    table_images = []
    
    gdnt_data = []                             # Storage for extracted table data and images
    surface_data = []


    # Process results
    for result in results:
        boxes = result.boxes
        for box in boxes:
            # Get box coordinates
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            
            # Get confidence and class
            conf = float(box.conf[0])
            cls_id = int(box.cls[0])
            
            if cls_id < len(categories):
                cls_name = categories[cls_id]
                color = category_colors[cls_id]
                
                if cls_name == "table":
                    table_region, extracted_info = table_extractor.extract_table_data(img_array, x1, y1, x2, y2)
                    if table_region is not None:
                        table_images.append(table_region)
                    if extracted_info is not None:
                        table_data.append(extracted_info)
            else:
                cls_name = f"Unknown ({cls_id})"
                color = (255, 255, 255)  # White for unknown categories

            label = f"{cls_name} {conf:.2f}"
            
            # Draw rectangle with category-specific color
            cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), color, 2)
            
            # Create a filled rectangle for text background
            text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
            cv2.rectangle(img_with_boxes, (x1, y1 - text_size[1] - 10), (x1 + text_size[0], y1), color, -1)
            
            # Add label with contrasting text color
            # Choose black or white text based on background brightness
            brightness = (color[0] + color[1] + color[2]) / 3
            text_color = (0, 0, 0) if brightness > 127 else (255, 255, 255)
            cv2.putText(img_with_boxes, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 2)
            
            # Store detection for table with color information
            detections.append({
                "category": cls_name,
                "confidence": conf,
                "position": (x1, y1, x2, y2),
                "color": color
            })
    

        # Extract GD&T and Datum OCR information
        gdnt_info = date_extractor.read_rois(img_array, [4], boxes, 0)
        if gdnt_info is not None:
            gdnt_data.append(gdnt_info)

        surface_info = date_extractor.read_rois(img_array, [3,6,8,9,10,11], boxes, 0)
        if surface_info is not None:
            surface_data.append(surface_info)

    # If we have detected tables but no extracted images, handle that case
    if len(table_data) > 0 and len(table_images) == 0:
        table_images = [Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8))]
    
    # Return the detection result image, any extracted table image, and the JSON data
    return (
        Image.fromarray(img_with_boxes),  # Main detection image
        table_images[0] if table_images else None,  # First table image or None
        table_data[0] if len(table_data) == 1 else table_data,  # JSON data for gr.JSON
        gdnt_data[0] if len(gdnt_data) == 1 else table_data,  # JSON data for gr.JSON
        surface_data[0] if len(surface_data) == 1 else table_data  # JSON data for gr.JSON
    )

# Create Gradio interface
with gr.Blocks(title="CAD 2d Drawing Data Extraction") as app:
    gr.Markdown("# CAD 2d Drawing Data Extraction")
    gr.Markdown("Upload an image to detect objects. Tables will be automatically extracted.")
    
    with gr.Row():
        with gr.Column(scale=2):
            input_image = gr.Image(type="pil", label="Input Image")
            gr.Markdown("## Extracted Table Region")
            table_image = gr.Image()
            gr.Markdown("### Extracted GD&T Data")
            gdnt_json = gr.JSON(open=True)
            gr.Markdown("### Extracted Surface Data")
            surface_json = gr.JSON(open=True)
        with gr.Column(scale=3):
            submit_btn = gr.Button("Detect Objects", variant="primary")
            gr.Markdown("## Detection Results")
            output_image = gr.Image()
            gr.Markdown("### Extracted Table Data")
            table_json = gr.JSON(open=True)

    submit_btn.click(
        fn=process_image,
        inputs=[input_image],
        outputs=[output_image, table_image, table_json, gdnt_json, surface_json]
    )

# Launch the app
if __name__ == "__main__":
    app.launch()