Sompote's picture
Upload 6 files
7ced5c0 verified
raw
history blame
12 kB
import streamlit as st
from PIL import Image
import cv2
import numpy as np
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
import yaml
import os
import io
import tempfile
import torchvision
class LicensePlateProcessor:
def __init__(self):
# Load models for plate detection
self.yolo_detector = YOLO('models/best.pt') # For plate detection
self.char_reader = YOLO('models/read_char.pt') # For character reading
# Load TrOCR for province detection
self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr')
self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr')
# Load character mapping from yaml
with open('config/data.yaml', 'r', encoding='utf-8') as f:
data_config = yaml.safe_load(f)
self.char_mapping = data_config.get('char_mapping', {})
self.names = data_config['names']
# Load province list
self.thai_provinces = [
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น",
"จันทบุรี", "ฉะเชิงเทรา", "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย",
"เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", "นครปฐม", "นครพนม", "นครราชสีมา",
"นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", "บุรีรัมย์", "ปทุมธานี",
"ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", "พิจิตร",
"พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร",
"แม่ฮ่องสอน", "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี",
"ลำปาง", "ลำพูน", "เลย", "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม",
"สมุทรสาคร", "สระแก้ว", "สระบุรี", "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี",
"สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", "อุดรธานี", "อุทัยธานี",
"อุบลราชธานี", "อ่างทอง"
]
self.CONF_THRESHOLD = 0.3
def _map_class_to_char(self, class_name):
"""Map class to character using yaml mapping"""
if str(class_name) in self.char_mapping:
return self.char_mapping[str(class_name)]
return str(class_name)
def get_closest_province(self, input_text):
"""Find closest matching province"""
min_distance = float('inf')
closest_province = None
for province in self.thai_provinces:
distance = Levenshtein.distance(input_text, province)
if distance < min_distance:
min_distance = distance
closest_province = province
return closest_province, min_distance
def read_plate_characters(self, plate_image):
"""Read characters from plate image"""
results = self.char_reader.predict(plate_image, conf=0.3)
detections = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
confidence = float(box.conf[0])
class_id = int(box.cls[0])
mapped_char = self._map_class_to_char(self.names[class_id])
detections.append({
'char': mapped_char,
'confidence': confidence,
'bbox': (x1, y1, x2, y2)
})
# Sort detections left to right
detections.sort(key=lambda x: x['bbox'][0])
# Combine characters
plate_text = ''.join(det['char'] for det in detections)
return plate_text
def process_image(self, image_path: str):
try:
# Read image
image = cv2.imread(image_path)
if image is None:
print(f"Error: Could not read image from {image_path}")
return None
# Detect license plate location
results = self.yolo_detector(image)
data = {"plate_number": "", "province": "", "raw_province": ""}
# Save visualization
output_image = image.copy()
for result in results:
for box in result.boxes:
confidence = float(box.conf)
if confidence < self.CONF_THRESHOLD:
continue
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
cropped_image = image[y1:y2, x1:x2]
# Draw rectangle on output image
color = (0, 255, 0) if int(box.cls.item()) == 0 else (255, 0, 0)
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
if int(box.cls.item()) == 0: # License plate number
# Read characters using YOLO character reader
data["plate_number"] = self.read_plate_characters(cropped_image)
elif int(box.cls.item()) == 1: # Province
# Process province using TrOCR
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
equalized_image = cv2.equalizeHist(cropped_image_gray)
_, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV)
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
resized_image = cv2.resize(cropped_image_3d, (128, 32))
pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values
generated_ids = self.model_plate.generate(pixel_values)
generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_province, _ = self.get_closest_province(generated_text)
data["raw_province"] = generated_text
data["province"] = generated_province
# Save the output image
cv2.imwrite('output_detection.jpg', output_image)
return data
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
def main():
st.set_page_config(
page_title="Thai License Plate Recognition",
layout="wide"
)
st.title("Thai License Plate Recognition")
st.write("Upload an image to detect and read Thai license plates")
# Initialize processor
@st.cache_resource
def load_processor():
return LicensePlateProcessor()
processor = load_processor()
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Create columns for side-by-side display
col1, col2 = st.columns(2)
# Display original image
with col1:
st.subheader("Original Image")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
# Convert PIL Image to OpenCV format for processing
image_array = np.array(image)
if len(image_array.shape) == 3 and image_array.shape[2] == 4:
# Convert RGBA to RGB if needed
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB)
image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
# Process image
with st.spinner("Processing image..."):
try:
# Save the OpenCV image for processing
temp_path = 'temp_input.jpg'
cv2.imwrite(temp_path, image_cv)
# Process the image using the processor
results = processor.process_image(temp_path)
# Clean up temporary input file
os.remove(temp_path)
if results:
# Display results
st.subheader("Detection Results")
# Create a styled container for results
results_container = st.container()
with results_container:
st.markdown(f"""
<div style='background-color: #f0f2f6; padding: 20px; border-radius: 10px;'>
<h3>License Plate: {results['plate_number']}</h3>
<h3>Province: {results['province']}</h3>
<p>Raw Province Text: {results['raw_province']}</p>
</div>
""", unsafe_allow_html=True)
# Display detection visualization
with col2:
st.subheader("Detection Visualization")
if os.path.exists('output_detection.jpg'):
# Read and convert the output image from BGR to RGB
output_image = cv2.imread('output_detection.jpg')
output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
st.image(output_image_rgb, use_column_width=True)
# Clean up output image
os.remove('output_detection.jpg')
else:
st.error("No license plate detected in the image.")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Clean up any temporary files in case of error
if os.path.exists('temp_input.jpg'):
os.remove('temp_input.jpg')
if os.path.exists('output_detection.jpg'):
os.remove('output_detection.jpg')
# Add information about the application
with st.expander("About This Application"):
st.markdown("""
### Thai License Plate Recognition System
This application uses advanced computer vision and deep learning to:
- Detect license plates in images using YOLO
- Read Thai license plate numbers using character recognition
- Identify province names using TrOCR
- Provide visual detection results
#### How to Use:
1. Click the 'Browse files' button above
2. Select an image containing a Thai license plate
3. Wait for the processing to complete
4. View the results and detection visualization
#### Technologies Used:
- YOLO for license plate detection
- Custom YOLO model for character recognition
- TrOCR for province text recognition
""")
if __name__ == "__main__":
main()