import streamlit as st
import os
import numpy as np
import cv2
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
# Page config
page_title="Thai License Plate Detection",
# Initialize session state for models
if 'models_loaded' not in st.session_state:
st.session_state['models_loaded'] = False
def load_ocr_models():
"""Load OCR models with proper error handling"""
# Set environment variables to suppress warnings
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# Load processor with specific config
processor = TrOCRProcessor.from_pretrained(
# Load OCR model with specific config
ocr_model = VisionEncoderDecoderModel.from_pretrained(
# Move model to CPU explicitly
ocr_model ='cpu')
return processor, ocr_model
except Exception as e:
st.error(f"Error loading OCR models: {str(e)}")
st.error("Detailed error information:")
import traceback
return None, None
# Load models
def load_models():
# Check if YOLO weights exist
if not os.path.exists(''):
st.error("YOLO model weights ( not found in the current directory!")
return None, None, None
# Load YOLO model
yolo_model = YOLO('', task='detect')
except Exception as yolo_error:
st.error(f"Error loading YOLO model: {str(yolo_error)}")
return None, None, None
# Load OCR models
processor, ocr_model = load_ocr_models()
if processor is None or ocr_model is None:
return None, None, None
return processor, ocr_model, yolo_model
except Exception as e:
st.error(f"Error in model loading: {str(e)}")
st.error("Detailed error information:")
import traceback
return None, None, None
# Thai provinces list
thai_provinces = [
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา",
"ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก",
"นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ",
"บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง",
"พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน",
"ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย",
"ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี",
"สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ",
"อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง"
def get_closest_province(input_text, provinces):
min_distance = float('inf')
closest_province = None
for province in provinces:
distance = Levenshtein.distance(input_text, province)
if distance < min_distance:
min_distance = distance
closest_province = province
return closest_province, min_distance
def process_image(image, processor, ocr_model, yolo_model):
data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None}
# Convert PIL Image to cv2 format
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Image enhancement
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
cl = clahe.apply(l)
enhanced = cv2.merge((cl,a,b))
image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
# YOLO detection
results = yolo_model(image)
# Process detections
detections = []
for result in results:
for box in result.boxes:
confidence = float(box.conf)
class_id = int(box.cls.item())
if confidence < CONF_THRESHOLD:
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
detections.append((class_id, confidence, (x1, y1, x2, y2)))
# Sort by class_id
detections.sort(key=lambda x: x[0])
for class_id, confidence, (x1, y1, x2, y2) in detections:
cropped_image = image[y1:y2, x1:x2]
if cropped_image.size == 0:
# Preprocess for OCR
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
thresh_image = cv2.adaptiveThreshold(
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel)
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
resized_image = cv2.resize(cropped_image_3d, (128, 32))
# OCR processing
pixel_values = processor(resized_image, return_tensors="pt").pixel_values
generated_ids = ocr_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Convert crop to PIL for display
cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
if class_id == 0: # License plate
data["plate_number"] = generated_text
data["plate_crop"] = cropped_pil
elif class_id == 1: # Province
generated_province, distance = get_closest_province(generated_text, thai_provinces)
data["raw_province"] = generated_text
data["province"] = generated_province
data["province_crop"] = cropped_pil
return data
# Main app
st.title("Thai License Plate Detection 🚗")
# Load models
if not st.session_state['models_loaded']:
with st.spinner("Loading models... (this may take a minute)"):
processor, ocr_model, yolo_model = load_models()
st.session_state['models_loaded'] = True
st.session_state['processor'] = processor
st.session_state['ocr_model'] = ocr_model
st.session_state['yolo_model'] = yolo_model
except Exception as e:
st.error(f"Error loading models: {str(e)}")
# File uploader
uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display the uploaded image
col1, col2 = st.columns(2)
with col1:
st.subheader("Uploaded Image")
image =
st.image(image, use_column_width=True)
# Process the image
with col2:
st.subheader("Detection Results")
with st.spinner("Processing image..."):
results = process_image(
if results["plate_number"]:
st.success("Detection successful!")
st.write("📝 License Plate:", results['plate_number'])
if results['plate_crop'] is not None:
st.subheader("Cropped License Plate")
st.image(results['plate_crop'], caption="Detected License Plate Region")
if results['raw_province']:
st.write("🔍 Detected Province Text:", results['raw_province'])
if results['province']:
st.write("🏠 Matched Province:", results['province'])
st.write("⚠️ No close province match found")
if results['province_crop'] is not None:
st.subheader("Cropped Province")
st.image(results['province_crop'], caption="Detected Province Region")
st.write("⚠️ No province text detected")
st.error("No license plate detected in the image.")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.markdown("### Instructions")
1. Upload an image containing a Thai license plate
2. Wait for the processing to complete
3. View the detected license plate number and province
# Add footer with GitHub link
