Sompote commited on
Commit
5e11a21
·
verified ·
1 Parent(s): f3308a3

Upload 6 files

Browse files
Files changed (4) hide show
  1. app.py +45 -33
  2. best.pt +3 -0
  3. data.yaml +52 -0
  4. detect_plate.pt +3 -0
app.py CHANGED
@@ -11,19 +11,19 @@ import io
11
  import tempfile
12
  import torchvision
13
 
14
-
15
  class LicensePlateProcessor:
16
  def __init__(self):
17
  # Load models for plate detection
18
- self.yolo_detector = YOLO('models/best.pt') # For plate detection
19
- self.char_reader = YOLO('models/read_char.pt') # For character reading
 
20
 
21
  # Load TrOCR for province detection
22
  self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr')
23
  self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr')
24
 
25
  # Load character mapping from yaml
26
- with open('config/data.yaml', 'r', encoding='utf-8') as f:
27
  data_config = yaml.safe_load(f)
28
  self.char_mapping = data_config.get('char_mapping', {})
29
  self.names = data_config['names']
@@ -100,14 +100,16 @@ class LicensePlateProcessor:
100
  return None
101
 
102
  # Detect license plate location
103
- results = self.yolo_detector(image)
 
104
 
105
  data = {"plate_number": "", "province": "", "raw_province": ""}
106
 
107
  # Save visualization
108
  output_image = image.copy()
109
 
110
- for result in results:
 
111
  for box in result.boxes:
112
  confidence = float(box.conf)
113
  if confidence < self.CONF_THRESHOLD:
@@ -116,29 +118,39 @@ class LicensePlateProcessor:
116
  x1, y1, x2, y2 = map(int, box.xyxy.flatten())
117
  cropped_image = image[y1:y2, x1:x2]
118
 
119
- # Draw rectangle on output image
120
- color = (0, 255, 0) if int(box.cls.item()) == 0 else (255, 0, 0)
121
- cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
 
 
122
 
123
- if int(box.cls.item()) == 0: # License plate number
124
- # Read characters using YOLO character reader
125
- data["plate_number"] = self.read_plate_characters(cropped_image)
126
-
127
- elif int(box.cls.item()) == 1: # Province
128
- # Process province using TrOCR
129
- cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
130
- equalized_image = cv2.equalizeHist(cropped_image_gray)
131
- _, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV)
132
- cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
133
- resized_image = cv2.resize(cropped_image_3d, (128, 32))
134
-
135
- pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values
136
- generated_ids = self.model_plate.generate(pixel_values)
137
- generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0]
138
-
139
- generated_province, _ = self.get_closest_province(generated_text)
140
- data["raw_province"] = generated_text
141
- data["province"] = generated_province
 
 
 
 
 
 
 
 
142
 
143
  # Save the output image
144
  cv2.imwrite('output_detection.jpg', output_image)
@@ -170,13 +182,13 @@ def main():
170
  if uploaded_file is not None:
171
  # Create columns for side-by-side display
172
  col1, col2 = st.columns(2)
173
-
174
  # Display original image
175
  with col1:
176
  st.subheader("Original Image")
177
  image = Image.open(uploaded_file)
178
  st.image(image, use_column_width=True)
179
-
180
  # Convert PIL Image to OpenCV format for processing
181
  image_array = np.array(image)
182
  if len(image_array.shape) == 3 and image_array.shape[2] == 4:
@@ -196,7 +208,7 @@ def main():
196
 
197
  # Clean up temporary input file
198
  os.remove(temp_path)
199
-
200
  if results:
201
  # Display results
202
  st.subheader("Detection Results")
@@ -211,7 +223,7 @@ def main():
211
  <p>Raw Province Text: {results['raw_province']}</p>
212
  </div>
213
  """, unsafe_allow_html=True)
214
-
215
  # Display detection visualization
216
  with col2:
217
  st.subheader("Detection Visualization")
@@ -257,4 +269,4 @@ def main():
257
  """)
258
 
259
  if __name__ == "__main__":
260
- main()
 
11
  import tempfile
12
  import torchvision
13
 
 
14
  class LicensePlateProcessor:
15
  def __init__(self):
16
  # Load models for plate detection
17
+ self.yolo_detector = YOLO('detect_plate.pt') # For license plate detection
18
+ self.province_detector = YOLO('best.pt') # For province detection
19
+ self.char_reader = YOLO('read_char.pt') # For character reading
20
 
21
  # Load TrOCR for province detection
22
  self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr')
23
  self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr')
24
 
25
  # Load character mapping from yaml
26
+ with open('data.yaml', 'r', encoding='utf-8') as f:
27
  data_config = yaml.safe_load(f)
28
  self.char_mapping = data_config.get('char_mapping', {})
29
  self.names = data_config['names']
 
100
  return None
101
 
102
  # Detect license plate location
103
+ plate_results = self.yolo_detector(image)
104
+ province_results = self.province_detector(image)
105
 
106
  data = {"plate_number": "", "province": "", "raw_province": ""}
