Spaces:
Running
Running
File size: 12,563 Bytes
7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 7ced5c0 5e11a21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
import streamlit as st
from PIL import Image
import cv2
import numpy as np
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
import yaml
import os
import io
import tempfile
import torchvision
class LicensePlateProcessor:
def __init__(self):
# Load models for plate detection
self.yolo_detector = YOLO('detect_plate.pt') # For license plate detection
self.province_detector = YOLO('best.pt') # For province detection
self.char_reader = YOLO('read_char.pt') # For character reading
# Load TrOCR for province detection
self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr')
self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr')
# Load character mapping from yaml
with open('data.yaml', 'r', encoding='utf-8') as f:
data_config = yaml.safe_load(f)
self.char_mapping = data_config.get('char_mapping', {})
self.names = data_config['names']
# Load province list
self.thai_provinces = [
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น",
"จันทบุรี", "ฉะเชิงเทรา", "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย",
"เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", "นครปฐม", "นครพนม", "นครราชสีมา",
"นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", "บุรีรัมย์", "ปทุมธานี",
"ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", "พิจิตร",
"พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร",
"แม่ฮ่องสอน", "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี",
"ลำปาง", "ลำพูน", "เลย", "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม",
"สมุทรสาคร", "สระแก้ว", "สระบุรี", "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี",
"สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", "อุดรธานี", "อุทัยธานี",
"อุบลราชธานี", "อ่างทอง"
]
self.CONF_THRESHOLD = 0.3
def _map_class_to_char(self, class_name):
"""Map class to character using yaml mapping"""
if str(class_name) in self.char_mapping:
return self.char_mapping[str(class_name)]
return str(class_name)
def get_closest_province(self, input_text):
"""Find closest matching province"""
min_distance = float('inf')
closest_province = None
for province in self.thai_provinces:
distance = Levenshtein.distance(input_text, province)
if distance < min_distance:
min_distance = distance
closest_province = province
return closest_province, min_distance
def read_plate_characters(self, plate_image):
"""Read characters from plate image"""
results = self.char_reader.predict(plate_image, conf=0.3)
detections = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
confidence = float(box.conf[0])
class_id = int(box.cls[0])
mapped_char = self._map_class_to_char(self.names[class_id])
detections.append({
'char': mapped_char,
'confidence': confidence,
'bbox': (x1, y1, x2, y2)
})
# Sort detections left to right
detections.sort(key=lambda x: x['bbox'][0])
# Combine characters
plate_text = ''.join(det['char'] for det in detections)
return plate_text
def process_image(self, image_path: str):
try:
# Read image
image = cv2.imread(image_path)
if image is None:
print(f"Error: Could not read image from {image_path}")
return None
# Detect license plate location
plate_results = self.yolo_detector(image)
province_results = self.province_detector(image)
data = {"plate_number": "", "province": "", "raw_province": ""}
# Save visualization
output_image = image.copy()
# Process license plate detections
for result in plate_results:
for box in result.boxes:
confidence = float(box.conf)
if confidence < self.CONF_THRESHOLD:
continue
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
cropped_image = image[y1:y2, x1:x2]
# Draw rectangle on output image (green for plate)
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Read characters using YOLO character reader
data["plate_number"] = self.read_plate_characters(cropped_image)
# Process province detections
for result in province_results:
for box in result.boxes:
confidence = float(box.conf)
if confidence < self.CONF_THRESHOLD:
continue
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
cropped_image = image[y1:y2, x1:x2]
# Draw rectangle on output image (blue for province)
cv2.rectangle(output_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
# Process province using TrOCR
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
equalized_image = cv2.equalizeHist(cropped_image_gray)
_, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV)
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
resized_image = cv2.resize(cropped_image_3d, (128, 32))
pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values
generated_ids = self.model_plate.generate(pixel_values)
generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_province, _ = self.get_closest_province(generated_text)
data["raw_province"] = generated_text
data["province"] = generated_province
# Save the output image
cv2.imwrite('output_detection.jpg', output_image)
return data
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
def main():
st.set_page_config(
page_title="Thai License Plate Recognition",
layout="wide"
)
st.title("Thai License Plate Recognition")
st.write("Upload an image to detect and read Thai license plates")
# Initialize processor
@st.cache_resource
def load_processor():
return LicensePlateProcessor()
processor = load_processor()
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Create columns for side-by-side display
col1, col2 = st.columns(2)
# Display original image
with col1:
st.subheader("Original Image")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
# Convert PIL Image to OpenCV format for processing
image_array = np.array(image)
if len(image_array.shape) == 3 and image_array.shape[2] == 4:
# Convert RGBA to RGB if needed
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB)
image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
# Process image
with st.spinner("Processing image..."):
try:
# Save the OpenCV image for processing
temp_path = 'temp_input.jpg'
cv2.imwrite(temp_path, image_cv)
# Process the image using the processor
results = processor.process_image(temp_path)
# Clean up temporary input file
os.remove(temp_path)
if results:
# Display results
st.subheader("Detection Results")
# Create a styled container for results
results_container = st.container()
with results_container:
st.markdown(f"""
<div style='background-color: #f0f2f6; padding: 20px; border-radius: 10px;'>
<h3>License Plate: {results['plate_number']}</h3>
<h3>Province: {results['province']}</h3>
<p>Raw Province Text: {results['raw_province']}</p>
</div>
""", unsafe_allow_html=True)
# Display detection visualization
with col2:
st.subheader("Detection Visualization")
if os.path.exists('output_detection.jpg'):
# Read and convert the output image from BGR to RGB
output_image = cv2.imread('output_detection.jpg')
output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
st.image(output_image_rgb, use_column_width=True)
# Clean up output image
os.remove('output_detection.jpg')
else:
st.error("No license plate detected in the image.")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Clean up any temporary files in case of error
if os.path.exists('temp_input.jpg'):
os.remove('temp_input.jpg')
if os.path.exists('output_detection.jpg'):
os.remove('output_detection.jpg')
# Add information about the application
with st.expander("About This Application"):
st.markdown("""
### Thai License Plate Recognition System
This application uses advanced computer vision and deep learning to:
- Detect license plates in images using YOLO
- Read Thai license plate numbers using character recognition
- Identify province names using TrOCR
- Provide visual detection results
#### How to Use:
1. Click the 'Browse files' button above
2. Select an image containing a Thai license plate
3. Wait for the processing to complete
4. View the results and detection visualization
#### Technologies Used:
- YOLO for license plate detection
- Custom YOLO model for character recognition
- TrOCR for province text recognition
""")
if __name__ == "__main__":
main() |