File size: 12,563 Bytes
7ced5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11a21
 
 
7ced5c0
 
 
 
 
 
5e11a21
7ced5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11a21
 
7ced5c0
 
 
 
 
 
5e11a21
 
7ced5c0
 
 
 
 
 
 
 
5e11a21
 
 
 
 
7ced5c0
5e11a21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ced5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11a21
7ced5c0
 
 
 
 
5e11a21
7ced5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11a21
7ced5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11a21
7ced5c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e11a21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import streamlit as st
from PIL import Image
import cv2
import numpy as np
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
import yaml
import os
import io
import tempfile
import torchvision

class LicensePlateProcessor:
    def __init__(self):
        # Load models for plate detection
        self.yolo_detector = YOLO('detect_plate.pt')  # For license plate detection
        self.province_detector = YOLO('best.pt')  # For province detection
        self.char_reader = YOLO('read_char.pt')  # For character reading
        
        # Load TrOCR for province detection
        self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr')
        self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr')
        
        # Load character mapping from yaml
        with open('data.yaml', 'r', encoding='utf-8') as f:
            data_config = yaml.safe_load(f)
            self.char_mapping = data_config.get('char_mapping', {})
            self.names = data_config['names']
        
        # Load province list
        self.thai_provinces = [
            "กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", 
            "จันทบุรี", "ฉะเชิงเทรา", "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", 
            "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", "นครปฐม", "นครพนม", "นครราชสีมา", 
            "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", "บุรีรัมย์", "ปทุมธานี", 
            "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", "พิจิตร", 
            "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", 
            "แม่ฮ่องสอน", "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", 
            "ลำปาง", "ลำพูน", "เลย", "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", 
            "สมุทรสาคร", "สระแก้ว", "สระบุรี", "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", 
            "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", "อุดรธานี", "อุทัยธานี", 
            "อุบลราชธานี", "อ่างทอง"
        ]
        
        self.CONF_THRESHOLD = 0.3

    def _map_class_to_char(self, class_name):
        """Map class to character using yaml mapping"""
        if str(class_name) in self.char_mapping:
            return self.char_mapping[str(class_name)]
        return str(class_name)

    def get_closest_province(self, input_text):
        """Find closest matching province"""
        min_distance = float('inf')
        closest_province = None

        for province in self.thai_provinces:
            distance = Levenshtein.distance(input_text, province)
            if distance < min_distance:
                min_distance = distance
                closest_province = province

        return closest_province, min_distance

    def read_plate_characters(self, plate_image):
        """Read characters from plate image"""
        results = self.char_reader.predict(plate_image, conf=0.3)
        
        detections = []
        for r in results:
            boxes = r.boxes
            for box in boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])
                confidence = float(box.conf[0])
                class_id = int(box.cls[0])
                
                mapped_char = self._map_class_to_char(self.names[class_id])
                
                detections.append({
                    'char': mapped_char,
                    'confidence': confidence,
                    'bbox': (x1, y1, x2, y2)
                })
        
        # Sort detections left to right
        detections.sort(key=lambda x: x['bbox'][0])
        
        # Combine characters
        plate_text = ''.join(det['char'] for det in detections)
        return plate_text

    def process_image(self, image_path: str):
        try:
            # Read image
            image = cv2.imread(image_path)
            if image is None:
                print(f"Error: Could not read image from {image_path}")
                return None

            # Detect license plate location
            plate_results = self.yolo_detector(image)
            province_results = self.province_detector(image)
            
            data = {"plate_number": "", "province": "", "raw_province": ""}
            
            # Save visualization
            output_image = image.copy()
            
            # Process license plate detections
            for result in plate_results:
                for box in result.boxes:
                    confidence = float(box.conf)
                    if confidence < self.CONF_THRESHOLD:
                        continue

                    x1, y1, x2, y2 = map(int, box.xyxy.flatten())
                    cropped_image = image[y1:y2, x1:x2]

                    # Draw rectangle on output image (green for plate)
                    cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    
                    # Read characters using YOLO character reader
                    data["plate_number"] = self.read_plate_characters(cropped_image)

            # Process province detections
            for result in province_results:
                for box in result.boxes:
                    confidence = float(box.conf)
                    if confidence < self.CONF_THRESHOLD:
                        continue

                    x1, y1, x2, y2 = map(int, box.xyxy.flatten())
                    cropped_image = image[y1:y2, x1:x2]

                    # Draw rectangle on output image (blue for province)
                    cv2.rectangle(output_image, (x1, y1), (x2, y2), (255, 0, 0), 2)

                    # Process province using TrOCR
                    cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
                    equalized_image = cv2.equalizeHist(cropped_image_gray)
                    _, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV)
                    cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
                    resized_image = cv2.resize(cropped_image_3d, (128, 32))
                    
                    pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values
                    generated_ids = self.model_plate.generate(pixel_values)
                    generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0]
                    
                    generated_province, _ = self.get_closest_province(generated_text)
                    data["raw_province"] = generated_text
                    data["province"] = generated_province

            # Save the output image
            cv2.imwrite('output_detection.jpg', output_image)
            return data

        except Exception as e:
            print(f"Error processing image: {str(e)}")
            return None