107
 
108
  # Save visualization
109
  output_image = image.copy()
110
 
111
+ # Process license plate detections
112
+ for result in plate_results:
113
  for box in result.boxes:
114
  confidence = float(box.conf)
115
  if confidence < self.CONF_THRESHOLD:
 
118
  x1, y1, x2, y2 = map(int, box.xyxy.flatten())
119
  cropped_image = image[y1:y2, x1:x2]
120
 
121
+ # Draw rectangle on output image (green for plate)
122
+ cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
123
+
124
+ # Read characters using YOLO character reader
125
+ data["plate_number"] = self.read_plate_characters(cropped_image)
126
 
127
+ # Process province detections
128
+ for result in province_results:
129
+ for box in result.boxes:
130
+ confidence = float(box.conf)
131
+ if confidence < self.CONF_THRESHOLD:
132
+ continue
133
+
134
+ x1, y1, x2, y2 = map(int, box.xyxy.flatten())
135
+ cropped_image = image[y1:y2, x1:x2]
136
+
137
+ # Draw rectangle on output image (blue for province)
138
+ cv2.rectangle(output_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
139
+
140
+ # Process province using TrOCR
141
+ cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
142
+ equalized_image = cv2.equalizeHist(cropped_image_gray)
143
+ _, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV)
144
+ cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
145
+ resized_image = cv2.resize(cropped_image_3d, (128, 32))
146
+
147
+ pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values
148
+ generated_ids = self.model_plate.generate(pixel_values)
149
+ generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0]
150
+
151
+ generated_province, _ = self.get_closest_province(generated_text)
152
+ data["raw_province"] = generated_text
153
+ data["province"] = generated_province
154
 
155
  # Save the output image
156
  cv2.imwrite('output_detection.jpg', output_image)
 
182
  if uploaded_file is not None:
183
  # Create columns for side-by-side display
184
  col1, col2 = st.columns(2)
185
+
186
  # Display original image
187
  with col1:
188
  st.subheader("Original Image")
189
  image = Image.open(uploaded_file)
190
  st.image(image, use_column_width=True)
191
+
192
  # Convert PIL Image to OpenCV format for processing
193
  image_array = np.array(image)
194
  if len(image_array.shape) == 3 and image_array.shape[2] == 4:
 
208
 
209
  # Clean up temporary input file
210
  os.remove(temp_path)
211
+
212
  if results:
213
  # Display results
214
  st.subheader("Detection Results")
 
223
  <p>Raw Province Text: {results['raw_province']}</p>
224
  </div>
225
  """, unsafe_allow_html=True)
226
+
227
  # Display detection visualization
228
  with col2:
229
  st.subheader("Detection Visualization")
 
269
  """)
270
 
271
  if __name__ == "__main__":
272
+ main()
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b1da8d9362a1005aa5b060b0ac53b4622677e753eded2893da10b6a69bc9fb7
3
+ size 5468691
data.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train: /CarLicensePlate/iotproject-license-plate-3/train
2
+ val: /CarLicensePlate/iotproject-license-plate-3/valid
3
+ test: /CarLicensePlate/iotproject-license-plate-3/test
4
+ nc: 47
5
+ names: ['0', '1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '5', '6', '7', '8', '9']
6
+
7
+
8
+ char_mapping:
9
+ '10': 'ก'
10
+ '11': 'ข'
11
+ '12': 'ค'
12
+ '13': 'ฆ'
13
+ '14': 'ง'
14
+ '15': 'จ'
15
+ '16': 'ฉ'
16
+ '17': 'ช'
17
+ '18': 'ฌ'
18
+ '19': 'ญ'
19
+ '20': 'ฎ'
20
+ '21': 'ฐ'
21
+ '22': 'ฒ'
22
+ '23': 'ณ'
23
+ '24': 'ด'
24
+ '25': 'ต'
25
+ '26': 'ถ'
26
+ '27': 'ท'
27
+ '28': 'ธ'
28
+ '29': 'น'
29
+ '30': 'บ'
30
+ '31': 'ผ'
31
+ '32': 'พ'
32
+ '33': 'ฟ'
33
+ '34': 'ภ'
34
+ '35': 'ม'
35
+ '36': 'ย'
36
+ '37': 'ร'
37
+ '38': 'ล'
38
+ '39': 'ว'
39
+ '40': 'ศ'
40
+ '41': 'ษ'
41
+ '42': 'ส'
42
+ '43': 'ห'
43
+ '44': 'ฬ'
44
+ '45': 'อ'
45
+ '46': 'ฮ'
46
+
47
+ roboflow:
48
+ workspace: magarthai
49
+ project: iotproject-license-plate
50
+ version: 3
51
+ license: CC BY 4.0
52
+ url: https://universe.roboflow.com/magarthai/iotproject-license-plate/dataset/3
detect_plate.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d605afb93097eb60af1ea2bd33d5ae25ef778176f3a7e28be79add2d331890
3
+ size 19188819