import streamlit as st | |
import os | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
from ultralytics import YOLO | |
import Levenshtein | |
# Page config | |
st.set_page_config( | |
page_title="Thai License Plate Detection", | |
page_icon="🚗", | |
layout="centered" | |
) | |
# Initialize session state for models | |
if 'models_loaded' not in st.session_state: | |
st.session_state['models_loaded'] = False | |
def load_ocr_models(): | |
"""Load OCR models with proper error handling""" | |
try: | |
# Set environment variables to suppress warnings | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
# Load processor with specific config | |
processor = TrOCRProcessor.from_pretrained( | |
'openthaigpt/thai-trocr', | |
revision='main', | |
use_auth_token=False, | |
trust_remote_code=True, | |
local_files_only=False | |
) | |
# Load OCR model with specific config | |
ocr_model = VisionEncoderDecoderModel.from_pretrained( | |
'openthaigpt/thai-trocr', | |
revision='main', | |
use_auth_token=False, | |
trust_remote_code=True, | |
local_files_only=False | |
) | |
# Move model to CPU explicitly | |
ocr_model ='cpu') | |
return processor, ocr_model | |
except Exception as e: | |
st.error(f"Error loading OCR models: {str(e)}") | |
st.error("Detailed error information:") | |
import traceback | |
st.code(traceback.format_exc()) | |
return None, None | |
# Load models | |
def load_models(): | |
try: | |
# Check if YOLO weights exist | |
if not os.path.exists(''): | |
st.error("YOLO model weights ( not found in the current directory!") | |
return None, None, None | |
# Load YOLO model | |
try: | |
yolo_model = YOLO('', task='detect') | |
except Exception as yolo_error: | |
st.error(f"Error loading YOLO model: {str(yolo_error)}") | |
return None, None, None | |
# Load OCR models | |
processor, ocr_model = load_ocr_models() | |
if processor is None or ocr_model is None: | |
return None, None, None | |
return processor, ocr_model, yolo_model | |
except Exception as e: | |
st.error(f"Error in model loading: {str(e)}") | |
st.error("Detailed error information:") | |
import traceback | |
st.code(traceback.format_exc()) | |
return None, None, None | |
# Thai provinces list | |
thai_provinces = [ | |
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา", | |
"ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", | |
"นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", | |
"บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", | |
"พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน", | |
"ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย", | |
"ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี", | |
"สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", | |
"อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง" | |
] | |
def get_closest_province(input_text, provinces): | |
min_distance = float('inf') | |
closest_province = None | |
for province in provinces: | |
distance = Levenshtein.distance(input_text, province) | |
if distance < min_distance: | |
min_distance = distance | |
closest_province = province | |
return closest_province, min_distance | |
def process_image(image, processor, ocr_model, yolo_model): | |
data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None} | |
# Convert PIL Image to cv2 format | |
image = np.array(image) | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
# Image enhancement | |
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
l, a, b = cv2.split(lab) | |
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
cl = clahe.apply(l) | |
enhanced = cv2.merge((cl,a,b)) | |
image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) | |
# YOLO detection | |
results = yolo_model(image) | |
# Process detections | |
detections = [] | |
for result in results: | |
for box in result.boxes: | |
confidence = float(box.conf) | |
class_id = int(box.cls.item()) | |
if confidence < CONF_THRESHOLD: | |
continue | |
x1, y1, x2, y2 = map(int, box.xyxy.flatten()) | |
detections.append((class_id, confidence, (x1, y1, x2, y2))) | |
# Sort by class_id | |
detections.sort(key=lambda x: x[0]) | |
for class_id, confidence, (x1, y1, x2, y2) in detections: | |
cropped_image = image[y1:y2, x1:x2] | |
if cropped_image.size == 0: | |
continue | |
# Preprocess for OCR | |
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY) | |
thresh_image = cv2.adaptiveThreshold( | |
cropped_image_gray, | |
255, | |
11, | |
2 | |
) | |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2)) | |
thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel) | |
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB) | |
resized_image = cv2.resize(cropped_image_3d, (128, 32)) | |
# OCR processing | |
pixel_values = processor(resized_image, return_tensors="pt").pixel_values | |
generated_ids = ocr_model.generate(pixel_values) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Convert crop to PIL for display | |
cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) | |
if class_id == 0: # License plate | |
data["plate_number"] = generated_text | |
data["plate_crop"] = cropped_pil | |
elif class_id == 1: # Province | |
generated_province, distance = get_closest_province(generated_text, thai_provinces) | |
data["raw_province"] = generated_text | |
data["province"] = generated_province | |
data["province_crop"] = cropped_pil | |
return data | |
# Main app | |
st.title("Thai License Plate Detection 🚗") | |
# Load models | |
try: | |
if not st.session_state['models_loaded']: | |
with st.spinner("Loading models... (this may take a minute)"): | |
processor, ocr_model, yolo_model = load_models() | |
st.session_state['models_loaded'] = True | |
st.session_state['processor'] = processor | |
st.session_state['ocr_model'] = ocr_model | |
st.session_state['yolo_model'] = yolo_model | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.stop() | |
# File uploader | |
uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
try: | |
# Display the uploaded image | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Uploaded Image") | |
image = | |
st.image(image, use_column_width=True) | |
# Process the image | |
with col2: | |
st.subheader("Detection Results") | |
with st.spinner("Processing image..."): | |
results = process_image( | |
image, | |
st.session_state['processor'], | |
st.session_state['ocr_model'], | |
st.session_state['yolo_model'] | |
) | |
if results["plate_number"]: | |
st.success("Detection successful!") | |
st.write("📝 License Plate:", results['plate_number']) | |
if results['plate_crop'] is not None: | |
st.subheader("Cropped License Plate") | |
st.image(results['plate_crop'], caption="Detected License Plate Region") | |
if results['raw_province']: | |
st.write("🔍 Detected Province Text:", results['raw_province']) | |
if results['province']: | |
st.write("🏠 Matched Province:", results['province']) | |
else: | |
st.write("⚠️ No close province match found") | |
if results['province_crop'] is not None: | |
st.subheader("Cropped Province") | |
st.image(results['province_crop'], caption="Detected Province Region") | |
else: | |
st.write("⚠️ No province text detected") | |
else: | |
st.error("No license plate detected in the image.") | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.markdown("---") | |
st.markdown("### Instructions") | |
st.markdown(""" | |
1. Upload an image containing a Thai license plate | |
2. Wait for the processing to complete | |
3. View the detected license plate number and province | |
""") | |
# Add footer with GitHub link | |
st.markdown("---") | |
