|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
import onnxruntime as ort |
|
from PIL import Image |
|
import tempfile |
|
import torch |
|
from ultralytics import YOLO |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
license_plate_detector = YOLO('license_plate_detector.pt') |
|
vehicle_detector = YOLO('yolov8n.pt') |
|
ort_session = ort.InferenceSession("model.onnx") |
|
return license_plate_detector, vehicle_detector, ort_session |
|
|
|
def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200): |
|
x1, y1 = top_left |
|
x2, y2 = bottom_right |
|
|
|
|
|
cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness) |
|
cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness) |
|
|
|
cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness) |
|
cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness) |
|
|
|
cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness) |
|
cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness) |
|
|
|
cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness) |
|
cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness) |
|
|
|
return img |
|
|
|
def process_frame(frame, license_plate_detector, vehicle_detector, ort_session): |
|
|
|
vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7]) |
|
|
|
|
|
for vehicle in vehicle_results[0].boxes.data: |
|
x1, y1, x2, y2, score, class_id = vehicle |
|
if score > 0.5: |
|
|
|
draw_border(frame, |
|
(int(x1), int(y1)), |
|
(int(x2), int(y2)), |
|
color=(0, 255, 0), |
|
thickness=25, |
|
line_length_x=200, |
|
line_length_y=200) |
|
|
|
|
|
vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)] |
|
license_results = license_plate_detector(vehicle_crop) |
|
|
|
for license_plate in license_results[0].boxes.data: |
|
lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate |
|
if lp_score > 0.5: |
|
|
|
abs_lp_x1 = int(x1 + lp_x1) |
|
abs_lp_y1 = int(y1 + lp_y1) |
|
abs_lp_x2 = int(x1 + lp_x2) |
|
abs_lp_y2 = int(y1 + lp_y2) |
|
|
|
|
|
cv2.rectangle(frame, |
|
(abs_lp_x1, abs_lp_y1), |
|
(abs_lp_x2, abs_lp_y2), |
|
(0, 0, 255), 12) |
|
|
|
|
|
license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2] |
|
if license_crop.size > 0: |
|
|
|
license_crop_resized = cv2.resize(license_crop, (640, 640)) |
|
license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0 |
|
license_crop_processed = np.expand_dims(license_crop_processed, axis=0) |
|
|
|
|
|
try: |
|
inputs = {ort_session.get_inputs()[0].name: license_crop_processed} |
|
outputs = ort_session.run(None, inputs) |
|
|
|
|
|
|
|
license_number = "ABC123" |
|
|
|
|
|
H, W, _ = license_crop.shape |
|
license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400)) |
|
|
|
try: |
|
|
|
h_crop, w_crop, _ = license_crop_display.shape |
|
center_x = int((x1 + x2) / 2) |
|
|
|
|
|
frame[int(y1) - h_crop - 100:int(y1) - 100, |
|
int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display |
|
|
|
|
|
cv2.rectangle(frame, |
|
(int(center_x - w_crop/2), int(y1) - h_crop - 400), |
|
(int(center_x + w_crop/2), int(y1) - h_crop - 100), |
|
(255, 255, 255), |
|
-1) |
|
|
|
|
|
(text_width, text_height), _ = cv2.getTextSize( |
|
license_number, |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
4.3, |
|
17) |
|
|
|
cv2.putText(frame, |
|
license_number, |
|
(int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
4.3, |
|
(0, 0, 0), |
|
17) |
|
except Exception as e: |
|
st.error(f"Error displaying results: {str(e)}") |
|
except Exception as e: |
|
st.error(f"Error in OCR processing: {str(e)}") |
|
|
|
return frame |
|
|
|
def process_video(video_path, license_plate_detector, vehicle_detector, ort_session): |
|
cap = cv2.VideoCapture(video_path) |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') |
|
out = cv2.VideoWriter(temp_file.name, |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
fps, |
|
(width, height)) |
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
progress_bar = st.progress(0) |
|
frame_count = 0 |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session) |
|
out.write(processed_frame) |
|
|
|
frame_count += 1 |
|
progress_bar.progress(frame_count / total_frames) |
|
|
|
cap.release() |
|
out.release() |
|
progress_bar.empty() |
|
|
|
return temp_file.name |
|
|
|
|
|
st.title("Advanced Vehicle and License Plate Detection") |
|
|
|
try: |
|
license_plate_detector, vehicle_detector, ort_session = load_models() |
|
|
|
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"]) |
|
|
|
if uploaded_file is not None: |
|
file_type = uploaded_file.type.split('/')[0] |
|
|
|
if file_type == "image": |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
if st.button("Detect"): |
|
with st.spinner("Processing image..."): |
|
|
|
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session) |
|
processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) |
|
st.image(processed_image, caption="Processed Image", use_column_width=True) |
|
|
|
elif file_type == "video": |
|
tfile = tempfile.NamedTemporaryFile(delete=False) |
|
tfile.write(uploaded_file.read()) |
|
st.video(tfile.name) |
|
|
|
if st.button("Detect"): |
|
with st.spinner("Processing video..."): |
|
processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session) |
|
st.video(processed_video) |
|
|
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |