Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
from ultralytics import YOLO | |
import Levenshtein | |
import yaml | |
import os | |
import io | |
import tempfile | |
import torchvision | |
class LicensePlateProcessor: | |
def __init__(self): | |
# Load models for plate detection | |
self.yolo_detector = YOLO('models/best.pt') # For plate detection | |
self.char_reader = YOLO('models/read_char.pt') # For character reading | |
# Load TrOCR for province detection | |
self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr') | |
self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr') | |
# Load character mapping from yaml | |
with open('config/data.yaml', 'r', encoding='utf-8') as f: | |
data_config = yaml.safe_load(f) | |
self.char_mapping = data_config.get('char_mapping', {}) | |
self.names = data_config['names'] | |
# Load province list | |
self.thai_provinces = [ | |
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", | |
"จันทบุรี", "ฉะเชิงเทรา", "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", | |
"เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", "นครปฐม", "นครพนม", "นครราชสีมา", | |
"นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", "บุรีรัมย์", "ปทุมธานี", | |
"ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", "พิจิตร", | |
"พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", | |
"แม่ฮ่องสอน", "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", | |
"ลำปาง", "ลำพูน", "เลย", "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", | |
"สมุทรสาคร", "สระแก้ว", "สระบุรี", "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", | |
"สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", "อุดรธานี", "อุทัยธานี", | |
"อุบลราชธานี", "อ่างทอง" | |
] | |
self.CONF_THRESHOLD = 0.3 | |
def _map_class_to_char(self, class_name): | |
"""Map class to character using yaml mapping""" | |
if str(class_name) in self.char_mapping: | |
return self.char_mapping[str(class_name)] | |
return str(class_name) | |
def get_closest_province(self, input_text): | |
"""Find closest matching province""" | |
min_distance = float('inf') | |
closest_province = None | |
for province in self.thai_provinces: | |
distance = Levenshtein.distance(input_text, province) | |
if distance < min_distance: | |
min_distance = distance | |
closest_province = province | |
return closest_province, min_distance | |
def read_plate_characters(self, plate_image): | |
"""Read characters from plate image""" | |
results = self.char_reader.predict(plate_image, conf=0.3) | |
detections = [] | |
for r in results: | |
boxes = r.boxes | |
for box in boxes: | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
confidence = float(box.conf[0]) | |
class_id = int(box.cls[0]) | |
mapped_char = self._map_class_to_char(self.names[class_id]) | |
detections.append({ | |
'char': mapped_char, | |
'confidence': confidence, | |
'bbox': (x1, y1, x2, y2) | |
}) | |
# Sort detections left to right | |
detections.sort(key=lambda x: x['bbox'][0]) | |
# Combine characters | |
plate_text = ''.join(det['char'] for det in detections) | |
return plate_text | |
def process_image(self, image_path: str): | |
try: | |
# Read image | |
image = cv2.imread(image_path) | |
if image is None: | |
print(f"Error: Could not read image from {image_path}") | |
return None | |
# Detect license plate location | |
results = self.yolo_detector(image) | |
data = {"plate_number": "", "province": "", "raw_province": ""} | |
# Save visualization | |
output_image = image.copy() | |
for result in results: | |
for box in result.boxes: | |
confidence = float(box.conf) | |
if confidence < self.CONF_THRESHOLD: | |
continue | |
x1, y1, x2, y2 = map(int, box.xyxy.flatten()) | |
cropped_image = image[y1:y2, x1:x2] | |
# Draw rectangle on output image | |
color = (0, 255, 0) if int(box.cls.item()) == 0 else (255, 0, 0) | |
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2) | |
if int(box.cls.item()) == 0: # License plate number | |
# Read characters using YOLO character reader | |
data["plate_number"] = self.read_plate_characters(cropped_image) | |
elif int(box.cls.item()) == 1: # Province | |
# Process province using TrOCR | |
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY) | |
equalized_image = cv2.equalizeHist(cropped_image_gray) | |
_, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV) | |
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB) | |
resized_image = cv2.resize(cropped_image_3d, (128, 32)) | |
pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values | |
generated_ids = self.model_plate.generate(pixel_values) | |
generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
generated_province, _ = self.get_closest_province(generated_text) | |
data["raw_province"] = generated_text | |
data["province"] = generated_province | |
# Save the output image | |
cv2.imwrite('output_detection.jpg', output_image) | |
return data | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
return None | |
def main(): | |
st.set_page_config( | |
page_title="Thai License Plate Recognition", | |
layout="wide" | |
) | |
st.title("Thai License Plate Recognition") | |
st.write("Upload an image to detect and read Thai license plates") | |
# Initialize processor | |
def load_processor(): | |
return LicensePlateProcessor() | |
processor = load_processor() | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Create columns for side-by-side display | |
col1, col2 = st.columns(2) | |
# Display original image | |
with col1: | |
st.subheader("Original Image") | |
image = Image.open(uploaded_file) | |
st.image(image, use_column_width=True) | |
# Convert PIL Image to OpenCV format for processing | |
image_array = np.array(image) | |
if len(image_array.shape) == 3 and image_array.shape[2] == 4: | |
# Convert RGBA to RGB if needed | |
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB) | |
image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR) | |
# Process image | |
with st.spinner("Processing image..."): | |
try: | |
# Save the OpenCV image for processing | |
temp_path = 'temp_input.jpg' | |
cv2.imwrite(temp_path, image_cv) | |
# Process the image using the processor | |
results = processor.process_image(temp_path) | |
# Clean up temporary input file | |
os.remove(temp_path) | |
if results: | |
# Display results | |
st.subheader("Detection Results") | |
# Create a styled container for results | |
results_container = st.container() | |
with results_container: | |
st.markdown(f""" | |
<div style='background-color: #f0f2f6; padding: 20px; border-radius: 10px;'> | |
<h3>License Plate: {results['plate_number']}</h3> | |
<h3>Province: {results['province']}</h3> | |
<p>Raw Province Text: {results['raw_province']}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Display detection visualization | |
with col2: | |
st.subheader("Detection Visualization") | |
if os.path.exists('output_detection.jpg'): | |
# Read and convert the output image from BGR to RGB | |
output_image = cv2.imread('output_detection.jpg') | |
output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) | |
st.image(output_image_rgb, use_column_width=True) | |
# Clean up output image | |
os.remove('output_detection.jpg') | |
else: | |
st.error("No license plate detected in the image.") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
# Clean up any temporary files in case of error | |
if os.path.exists('temp_input.jpg'): | |
os.remove('temp_input.jpg') | |
if os.path.exists('output_detection.jpg'): | |
os.remove('output_detection.jpg') | |
# Add information about the application | |
with st.expander("About This Application"): | |
st.markdown(""" | |
### Thai License Plate Recognition System | |
This application uses advanced computer vision and deep learning to: | |
- Detect license plates in images using YOLO | |
- Read Thai license plate numbers using character recognition | |
- Identify province names using TrOCR | |
- Provide visual detection results | |
#### How to Use: | |
1. Click the 'Browse files' button above | |
2. Select an image containing a Thai license plate | |
3. Wait for the processing to complete | |
4. View the results and detection visualization | |
#### Technologies Used: | |
- YOLO for license plate detection | |
- Custom YOLO model for character recognition | |
- TrOCR for province text recognition | |
""") | |
if __name__ == "__main__": | |
main() |