File size: 9,072 Bytes
6d35603
 
 
 
 
 
1c1fcf9
 
6d35603
1c1fcf9
6d35603
1c1fcf9
 
 
 
 
ab634f0
1c1fcf9
 
 
ab634f0
1c1fcf9
 
 
6d35603
1c1fcf9
 
6d35603
1c1fcf9
 
95a4f19
1c1fcf9
 
6d35603
1c1fcf9
6d35603
1c1fcf9
 
 
6d35603
1c1fcf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d35603
1c1fcf9
6d35603
1c1fcf9
6d35603
 
 
 
 
 
 
1c1fcf9
 
 
 
6af8008
 
 
 
6d35603
 
 
 
 
 
1c1fcf9
 
6af8008
 
 
6d35603
 
 
6af8008
6d35603
 
 
6af8008
1c1fcf9
6af8008
1c1fcf9
 
6d35603
1c1fcf9
 
 
 
6d35603
1c1fcf9
 
 
 
 
 
 
 
 
 
 
6d35603
1c1fcf9
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import streamlit as st
import cv2
import numpy as np
import onnxruntime as ort
from PIL import Image
import tempfile
import torch
from ultralytics import YOLO

# Load models
@st.cache_resource
def load_models():
    license_plate_detector = YOLO('license_plate_detector.pt')
    vehicle_detector = YOLO('yolov8n.pt')
    ort_session = ort.InferenceSession("model.onnx")
    return license_plate_detector, vehicle_detector, ort_session

def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200):
    x1, y1 = top_left
    x2, y2 = bottom_right

    # Draw corner lines
    cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness)  # top-left
    cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness)

    cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness)  # bottom-left
    cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness)

    cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness)  # top-right
    cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness)

    cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness)  # bottom-right
    cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness)

    return img

def process_frame(frame, license_plate_detector, vehicle_detector, ort_session):
    # Detect vehicles
    vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7])  # cars, motorcycles, bus, trucks
    
    # Process each vehicle
    for vehicle in vehicle_results[0].boxes.data:
        x1, y1, x2, y2, score, class_id = vehicle
        if score > 0.5:  # Confidence threshold
            # Draw vehicle border
            draw_border(frame, 
                       (int(x1), int(y1)), 
                       (int(x2), int(y2)), 
                       color=(0, 255, 0), 
                       thickness=25,
                       line_length_x=200, 
                       line_length_y=200)
            
            # Detect license plate in vehicle region
            vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)]
            license_results = license_plate_detector(vehicle_crop)
            
            for license_plate in license_results[0].boxes.data:
                lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate
                if lp_score > 0.5:
                    # Adjust coordinates to full frame
                    abs_lp_x1 = int(x1 + lp_x1)
                    abs_lp_y1 = int(y1 + lp_y1)
                    abs_lp_x2 = int(x1 + lp_x2)
                    abs_lp_y2 = int(y1 + lp_y2)
                    
                    # Draw license plate box
                    cv2.rectangle(frame, 
                                (abs_lp_x1, abs_lp_y1), 
                                (abs_lp_x2, abs_lp_y2), 
                                (0, 0, 255), 12)
                    
                    # Extract and process license plate for OCR
                    license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2]
                    if license_crop.size > 0:
                        # Prepare license crop for ONNX model
                        license_crop_resized = cv2.resize(license_crop, (640, 640))
                        license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0
                        license_crop_processed = np.expand_dims(license_crop_processed, axis=0)
                        
                        # Run OCR inference
                        try:
                            inputs = {ort_session.get_inputs()[0].name: license_crop_processed}
                            outputs = ort_session.run(None, inputs)
                            
                            # Process OCR output (adjust based on your model's output format)
                            # This is a placeholder - adjust based on your ONNX model's output
                            license_number = "ABC123"  # Replace with actual OCR processing
                            
                            # Display license plate number
                            H, W, _ = license_crop.shape
                            license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400))
                            
                            try:
                                # Display license crop and number above vehicle
                                h_crop, w_crop, _ = license_crop_display.shape
                                center_x = int((x1 + x2) / 2)
                                
                                # Display license plate crop
                                frame[int(y1) - h_crop - 100:int(y1) - 100,
                                      int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display
                                
                                # White background for text
                                cv2.rectangle(frame,
                                            (int(center_x - w_crop/2), int(y1) - h_crop - 400),
                                            (int(center_x + w_crop/2), int(y1) - h_crop - 100),
                                            (255, 255, 255),
                                            -1)
                                
                                # Draw license number
                                (text_width, text_height), _ = cv2.getTextSize(
                                    license_number,
                                    cv2.FONT_HERSHEY_SIMPLEX,
                                    4.3,
                                    17)
                                
                                cv2.putText(frame,
                                            license_number,
                                            (int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)),
                                            cv2.FONT_HERSHEY_SIMPLEX,
                                            4.3,
                                            (0, 0, 0),
                                            17)
                            except Exception as e:
                                st.error(f"Error displaying results: {str(e)}")
                        except Exception as e:
                            st.error(f"Error in OCR processing: {str(e)}")
    
    return frame

def process_video(video_path, license_plate_detector, vehicle_detector, ort_session):
    cap = cv2.VideoCapture(video_path)
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
    out = cv2.VideoWriter(temp_file.name,
                         cv2.VideoWriter_fourcc(*'mp4v'),
                         fps,
                         (width, height))
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    progress_bar = st.progress(0)
    frame_count = 0
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session)
        out.write(processed_frame)
        
        frame_count += 1
        progress_bar.progress(frame_count / total_frames)
    
    cap.release()
    out.release()
    progress_bar.empty()
    
    return temp_file.name

# Streamlit UI
st.title("Advanced Vehicle and License Plate Detection")

try:
    license_plate_detector, vehicle_detector, ort_session = load_models()
    
    uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"])

    if uploaded_file is not None:
        file_type = uploaded_file.type.split('/')[0]
        
        if file_type == "image":
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_column_width=True)
            
            if st.button("Detect"):
                with st.spinner("Processing image..."):
                    # Convert PIL Image to CV2 format
                    image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
                    processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session)
                    processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
                    st.image(processed_image, caption="Processed Image", use_column_width=True)
        
        elif file_type == "video":
            tfile = tempfile.NamedTemporaryFile(delete=False)
            tfile.write(uploaded_file.read())
            st.video(tfile.name)
            
            if st.button("Detect"):
                with st.spinner("Processing video..."):
                    processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session)
                    st.video(processed_video)

except Exception as e:
    st.error(f"Error loading models: {str(e)}")