Abs6187 commited on
Commit
72729f3
·
verified ·
1 Parent(s): b4c97e7

Create advanced_ocr.py

Browse files
Files changed (1) hide show
  1. advanced_ocr.py +368 -0
advanced_ocr.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq, pipeline
3
+ from PIL import Image, ImageEnhance, ImageFilter
4
+ import cv2
5
+ import numpy as np
6
+ import re
7
+ import os
8
+ from typing import Dict, List, Optional, Union
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ class AdvancedLicensePlateOCR:
13
+ def __init__(self):
14
+ self.models = {
15
+ "trocr_license": {
16
+ "name": "TrOCR License Plates (Recommended)",
17
+ "model_id": "DunnBC22/trocr-base-printed_license_plates_ocr",
18
+ "type": "transformers",
19
+ "processor": None,
20
+ "model": None,
21
+ "loaded": False,
22
+ "description": "Specialized TrOCR model trained on license plates"
23
+ },
24
+ "detr_license": {
25
+ "name": "DETR License Plate Detection + OCR",
26
+ "model_id": "nickmuchi/detr-resnet50-license-plate-detection",
27
+ "type": "object_detection",
28
+ "processor": None,
29
+ "model": None,
30
+ "loaded": False,
31
+ "description": "End-to-end detection and recognition"
32
+ },
33
+ "yolo_license": {
34
+ "name": "YOLO License Plate (Fast)",
35
+ "model_id": "keremberke/yolov5n-license-plate",
36
+ "type": "yolo",
37
+ "processor": None,
38
+ "model": None,
39
+ "loaded": False,
40
+ "description": "Fast YOLO-based license plate detection"
41
+ },
42
+ "trocr_base": {
43
+ "name": "TrOCR Base (General)",
44
+ "model_id": "microsoft/trocr-base-printed",
45
+ "type": "transformers",
46
+ "processor": None,
47
+ "model": None,
48
+ "loaded": False,
49
+ "description": "General purpose OCR model"
50
+ },
51
+ "easyocr": {
52
+ "name": "EasyOCR (Fallback)",
53
+ "model_id": "easyocr",
54
+ "type": "easyocr",
55
+ "processor": None,
56
+ "model": None,
57
+ "loaded": False,
58
+ "description": "Traditional OCR approach"
59
+ }
60
+ }
61
+
62
+ self.current_model = "trocr_license"
63
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+
65
+ def list_available_models(self) -> Dict[str, Dict]:
66
+ return {
67
+ key: {
68
+ "name": model["name"],
69
+ "description": model["description"],
70
+ "type": model["type"],
71
+ "loaded": model["loaded"]
72
+ }
73
+ for key, model in self.models.items()
74
+ }
75
+
76
+ def load_model(self, model_key: str) -> bool:
77
+ if model_key not in self.models:
78
+ print(f"Model {model_key} not found")
79
+ return False
80
+
81
+ model_info = self.models[model_key]
82
+
83
+ if model_info["loaded"]:
84
+ print(f"Model {model_info['name']} already loaded")
85
+ return True
86
+
87
+ try:
88
+ print(f"Loading {model_info['name']}...")
89
+
90
+ if model_info["type"] == "transformers":
91
+ model_info["processor"] = AutoProcessor.from_pretrained(model_info["model_id"])
92
+ model_info["model"] = AutoModelForVision2Seq.from_pretrained(model_info["model_id"])
93
+ model_info["model"].to(self.device)
94
+
95
+ elif model_info["type"] == "object_detection":
96
+ try:
97
+ model_info["model"] = pipeline(
98
+ "object-detection",
99
+ model=model_info["model_id"],
100
+ device=0 if torch.cuda.is_available() else -1
101
+ )
102
+ except Exception as e:
103
+ print(f"Failed to load as pipeline, trying alternative: {e}")
104
+ model_info["processor"] = AutoProcessor.from_pretrained(model_info["model_id"])
105
+ model_info["model"] = AutoModelForVision2Seq.from_pretrained(model_info["model_id"])
106
+ model_info["model"].to(self.device)
107
+
108
+ elif model_info["type"] == "yolo":
109
+ try:
110
+ from ultralytics import YOLO
111
+ model_info["model"] = YOLO(model_info["model_id"])
112
+ except Exception as e:
113
+ print(f"YOLO model loading failed: {e}")
114
+ return False
115
+
116
+ elif model_info["type"] == "easyocr":
117
+ try:
118
+ import easyocr
119
+ model_info["model"] = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
120
+ except Exception as e:
121
+ print(f"EasyOCR loading failed: {e}")
122
+ return False
123
+
124
+ model_info["loaded"] = True
125
+ self.current_model = model_key
126
+ print(f"✅ Successfully loaded {model_info['name']}")
127
+ return True
128
+
129
+ except Exception as e:
130
+ print(f"❌ Failed to load {model_info['name']}: {e}")
131
+ return False
132
+
133
+ def preprocess_image_advanced(self, image: Image.Image) -> List[Image.Image]:
134
+ variants = []
135
+
136
+ try:
137
+ original = image.copy()
138
+ variants.append(original)
139
+
140
+ if image.mode != 'RGB':
141
+ image = image.convert('RGB')
142
+
143
+ enhancer = ImageEnhance.Contrast(image)
144
+ high_contrast = enhancer.enhance(2.5)
145
+ variants.append(high_contrast)
146
+
147
+ sharpened = high_contrast.filter(ImageFilter.SHARPEN)
148
+ variants.append(sharpened)
149
+
150
+ img_array = np.array(image)
151
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
152
+
153
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
154
+ clahe_img = clahe.apply(gray)
155
+ clahe_pil = Image.fromarray(clahe_img).convert('RGB')
156
+ variants.append(clahe_pil)
157
+
158
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
159
+ binary_pil = Image.fromarray(binary).convert('RGB')
160
+ variants.append(binary_pil)
161
+
162
+ denoised = cv2.bilateralFilter(gray, 9, 75, 75)
163
+ denoised_pil = Image.fromarray(denoised).convert('RGB')
164
+ variants.append(denoised_pil)
165
+
166
+ except Exception as e:
167
+ print(f"Preprocessing error: {e}")
168
+ variants = [image]
169
+
170
+ return variants
171
+
172
+ def extract_with_trocr(self, image: Image.Image, model_key: str) -> str:
173
+ model_info = self.models[model_key]
174
+
175
+ if not model_info["loaded"]:
176
+ if not self.load_model(model_key):
177
+ return "Model loading failed"
178
+
179
+ try:
180
+ processor = model_info["processor"]
181
+ model = model_info["model"]
182
+
183
+ pixel_values = processor(image, return_tensors="pt").pixel_values
184
+ pixel_values = pixel_values.to(self.device)
185
+
186
+ with torch.no_grad():
187
+ generated_ids = model.generate(pixel_values, max_length=50)
188
+
189
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
190
+ return text.strip()
191
+
192
+ except Exception as e:
193
+ print(f"TrOCR extraction error: {e}")
194
+ return f"TrOCR Error: {str(e)}"
195
+
196
+ def extract_with_easyocr(self, image: Image.Image) -> str:
197
+ model_info = self.models["easyocr"]
198
+
199
+ if not model_info["loaded"]:
200
+ if not self.load_model("easyocr"):
201
+ return "EasyOCR loading failed"
202
+
203
+ try:
204
+ reader = model_info["model"]
205
+ img_array = np.array(image)
206
+ results = reader.readtext(img_array, detail=False, paragraph=False)
207
+
208
+ if results:
209
+ return ' '.join(results).strip()
210
+ return "No text detected"
211
+
212
+ except Exception as e:
213
+ print(f"EasyOCR extraction error: {e}")
214
+ return f"EasyOCR Error: {str(e)}"
215
+
216
+ def extract_with_detr(self, image: Image.Image) -> str:
217
+ model_info = self.models["detr_license"]
218
+
219
+ if not model_info["loaded"]:
220
+ if not self.load_model("detr_license"):
221
+ return "DETR model loading failed"
222
+
223
+ try:
224
+ if hasattr(model_info["model"], '__call__'):
225
+ results = model_info["model"](image)
226
+ if results and len(results) > 0:
227
+ return f"Detected {len(results)} objects"
228
+ else:
229
+ return self.extract_with_trocr(image, "detr_license")
230
+
231
+ except Exception as e:
232
+ print(f"DETR extraction error: {e}")
233
+ return f"DETR Error: {str(e)}"
234
+
235
+ def clean_license_text(self, text: str) -> str:
236
+ if not text or text.startswith(("Error:", "Failed")):
237
+ return text
238
+
239
+ text = text.upper().strip()
240
+ text = re.sub(r'[^A-Z0-9\s-]', '', text)
241
+ text = re.sub(r'\s+', ' ', text).strip()
242
+
243
+ common_corrections = {
244
+ 'O': '0', 'I': '1', 'S': '5', 'B': '8', 'G': '6', 'Z': '2'
245
+ }
246
+
247
+ for old, new in common_corrections.items():
248
+ if sum(c.isdigit() for c in text) > sum(c.isalpha() for c in text):
249
+ text = text.replace(old, new)
250
+
251
+ return text
252
+
253
+ def extract_text_with_model(self, image: Union[Image.Image, str],
254
+ model_key: Optional[str] = None,
255
+ use_preprocessing: bool = True) -> Dict:
256
+
257
+ if isinstance(image, str):
258
+ if os.path.exists(image):
259
+ image = Image.open(image)
260
+ else:
261
+ return {"error": f"Image file not found: {image}"}
262
+
263
+ if model_key is None:
264
+ model_key = self.current_model
265
+
266
+ if model_key not in self.models:
267
+ return {"error": f"Unknown model: {model_key}"}
268
+
269
+ result = {
270
+ "model_used": self.models[model_key]["name"],
271
+ "model_key": model_key,
272
+ "preprocessing": use_preprocessing,
273
+ "extractions": [],
274
+ "best_result": "",
275
+ "confidence": 0.0
276
+ }
277
+
278
+ try:
279
+ images_to_process = self.preprocess_image_advanced(image) if use_preprocessing else [image]
280
+
281
+ for i, processed_img in enumerate(images_to_process):
282
+ try:
283
+ if self.models[model_key]["type"] == "transformers":
284
+ raw_text = self.extract_with_trocr(processed_img, model_key)
285
+ elif self.models[model_key]["type"] == "object_detection":
286
+ raw_text = self.extract_with_detr(processed_img)
287
+ elif self.models[model_key]["type"] == "easyocr":
288
+ raw_text = self.extract_with_easyocr(processed_img)
289
+ else:
290
+ raw_text = "Unsupported model type"
291
+
292
+ cleaned_text = self.clean_license_text(raw_text)
293
+
294
+ extraction = {
295
+ "step": i,
296
+ "raw_text": raw_text,
297
+ "cleaned_text": cleaned_text,
298
+ "length": len(cleaned_text) if cleaned_text else 0
299
+ }
300
+
301
+ result["extractions"].append(extraction)
302
+
303
+ if cleaned_text and not cleaned_text.startswith(("Error:", "Failed")):
304
+ if len(cleaned_text) > len(result["best_result"]):
305
+ result["best_result"] = cleaned_text
306
+ result["confidence"] = 0.8 + (len(cleaned_text) * 0.02)
307
+
308
+ except Exception as e:
309
+ print(f"Error processing image variant {i}: {e}")
310
+ continue
311
+
312
+ if not result["best_result"]:
313
+ if result["extractions"]:
314
+ result["best_result"] = result["extractions"][0].get("raw_text", "No text found")
315
+ result["confidence"] = 0.3
316
+ else:
317
+ result["best_result"] = "No text extracted"
318
+ result["confidence"] = 0.0
319
+
320
+ return result
321
+
322
+ except Exception as e:
323
+ return {"error": f"Extraction failed: {str(e)}"}
324
+
325
+ advanced_ocr = AdvancedLicensePlateOCR()
326
+
327
+ def get_available_models():
328
+ return advanced_ocr.list_available_models()
329
+
330
+ def set_ocr_model(model_key: str) -> bool:
331
+ return advanced_ocr.load_model(model_key)
332
+
333
+ def extract_license_plate_text_advanced(image: Union[Image.Image, str],
334
+ model_key: Optional[str] = None) -> str:
335
+ try:
336
+ result = advanced_ocr.extract_text_with_model(image, model_key)
337
+
338
+ if "error" in result:
339
+ return f"Error: {result['error']}"
340
+
341
+ return result.get("best_result", "No text found")
342
+
343
+ except Exception as e:
344
+ return f"Error: {str(e)}"
345
+
346
+ def get_detailed_analysis(image: Union[Image.Image, str],
347
+ model_key: Optional[str] = None) -> Dict:
348
+ return advanced_ocr.extract_text_with_model(image, model_key)
349
+
350
+ if __name__ == "__main__":
351
+ print("Advanced License Plate OCR System")
352
+ print("=" * 40)
353
+
354
+ models = get_available_models()
355
+ print("Available models:")
356
+ for key, info in models.items():
357
+ status = "✅" if info["loaded"] else "⚪"
358
+ print(f"{status} {key}: {info['name']} - {info['description']}")
359
+
360
+ print("\nRecommended models (in order):")
361
+ print("1. trocr_license - Best for license plates")
362
+ print("2. detr_license - End-to-end detection")
363
+ print("3. easyocr - Reliable fallback")
364
+
365
+ print("\nUsage:")
366
+ print("from advanced_ocr import extract_license_plate_text_advanced, set_ocr_model")
367
+ print("set_ocr_model('trocr_license')")
368
+ print("text = extract_license_plate_text_advanced('license_plate.jpg')")