|
import torch |
|
from ultralytics import YOLO |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import pandas as pd |
|
import os |
|
import cv2 |
|
import time |
|
import zipfile |
|
import io |
|
from datetime import datetime |
|
|
|
|
|
try: |
|
from license_plate_ocr import extract_license_plate_text |
|
OCR_AVAILABLE = True |
|
print("Basic OCR module loaded successfully") |
|
except ImportError as e: |
|
print(f"Basic OCR module not available: {e}") |
|
OCR_AVAILABLE = False |
|
|
|
try: |
|
from advanced_ocr import ( |
|
extract_license_plate_text_advanced, |
|
get_available_models, |
|
set_ocr_model, |
|
) |
|
ADVANCED_OCR_AVAILABLE = True |
|
print("Advanced OCR module loaded successfully") |
|
except ImportError as e: |
|
print(f"Advanced OCR module not available: {e}") |
|
ADVANCED_OCR_AVAILABLE = False |
|
|
|
|
|
model = YOLO("best.pt") |
|
class_names = {0: "With Helmet", 1: "Without Helmet", 2: "License Plate"} |
|
|
|
|
|
def crop_license_plates(image, detections, extract_text=False, selected_ocr_model="auto"): |
|
"""Crop license plates and (optionally) run OCR on the crops.""" |
|
cropped_plates = [] |
|
|
|
try: |
|
if isinstance(image, str): |
|
if not os.path.exists(image): |
|
print(f"Error: Image file not found: {image}") |
|
return cropped_plates |
|
image = Image.open(image) |
|
elif isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
elif not isinstance(image, Image.Image): |
|
print(f"Error: Unsupported image type: {type(image)}") |
|
return cropped_plates |
|
|
|
if image.size[0] == 0 or image.size[1] == 0: |
|
print("Error: Image has zero dimensions") |
|
return cropped_plates |
|
except Exception as e: |
|
print(f"Error loading image: {e}") |
|
return cropped_plates |
|
|
|
for i, detection in enumerate(detections): |
|
try: |
|
if detection["Object"] != "License Plate": |
|
continue |
|
|
|
pos_str = detection["Position"].strip("()") |
|
if "," not in pos_str: |
|
print( |
|
f"Error: Invalid position format for detection {i}: {detection['Position']}" |
|
) |
|
continue |
|
|
|
x1, y1 = map(int, pos_str.split(", ")) |
|
|
|
dims_str = detection["Dimensions"] |
|
if "x" not in dims_str: |
|
print( |
|
f"Error: Invalid dimensions format for detection {i}: {detection['Dimensions']}" |
|
) |
|
continue |
|
|
|
width, height = map(int, dims_str.split("x")) |
|
|
|
if width <= 0 or height <= 0: |
|
print(f"Error: Invalid dimensions for detection {i}: {width}x{height}") |
|
continue |
|
|
|
x2, y2 = x1 + width, y1 + height |
|
|
|
if x1 < 0 or y1 < 0 or x2 > image.width or y2 > image.height: |
|
print( |
|
f"Warning: Bounding box extends beyond image boundaries for detection {i}" |
|
) |
|
x1 = max(0, x1) |
|
y1 = max(0, y1) |
|
x2 = min(image.width, x2) |
|
y2 = min(image.height, y2) |
|
|
|
if x2 <= x1 or y2 <= y1: |
|
print( |
|
f"Error: Invalid crop coordinates for detection {i}: ({x1},{y1}) to ({x2},{y2})" |
|
) |
|
continue |
|
|
|
cropped_plate = image.crop((x1, y1, x2, y2)) |
|
|
|
if cropped_plate.size[0] == 0 or cropped_plate.size[1] == 0: |
|
print( |
|
f"Error: Cropped image has zero dimensions for detection {i}" |
|
) |
|
continue |
|
|
|
plate_data = { |
|
"image": cropped_plate, |
|
"confidence": detection["Confidence"], |
|
"position": detection["Position"], |
|
"crop_coords": f"({x1},{y1}) to ({x2},{y2})", |
|
"text": "Processing...", |
|
} |
|
|
|
if extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE): |
|
try: |
|
print( |
|
f"Extracting text from license plate {i+1} using {selected_ocr_model}..." |
|
) |
|
|
|
if ADVANCED_OCR_AVAILABLE and selected_ocr_model != "basic": |
|
if selected_ocr_model != "auto": |
|
set_ocr_model(selected_ocr_model) |
|
plate_text = extract_license_plate_text_advanced( |
|
cropped_plate, |
|
None if selected_ocr_model == "auto" else selected_ocr_model, |
|
) |
|
else: |
|
plate_text = extract_license_plate_text(cropped_plate) |
|
|
|
if ( |
|
plate_text |
|
and plate_text.strip() |
|
and not plate_text.startswith("Error") |
|
): |
|
plate_data["text"] = plate_text.strip() |
|
print(f"Extracted text: {plate_text.strip()}") |
|
else: |
|
plate_data["text"] = "No text detected" |
|
print(f"No text found in plate {i+1}") |
|
except Exception as e: |
|
print(f"OCR extraction failed for plate {i+1}: {e}") |
|
plate_data["text"] = f"OCR Failed: {str(e)}" |
|
elif extract_text and not (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE): |
|
plate_data["text"] = "OCR not available" |
|
else: |
|
plate_data["text"] = "OCR disabled" |
|
|
|
cropped_plates.append(plate_data) |
|
|
|
except ValueError as e: |
|
print(f"Error parsing coordinates for detection {i}: {e}") |
|
continue |
|
except Exception as e: |
|
print(f"Error cropping license plate {i}: {e}") |
|
continue |
|
|
|
return cropped_plates |
|
|
|
|
|
def create_download_files(annotated_image, cropped_plates, detections): |
|
try: |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
os.makedirs("temp", exist_ok=True) |
|
|
|
annotated_path = f"temp/annotated_image_{timestamp}.jpg" |
|
try: |
|
annotated_image.save(annotated_path, quality=95) |
|
except Exception as e: |
|
print(f"Error saving annotated image: {e}") |
|
return None, None, [] |
|
|
|
plate_paths = [] |
|
for i, plate_data in enumerate(cropped_plates): |
|
try: |
|
plate_path = f"temp/license_plate_{i+1}_{timestamp}.jpg" |
|
plate_data["image"].save(plate_path, quality=95) |
|
plate_paths.append(plate_path) |
|
except Exception as e: |
|
print(f"Error saving license plate {i+1}: {e}") |
|
continue |
|
|
|
report_data = [] |
|
for detection in detections: |
|
report_data.append(detection) |
|
|
|
for i, plate_data in enumerate(cropped_plates): |
|
report_data.append( |
|
{ |
|
"Object": f"License Plate {i+1} - Text", |
|
"Confidence": plate_data["confidence"], |
|
"Position": plate_data["position"], |
|
"Dimensions": "Extracted Text", |
|
"Text": plate_data.get("text", "N/A"), |
|
} |
|
) |
|
|
|
report_path = f"temp/detection_report_{timestamp}.csv" |
|
if report_data: |
|
try: |
|
df = pd.DataFrame(report_data) |
|
df.to_csv(report_path, index=False) |
|
except Exception as e: |
|
print(f"Error creating detection report: {e}") |
|
report_path = None |
|
|
|
zip_path = f"temp/detection_results_{timestamp}.zip" |
|
try: |
|
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: |
|
if os.path.exists(annotated_path): |
|
zipf.write(annotated_path, f"annotated_image_{timestamp}.jpg") |
|
for plate_path in plate_paths: |
|
if os.path.exists(plate_path): |
|
zipf.write(plate_path, os.path.basename(plate_path)) |
|
if report_path and os.path.exists(report_path): |
|
zipf.write(report_path, f"detection_report_{timestamp}.csv") |
|
except Exception as e: |
|
print(f"Error creating ZIP file: {e}") |
|
return None, annotated_path, plate_paths |
|
|
|
return zip_path, annotated_path, plate_paths |
|
|
|
except Exception as e: |
|
print(f"Error in create_download_files: {e}") |
|
return None, None, [] |
|
|
|
|
|
def yolov8_detect( |
|
image=None, |
|
image_size=640, |
|
conf_threshold=0.4, |
|
iou_threshold=0.5, |
|
show_stats=True, |
|
show_confidence=True, |
|
crop_plates=True, |
|
extract_text=False, |
|
ocr_on_no_helmet=False, |
|
selected_ocr_model="auto", |
|
): |
|
"""Main detection function.""" |
|
if image_size is None: |
|
image_size = 640 |
|
if not isinstance(image_size, int): |
|
image_size = int(image_size) |
|
|
|
imgsz = [image_size, image_size] |
|
results = model.predict(image, conf=conf_threshold, iou=iou_threshold, imgsz=imgsz) |
|
|
|
annotated_image = results[0].plot() |
|
if isinstance(annotated_image, np.ndarray): |
|
annotated_image = Image.fromarray(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)) |
|
|
|
boxes = results[0].boxes |
|
detections = [] |
|
if boxes is not None and len(boxes) > 0: |
|
for i, (box, cls, conf) in enumerate(zip(boxes.xyxy, boxes.cls, boxes.conf)): |
|
x1, y1, x2, y2 = box.tolist() |
|
class_id = int(cls) |
|
confidence = float(conf) |
|
label = class_names.get(class_id, f"Class {class_id}") |
|
detections.append( |
|
{ |
|
"Object": label, |
|
"Confidence": f"{confidence:.2f}", |
|
"Position": f"({int(x1)}, {int(y1)})", |
|
"Dimensions": f"{int(x2 - x1)}x{int(y2 - y1)}", |
|
} |
|
) |
|
|
|
cropped_plates = [] |
|
license_plate_gallery = [] |
|
plate_texts = [] |
|
download_files = None |
|
|
|
has_no_helmet = any(d["Object"] == "Without Helmet" for d in detections) |
|
should_extract_text = extract_text or (ocr_on_no_helmet and has_no_helmet) |
|
ocr_available = OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE |
|
|
|
if crop_plates and detections: |
|
try: |
|
license_plate_count = len([d for d in detections if d["Object"] == "License Plate"]) |
|
print(f"Processing {license_plate_count} license plates...") |
|
|
|
if ocr_on_no_helmet and has_no_helmet: |
|
print("⚠️ No helmet detected - OCR will be performed on license plates") |
|
|
|
cropped_plates = crop_license_plates( |
|
image, detections, should_extract_text, selected_ocr_model |
|
) |
|
print(f"Successfully cropped {len(cropped_plates)} license plates") |
|
|
|
license_plate_gallery = [plate_data["image"] for plate_data in cropped_plates] |
|
|
|
if should_extract_text and ocr_available: |
|
print("Extracting text from license plates...") |
|
plate_texts = [] |
|
for i, plate_data in enumerate(cropped_plates): |
|
text = plate_data.get("text", "No text detected") |
|
print(f"Plate {i+1} text: {text}") |
|
if ocr_on_no_helmet and has_no_helmet: |
|
plate_texts.append(f"🚨 No Helmet Violation - Plate {i+1}: {text}") |
|
else: |
|
plate_texts.append(f"Plate {i+1}: {text}") |
|
elif should_extract_text and not ocr_available: |
|
plate_texts = [ |
|
"OCR not available - install requirements: pip install transformers easyocr" |
|
] |
|
elif not should_extract_text: |
|
plate_texts = [ |
|
f"Plate {i+1}: Text extraction disabled" for i in range(len(cropped_plates)) |
|
] |
|
|
|
if cropped_plates or detections: |
|
download_files, _, _ = create_download_files( |
|
annotated_image, cropped_plates, detections |
|
) |
|
if download_files is None: |
|
print("Warning: Could not create download files") |
|
except Exception as e: |
|
print(f"Error in license plate processing: {e}") |
|
cropped_plates = [] |
|
license_plate_gallery = [] |
|
plate_texts = ["Error processing license plates"] |
|
download_files = None |
|
|
|
stats_text = "" |
|
if show_stats and detections: |
|
df = pd.DataFrame(detections) |
|
counts = df["Object"].value_counts().to_dict() |
|
stats_text = "Detection Summary:\n" |
|
for obj, count in counts.items(): |
|
stats_text += f"- {obj}: {count}\n" |
|
|
|
if cropped_plates: |
|
stats_text += f"\nLicense Plates Cropped: {len(cropped_plates)}\n" |
|
if has_no_helmet: |
|
stats_text += "⚠️ HELMET VIOLATION DETECTED!\n" |
|
if should_extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE): |
|
stats_text += "Extracted Text:\n" |
|
for i, plate_data in enumerate(cropped_plates): |
|
text = plate_data.get("text", "No text") |
|
if has_no_helmet and ocr_on_no_helmet: |
|
stats_text += f"🚨 Violation - Plate {i+1}: {text}\n" |
|
else: |
|
stats_text += f"- Plate {i+1}: {text}\n" |
|
|
|
detection_table = ( |
|
pd.DataFrame(detections) |
|
if detections |
|
else pd.DataFrame(columns=["Object", "Confidence", "Position", "Dimensions"]) |
|
) |
|
plate_text_output = ( |
|
"\n".join(plate_texts) |
|
if plate_texts |
|
else "No license plates detected or OCR disabled" |
|
) |
|
|
|
return ( |
|
annotated_image, |
|
detection_table, |
|
stats_text, |
|
license_plate_gallery, |
|
download_files, |
|
plate_text_output, |
|
) |
|
|
|
|
|
def download_sample_images(): |
|
"""Download sample images for testing.""" |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-1.jpg?raw=true", |
|
"sample_1.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-2.jpg?raw=true", |
|
"sample_2.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-3.jpg?raw=true", |
|
"sample_3.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-4.jpg?raw=true", |
|
"sample_4.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-5.jpg?raw=true", |
|
"sample_5.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-6.jpg?raw=true", |
|
"sample_6.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-7.jpg?raw=true", |
|
"sample_7.jpg", |
|
) |
|
torch.hub.download_url_to_file( |
|
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-8.jpg?raw=true", |
|
"sample_8.jpg", |
|
) |
|
|
|
|
|
def get_ocr_status(): |
|
"""Return OCR availability status.""" |
|
return { |
|
"basic_available": OCR_AVAILABLE, |
|
"advanced_available": ADVANCED_OCR_AVAILABLE, |
|
"any_available": OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE |
|
} |