arsath-sm's picture
Update app.py
1c1fcf9 verified
import streamlit as st
import cv2
import numpy as np
import onnxruntime as ort
from PIL import Image
import tempfile
import torch
from ultralytics import YOLO
# Load models
@st.cache_resource
def load_models():
license_plate_detector = YOLO('license_plate_detector.pt')
vehicle_detector = YOLO('yolov8n.pt')
ort_session = ort.InferenceSession("model.onnx")
return license_plate_detector, vehicle_detector, ort_session
def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200):
x1, y1 = top_left
x2, y2 = bottom_right
# Draw corner lines
cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness) # top-left
cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness)
cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness) # bottom-left
cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness)
cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness) # top-right
cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness)
cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness) # bottom-right
cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness)
return img
def process_frame(frame, license_plate_detector, vehicle_detector, ort_session):
# Detect vehicles
vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7]) # cars, motorcycles, bus, trucks
# Process each vehicle
for vehicle in vehicle_results[0].boxes.data:
x1, y1, x2, y2, score, class_id = vehicle
if score > 0.5: # Confidence threshold
# Draw vehicle border
draw_border(frame,
(int(x1), int(y1)),
(int(x2), int(y2)),
color=(0, 255, 0),
thickness=25,
line_length_x=200,
line_length_y=200)
# Detect license plate in vehicle region
vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)]
license_results = license_plate_detector(vehicle_crop)
for license_plate in license_results[0].boxes.data:
lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate
if lp_score > 0.5:
# Adjust coordinates to full frame
abs_lp_x1 = int(x1 + lp_x1)
abs_lp_y1 = int(y1 + lp_y1)
abs_lp_x2 = int(x1 + lp_x2)
abs_lp_y2 = int(y1 + lp_y2)
# Draw license plate box
cv2.rectangle(frame,
(abs_lp_x1, abs_lp_y1),
(abs_lp_x2, abs_lp_y2),
(0, 0, 255), 12)
# Extract and process license plate for OCR
license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2]
if license_crop.size > 0:
# Prepare license crop for ONNX model
license_crop_resized = cv2.resize(license_crop, (640, 640))
license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0
license_crop_processed = np.expand_dims(license_crop_processed, axis=0)
# Run OCR inference
try:
inputs = {ort_session.get_inputs()[0].name: license_crop_processed}
outputs = ort_session.run(None, inputs)
# Process OCR output (adjust based on your model's output format)
# This is a placeholder - adjust based on your ONNX model's output
license_number = "ABC123" # Replace with actual OCR processing
# Display license plate number
H, W, _ = license_crop.shape
license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400))
try:
# Display license crop and number above vehicle
h_crop, w_crop, _ = license_crop_display.shape
center_x = int((x1 + x2) / 2)
# Display license plate crop
frame[int(y1) - h_crop - 100:int(y1) - 100,
int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display
# White background for text
cv2.rectangle(frame,
(int(center_x - w_crop/2), int(y1) - h_crop - 400),
(int(center_x + w_crop/2), int(y1) - h_crop - 100),
(255, 255, 255),
-1)
# Draw license number
(text_width, text_height), _ = cv2.getTextSize(
license_number,
cv2.FONT_HERSHEY_SIMPLEX,
4.3,
17)
cv2.putText(frame,
license_number,
(int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)),
cv2.FONT_HERSHEY_SIMPLEX,
4.3,
(0, 0, 0),
17)
except Exception as e:
st.error(f"Error displaying results: {str(e)}")
except Exception as e:
st.error(f"Error in OCR processing: {str(e)}")
return frame
def process_video(video_path, license_plate_detector, vehicle_detector, ort_session):
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
out = cv2.VideoWriter(temp_file.name,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
progress_bar = st.progress(0)
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session)
out.write(processed_frame)
frame_count += 1
progress_bar.progress(frame_count / total_frames)
cap.release()
out.release()
progress_bar.empty()
return temp_file.name
# Streamlit UI
st.title("Advanced Vehicle and License Plate Detection")
try:
license_plate_detector, vehicle_detector, ort_session = load_models()
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"])
if uploaded_file is not None:
file_type = uploaded_file.type.split('/')[0]
if file_type == "image":
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Detect"):
with st.spinner("Processing image..."):
# Convert PIL Image to CV2 format
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session)
processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
st.image(processed_image, caption="Processed Image", use_column_width=True)
elif file_type == "video":
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
st.video(tfile.name)
if st.button("Detect"):
with st.spinner("Processing video..."):
processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session)
st.video(processed_video)
except Exception as e:
st.error(f"Error loading models: {str(e)}")