def main():
    st.set_page_config(
        page_title="Thai License Plate Recognition",
        layout="wide"
    )

    st.title("Thai License Plate Recognition")
    st.write("Upload an image to detect and read Thai license plates")

    # Initialize processor
    @st.cache_resource
    def load_processor():
        return LicensePlateProcessor()

    processor = load_processor()

    # File uploader
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        # Create columns for side-by-side display
        col1, col2 = st.columns(2)
        
        # Display original image
        with col1:
            st.subheader("Original Image")
            image = Image.open(uploaded_file)
            st.image(image, use_column_width=True)
            
            # Convert PIL Image to OpenCV format for processing
            image_array = np.array(image)
            if len(image_array.shape) == 3 and image_array.shape[2] == 4:
                # Convert RGBA to RGB if needed
                image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB)
            image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)

        # Process image
        with st.spinner("Processing image..."):
            try:
                # Save the OpenCV image for processing
                temp_path = 'temp_input.jpg'
                cv2.imwrite(temp_path, image_cv)
                
                # Process the image using the processor
                results = processor.process_image(temp_path)

                # Clean up temporary input file
                os.remove(temp_path)
                
                if results:
                    # Display results
                    st.subheader("Detection Results")
                    
                    # Create a styled container for results
                    results_container = st.container()
                    with results_container:
                        st.markdown(f"""
                        <div style='background-color: #f0f2f6; padding: 20px; border-radius: 10px;'>
                            <h3>License Plate: {results['plate_number']}</h3>
                            <h3>Province: {results['province']}</h3>
                            <p>Raw Province Text: {results['raw_province']}</p>
                        </div>
                        """, unsafe_allow_html=True)
                    
                    # Display detection visualization
                    with col2:
                        st.subheader("Detection Visualization")
                        if os.path.exists('output_detection.jpg'):
                            # Read and convert the output image from BGR to RGB
                            output_image = cv2.imread('output_detection.jpg')
                            output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
                            st.image(output_image_rgb, use_column_width=True)
                            # Clean up output image
                            os.remove('output_detection.jpg')
                else:
                    st.error("No license plate detected in the image.")

            except Exception as e:
                st.error(f"Error processing image: {str(e)}")
                # Clean up any temporary files in case of error
                if os.path.exists('temp_input.jpg'):
                    os.remove('temp_input.jpg')
                if os.path.exists('output_detection.jpg'):
                    os.remove('output_detection.jpg')

    # Add information about the application
    with st.expander("About This Application"):
        st.markdown("""
        ### Thai License Plate Recognition System

        This application uses advanced computer vision and deep learning to:
        - Detect license plates in images using YOLO
        - Read Thai license plate numbers using character recognition
        - Identify province names using TrOCR
        - Provide visual detection results

        #### How to Use:
        1. Click the 'Browse files' button above
        2. Select an image containing a Thai license plate
        3. Wait for the processing to complete
        4. View the results and detection visualization

        #### Technologies Used:
        - YOLO for license plate detection
        - Custom YOLO model for character recognition
        - TrOCR for province text recognition
        """)

if __name__ == "__main__":
    main()