ChronoStellar commited on
Commit
91073e2
·
verified ·
1 Parent(s): 75baf9a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import joblib
4
+ import numpy as np
5
+ from skimage.feature import hog
6
+ import tensorflow as tf
7
+ from tensorflow import keras
8
+ from tensorflow.keras import layers
9
+ from transformers import AutoTokenizer, AutoModelForImageTextToText
10
+ from transformers import VisionEncoderDecoderModel, TrOCRProcessor
11
+ import torch
12
+ from PIL import Image
13
+
14
+ # Paths to your models
15
+ MODEL_TYPES = ["HOG & Logistic Regression","CRNN CTC","Fine Tuned TrOCR"]
16
+
17
+ clf_hog = joblib.load('/content/HOG_LogRes.pkl')
18
+
19
+ clf_crnn = tf.keras.models.load_model('/content/crnn_ctc.keras')
20
+ num_to_char = joblib.load('./decoder.joblib')
21
+
22
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
23
+ clf_trocr = AutoModelForImageTextToText.from_pretrained("ChronoStellar/TrOCR_IndonesianLPR")
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ clf_trocr.to(device)
27
+
28
+ # Preprocessing and prediction functions for each model
29
+ def ocr_model_1(file_path):
30
+ im = cv2.imread(file_path)
31
+ im_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
32
+ ret, im_th = cv2.threshold(im_gray, 120, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
33
+ ctrs, hier = cv2.findContours(im_th, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
34
+ bboxes = [cv2.boundingRect(c) for c in ctrs]
35
+ sorted_bboxes = sorted(bboxes, key=lambda b: b[0])
36
+
37
+ plate_char = []
38
+ image_height, image_width = im.shape[:2]
39
+ height_threshold = image_height * 0.3
40
+ width_threshold = image_width * 0.3
41
+
42
+ for num, i_bboxes in enumerate(sorted_bboxes):
43
+ [x, y, w, h] = i_bboxes
44
+ if h > height_threshold and w < width_threshold:
45
+ roi = im_gray[y:y + h, x:x + w]
46
+ roi = cv2.resize(roi, (64, 128), interpolation=cv2.INTER_AREA)
47
+ roi_hog_fd = hog(roi, orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1))
48
+ nbr = clf_hog.predict(np.array([roi_hog_fd]))
49
+ plate_char.append(str(nbr[0]))
50
+
51
+ return ''.join(plate_char)
52
+
53
+ max_length = 9
54
+ img_width = 200
55
+ img_height = 50
56
+
57
+ def decode_batch_predictions(pred):
58
+ input_len = np.ones(pred.shape[0]) * pred.shape[1]
59
+ results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
60
+ :, :max_length
61
+ ]
62
+ output_text = []
63
+ for res in results:
64
+ res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
65
+ res = res.replace('[UNK]', '')
66
+ output_text.append(res)
67
+ return output_text
68
+
69
+ def ocr_model_2(file_path):
70
+ img = tf.io.read_file(file_path)
71
+ img = tf.io.decode_png(img, channels=1)
72
+ img = tf.image.convert_image_dtype(img, tf.float32)
73
+ img = tf.image.resize(img, [img_height, img_width])
74
+ img = tf.transpose(img, perm=[1, 0, 2])
75
+ img = tf.expand_dims(img, axis=0)
76
+ preds = clf_crnn.predict(img)
77
+ pred_text = decode_batch_predictions(preds)
78
+ return pred_text[0]
79
+
80
+
81
+ def ocr_model_3(file_path):
82
+ pil_image = Image.open(file_path).convert("RGB")
83
+ pixel_values = processor(pil_image, return_tensors="pt").pixel_values
84
+ pixel_values = pixel_values.to(device)
85
+ clf_trocr.eval()
86
+ with torch.no_grad():
87
+ generated_ids = clf_trocr.generate(pixel_values)
88
+ predicted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
89
+ return predicted_text
90
+
91
+
92
+ # Master OCR function that chooses the appropriate pipeline
93
+ def ocr(file_path, model_name):
94
+ if model_name == MODEL_TYPES[0]:
95
+ return ocr_model_1(file_path)
96
+ elif model_name == MODEL_TYPES[1]:
97
+ return ocr_model_2(file_path)
98
+ elif model_name == MODEL_TYPES[2]:
99
+ return ocr_model_3(file_path)
100
+
101
+ # Create Gradio interface
102
+ interface = gr.Interface(
103
+ fn=ocr,
104
+ inputs=[
105
+ gr.Image(type="filepath"),
106
+ gr.Dropdown(choices=MODEL_TYPES, label="Choose Model")
107
+ ],
108
+ outputs=gr.Textbox(label="Predicted License Plate"),
109
+ title="Automatic License Plate Recognition",
110
+ description="Provide the file path of a license plate image, choose a model, and the system will predict the text on it. These Models are all trained on the same dataset, one model might be better compared to the other",
111
+ examples=[
112
+ ['/content/B8837NR.jpg', ''],
113
+ ['/content/E5105OD.jpg', '']
114
+ ]
115
+ )
116
+
117
+ # Launch the Gradio app
118
+ interface.launch()