import torch from transformers import AutoProcessor, AutoModelForVision2Seq, pipeline from PIL import Image, ImageEnhance, ImageFilter import cv2 import numpy as np import re import os from typing import Dict, List, Optional, Union import requests from io import BytesIO class AdvancedLicensePlateOCR: def __init__(self): self.models = { "trocr_license": { "name": "TrOCR License Plates (Recommended)", "model_id": "DunnBC22/trocr-base-printed_license_plates_ocr", "type": "transformers", "processor": None, "model": None, "loaded": False, "description": "Specialized TrOCR model trained on license plates" }, "detr_license": { "name": "DETR License Plate Detection + OCR", "model_id": "nickmuchi/detr-resnet50-license-plate-detection", "type": "object_detection", "processor": None, "model": None, "loaded": False, "description": "End-to-end detection and recognition" }, "yolo_license": { "name": "YOLO License Plate (Fast)", "model_id": "keremberke/yolov5n-license-plate", "type": "yolo", "processor": None, "model": None, "loaded": False, "description": "Fast YOLO-based license plate detection" }, "trocr_base": { "name": "TrOCR Base (General)", "model_id": "microsoft/trocr-base-printed", "type": "transformers", "processor": None, "model": None, "loaded": False, "description": "General purpose OCR model" }, "easyocr": { "name": "EasyOCR (Fallback)", "model_id": "easyocr", "type": "easyocr", "processor": None, "model": None, "loaded": False, "description": "Traditional OCR approach" } } self.current_model = "trocr_license" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def list_available_models(self) -> Dict[str, Dict]: return { key: { "name": model["name"], "description": model["description"], "type": model["type"], "loaded": model["loaded"] } for key, model in self.models.items() } def load_model(self, model_key: str) -> bool: if model_key not in self.models: print(f"Model {model_key} not found") return False model_info = self.models[model_key] if model_info["loaded"]: print(f"Model {model_info['name']} already loaded") return True try: print(f"Loading {model_info['name']}...") if model_info["type"] == "transformers": model_info["processor"] = AutoProcessor.from_pretrained(model_info["model_id"]) model_info["model"] = AutoModelForVision2Seq.from_pretrained(model_info["model_id"]) model_info["model"].to(self.device) elif model_info["type"] == "object_detection": try: model_info["model"] = pipeline( "object-detection", model=model_info["model_id"], device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Failed to load as pipeline, trying alternative: {e}") model_info["processor"] = AutoProcessor.from_pretrained(model_info["model_id"]) model_info["model"] = AutoModelForVision2Seq.from_pretrained(model_info["model_id"]) model_info["model"].to(self.device) elif model_info["type"] == "yolo": try: from ultralytics import YOLO model_info["model"] = YOLO(model_info["model_id"]) except Exception as e: print(f"YOLO model loading failed: {e}") return False elif model_info["type"] == "easyocr": try: import easyocr model_info["model"] = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) except Exception as e: print(f"EasyOCR loading failed: {e}") return False model_info["loaded"] = True self.current_model = model_key print(f"✅ Successfully loaded {model_info['name']}") return True except Exception as e: print(f"❌ Failed to load {model_info['name']}: {e}") return False def preprocess_image_advanced(self, image: Image.Image) -> List[Image.Image]: variants = [] try: original = image.copy() variants.append(original) if image.mode != 'RGB': image = image.convert('RGB') enhancer = ImageEnhance.Contrast(image) high_contrast = enhancer.enhance(2.5) variants.append(high_contrast) sharpened = high_contrast.filter(ImageFilter.SHARPEN) variants.append(sharpened) img_array = np.array(image) gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) clahe_img = clahe.apply(gray) clahe_pil = Image.fromarray(clahe_img).convert('RGB') variants.append(clahe_pil) _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) binary_pil = Image.fromarray(binary).convert('RGB') variants.append(binary_pil) denoised = cv2.bilateralFilter(gray, 9, 75, 75) denoised_pil = Image.fromarray(denoised).convert('RGB') variants.append(denoised_pil) except Exception as e: print(f"Preprocessing error: {e}") variants = [image] return variants def extract_with_trocr(self, image: Image.Image, model_key: str) -> str: model_info = self.models[model_key] if not model_info["loaded"]: if not self.load_model(model_key): return "Model loading failed" try: processor = model_info["processor"] model = model_info["model"] pixel_values = processor(image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(self.device) with torch.no_grad(): generated_ids = model.generate(pixel_values, max_length=50) text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return text.strip() except Exception as e: print(f"TrOCR extraction error: {e}") return f"TrOCR Error: {str(e)}" def extract_with_easyocr(self, image: Image.Image) -> str: model_info = self.models["easyocr"] if not model_info["loaded"]: if not self.load_model("easyocr"): return "EasyOCR loading failed" try: reader = model_info["model"] img_array = np.array(image) results = reader.readtext(img_array, detail=False, paragraph=False) if results: return ' '.join(results).strip() return "No text detected" except Exception as e: print(f"EasyOCR extraction error: {e}") return f"EasyOCR Error: {str(e)}" def extract_with_detr(self, image: Image.Image) -> str: model_info = self.models["detr_license"] if not model_info["loaded"]: if not self.load_model("detr_license"): return "DETR model loading failed" try: if hasattr(model_info["model"], '__call__'): results = model_info["model"](image) if results and len(results) > 0: return f"Detected {len(results)} objects" else: return self.extract_with_trocr(image, "detr_license") except Exception as e: print(f"DETR extraction error: {e}") return f"DETR Error: {str(e)}" def clean_license_text(self, text: str) -> str: if not text or text.startswith(("Error:", "Failed")): return text text = text.upper().strip() text = re.sub(r'[^A-Z0-9\s-]', '', text) text = re.sub(r'\s+', ' ', text).strip() common_corrections = { 'O': '0', 'I': '1', 'S': '5', 'B': '8', 'G': '6', 'Z': '2' } for old, new in common_corrections.items(): if sum(c.isdigit() for c in text) > sum(c.isalpha() for c in text): text = text.replace(old, new) return text def extract_text_with_model(self, image: Union[Image.Image, str], model_key: Optional[str] = None, use_preprocessing: bool = True) -> Dict: if isinstance(image, str): if os.path.exists(image): image = Image.open(image) else: return {"error": f"Image file not found: {image}"} if model_key is None: model_key = self.current_model if model_key not in self.models: return {"error": f"Unknown model: {model_key}"} result = { "model_used": self.models[model_key]["name"], "model_key": model_key, "preprocessing": use_preprocessing, "extractions": [], "best_result": "", "confidence": 0.0 } try: images_to_process = self.preprocess_image_advanced(image) if use_preprocessing else [image] for i, processed_img in enumerate(images_to_process): try: if self.models[model_key]["type"] == "transformers": raw_text = self.extract_with_trocr(processed_img, model_key) elif self.models[model_key]["type"] == "object_detection": raw_text = self.extract_with_detr(processed_img) elif self.models[model_key]["type"] == "easyocr": raw_text = self.extract_with_easyocr(processed_img) else: raw_text = "Unsupported model type" cleaned_text = self.clean_license_text(raw_text) extraction = { "step": i, "raw_text": raw_text, "cleaned_text": cleaned_text, "length": len(cleaned_text) if cleaned_text else 0 } result["extractions"].append(extraction) if cleaned_text and not cleaned_text.startswith(("Error:", "Failed")): if len(cleaned_text) > len(result["best_result"]): result["best_result"] = cleaned_text result["confidence"] = 0.8 + (len(cleaned_text) * 0.02) except Exception as e: print(f"Error processing image variant {i}: {e}") continue if not result["best_result"]: if result["extractions"]: result["best_result"] = result["extractions"][0].get("raw_text", "No text found") result["confidence"] = 0.3 else: result["best_result"] = "No text extracted" result["confidence"] = 0.0 return result except Exception as e: return {"error": f"Extraction failed: {str(e)}"} advanced_ocr = AdvancedLicensePlateOCR() def get_available_models(): return advanced_ocr.list_available_models() def set_ocr_model(model_key: str) -> bool: return advanced_ocr.load_model(model_key) def extract_license_plate_text_advanced(image: Union[Image.Image, str], model_key: Optional[str] = None) -> str: try: result = advanced_ocr.extract_text_with_model(image, model_key) if "error" in result: return f"Error: {result['error']}" return result.get("best_result", "No text found") except Exception as e: return f"Error: {str(e)}" def get_detailed_analysis(image: Union[Image.Image, str], model_key: Optional[str] = None) -> Dict: return advanced_ocr.extract_text_with_model(image, model_key) if __name__ == "__main__": print("Advanced License Plate OCR System") print("=" * 40) models = get_available_models() print("Available models:") for key, info in models.items(): status = "✅" if info["loaded"] else "⚪" print(f"{status} {key}: {info['name']} - {info['description']}") print("\nRecommended models (in order):") print("1. trocr_license - Best for license plates") print("2. detr_license - End-to-end detection") print("3. easyocr - Reliable fallback") print("\nUsage:") print("from advanced_ocr import extract_license_plate_text_advanced, set_ocr_model") print("set_ocr_model('trocr_license')") print("text = extract_license_plate_text_advanced('license_plate.jpg')")