arsath-sm commited on
Commit
6af8008
·
verified ·
1 Parent(s): 95a4f19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -41
app.py CHANGED
@@ -5,6 +5,18 @@ import onnxruntime as ort
5
  from PIL import Image
6
  import tempfile
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Load the ONNX model
9
  @st.cache_resource
10
  def load_model():
@@ -13,25 +25,18 @@ def load_model():
13
  ort_session = load_model()
14
 
15
  def preprocess_image(image, target_size=(640, 640)):
16
- # Convert PIL Image to numpy array if necessary
17
  if isinstance(image, Image.Image):
18
  image = np.array(image)
19
 
20
- # Convert RGB to BGR
21
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
22
-
23
- # Resize image
24
  image = cv2.resize(image, target_size)
25
- # Normalize
26
  image = image.astype(np.float32) / 255.0
27
- # Transpose for ONNX input
28
  image = np.transpose(image, (2, 0, 1))
29
- # Add batch dimension
30
  image = np.expand_dims(image, axis=0)
31
- return image
32
 
33
- def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45):
34
- # Handle different possible output formats
35
  if isinstance(output, (list, tuple)):
36
  predictions = output[0]
37
  elif isinstance(output, np.ndarray):
@@ -39,7 +44,6 @@ def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_thre
39
  else:
40
  raise ValueError(f"Unexpected output type: {type(output)}")
41
 
42
- # Reshape if necessary
43
  if len(predictions.shape) == 4:
44
  predictions = predictions.squeeze((0, 1))
45
  elif len(predictions.shape) == 3:
@@ -59,52 +63,93 @@ def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_thre
59
  # Convert boxes from [x, y, w, h] to [x1, y1, x2, y2]
60
  boxes[:, 2:] += boxes[:, :2]
61
 
62
- # Scale boxes to image size
63
- boxes[:, [0, 2]] *= image_shape[1]
64
- boxes[:, [1, 3]] *= image_shape[0]
65
-
66
- # Apply NMS
67
- indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold)
68
 
 
69
  results = []
70
- for i in indices:
71
- box = boxes[i]
72
- score = scores[i]
73
- class_id = class_ids[i]
74
- x1, y1, x2, y2 = map(int, box)
75
- results.append((x1, y1, x2, y2, float(score), int(class_id)))
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  return results
78
 
79
  def process_image(image):
80
  orig_image = image.copy()
81
- processed_image = preprocess_image(image)
82
 
83
  # Run inference
84
  inputs = {ort_session.get_inputs()[0].name: processed_image}
85
  outputs = ort_session.run(None, inputs)
86
 
87
- results = postprocess_results(outputs, image.shape)
88
 
89
  # Draw bounding boxes on the image
90
  for x1, y1, x2, y2, score, class_id in results:
91
- cv2.rectangle(orig_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
92
- label = f"License Plate: {score:.2f}"
93
- cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
96
 
97
  def process_video(video_path):
98
  cap = cv2.VideoCapture(video_path)
99
 
100
- # Get video properties
101
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103
  fps = int(cap.get(cv2.CAP_PROP_FPS))
104
 
105
- # Create a temporary file to store the processed video
106
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
107
- out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
 
 
 
 
 
 
 
 
 
108
 
109
  while cap.isOpened():
110
  ret, frame = cap.read()
@@ -113,15 +158,33 @@ def process_video(video_path):
113
 
114
  processed_frame = process_image(frame)
115
  out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
 
 
 
 
116
 
117
  cap.release()
118
  out.release()
 
119
 
120
  return temp_file.name
121
 
122
- st.title("License Plate Detection")
 
123
 
124
- uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"])
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  if uploaded_file is not None:
127
  file_type = uploaded_file.type.split('/')[0]
@@ -130,9 +193,10 @@ if uploaded_file is not None:
130
  image = Image.open(uploaded_file)
131
  st.image(image, caption="Uploaded Image", use_column_width=True)
132
 
133
- if st.button("Detect License Plates"):
134
- processed_image = process_image(np.array(image))
135
- st.image(processed_image, caption="Processed Image", use_column_width=True)
 
136
 
137
  elif file_type == "video":
138
  tfile = tempfile.NamedTemporaryFile(delete=False)
@@ -140,8 +204,20 @@ if uploaded_file is not None:
140
 
141
  st.video(tfile.name)
142
 
143
- if st.button("Detect License Plates"):
144
- processed_video = process_video(tfile.name)
145
- st.video(processed_video)
146
-
147
- st.write("Upload an image or video to detect license plates.")
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import tempfile
7
 
8
+ # Class labels for both vehicles and license plates
9
+ CLASSES = {
10
+ 0: "Vehicle",
11
+ 1: "License_Plate"
12
+ }
13
+
14
+ # Different colors for each class
15
+ COLORS = {
16
+ 0: (255, 0, 0), # Red for vehicles
17
+ 1: (0, 255, 0) # Green for license plates
18
+ }
19
+
20
  # Load the ONNX model
21
  @st.cache_resource
22
  def load_model():
 
25
  ort_session = load_model()
26
 
27
  def preprocess_image(image, target_size=(640, 640)):
 
28
  if isinstance(image, Image.Image):
29
  image = np.array(image)
30
 
 
31
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
32
+ original_shape = image.shape[:2]
 
33
  image = cv2.resize(image, target_size)
 
34
  image = image.astype(np.float32) / 255.0
 
35
  image = np.transpose(image, (2, 0, 1))
 
36
  image = np.expand_dims(image, axis=0)
37
+ return image, original_shape
38
 
39
+ def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_threshold=0.45):
 
40
  if isinstance(output, (list, tuple)):
41
  predictions = output[0]
42
  elif isinstance(output, np.ndarray):
 
44
  else:
45
  raise ValueError(f"Unexpected output type: {type(output)}")
46
 
 
47
  if len(predictions.shape) == 4:
48
  predictions = predictions.squeeze((0, 1))
49
  elif len(predictions.shape) == 3:
 
63
  # Convert boxes from [x, y, w, h] to [x1, y1, x2, y2]
64
  boxes[:, 2:] += boxes[:, :2]
65
 
66
+ # Scale boxes to original image size
67
+ h, w = original_shape
68
+ boxes[:, [0, 2]] *= w
69
+ boxes[:, [1, 3]] *= h
 
 
70
 
71
+ # Apply NMS for each class separately
72
  results = []
73
+ for class_id in np.unique(class_ids):
74
+ class_mask = class_ids == class_id
75
+ class_boxes = boxes[class_mask]
76
+ class_scores = scores[class_mask]
77
+
78
+ indices = cv2.dnn.NMSBoxes(
79
+ class_boxes.tolist(),
80
+ class_scores.tolist(),
81
+ confidence_threshold,
82
+ iou_threshold
83
+ )
84
+
85
+ for i in indices:
86
+ box = class_boxes[i]
87
+ score = class_scores[i]
88
+ x1, y1, x2, y2 = map(int, box)
89
+ results.append((x1, y1, x2, y2, float(score), int(class_id)))
90
 
91
  return results
92
 
93
  def process_image(image):
94
  orig_image = image.copy()
95
+ processed_image, original_shape = preprocess_image(image)
96
 
97
  # Run inference
98
  inputs = {ort_session.get_inputs()[0].name: processed_image}
99
  outputs = ort_session.run(None, inputs)
100
 
101
+ results = postprocess_results(outputs, original_shape)
102
 
103
  # Draw bounding boxes on the image
104
  for x1, y1, x2, y2, score, class_id in results:
105
+ color = COLORS[class_id]
106
+ cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
107
+
108
+ label = f"{CLASSES[class_id]}: {score:.2f}"
109
+ # Calculate text size for better positioning
110
+ (text_width, text_height), _ = cv2.getTextSize(
111
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2
112
+ )
113
+ # Draw background rectangle for text
114
+ cv2.rectangle(
115
+ orig_image,
116
+ (x1, y1 - text_height - 10),
117
+ (x1 + text_width, y1),
118
+ color,
119
+ -1
120
+ )
121
+ # Draw text
122
+ cv2.putText(
123
+ orig_image,
124
+ label,
125
+ (x1, y1 - 5),
126
+ cv2.FONT_HERSHEY_SIMPLEX,
127
+ 0.9,
128
+ (255, 255, 255),
129
+ 2
130
+ )
131
 
132
  return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
133
 
134
  def process_video(video_path):
135
  cap = cv2.VideoCapture(video_path)
136
 
 
137
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
138
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
139
  fps = int(cap.get(cv2.CAP_PROP_FPS))
140
 
 
141
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
142
+ out = cv2.VideoWriter(
143
+ temp_file.name,
144
+ cv2.VideoWriter_fourcc(*'mp4v'),
145
+ fps,
146
+ (width, height)
147
+ )
148
+
149
+ # Add progress bar for video processing
150
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
151
+ progress_bar = st.progress(0)
152
+ frame_count = 0
153
 
154
  while cap.isOpened():
155
  ret, frame = cap.read()
 
158
 
159
  processed_frame = process_image(frame)
160
  out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
161
+
162
+ # Update progress bar
163
+ frame_count += 1
164
+ progress_bar.progress(frame_count / total_frames)
165
 
166
  cap.release()
167
  out.release()
168
+ progress_bar.empty()
169
 
170
  return temp_file.name
171
 
172
+ # Streamlit UI
173
+ st.title("Vehicle and License Plate Detection")
174
 
175
+ # Add confidence threshold slider
176
+ confidence_threshold = st.slider(
177
+ "Confidence Threshold",
178
+ min_value=0.0,
179
+ max_value=1.0,
180
+ value=0.25,
181
+ step=0.05
182
+ )
183
+
184
+ uploaded_file = st.file_uploader(
185
+ "Choose an image or video file",
186
+ type=["jpg", "jpeg", "png", "mp4"]
187
+ )
188
 
189
  if uploaded_file is not None:
190
  file_type = uploaded_file.type.split('/')[0]
 
193
  image = Image.open(uploaded_file)
194
  st.image(image, caption="Uploaded Image", use_column_width=True)
195
 
196
+ if st.button("Detect Objects"):
197
+ with st.spinner("Processing image..."):
198
+ processed_image = process_image(np.array(image))
199
+ st.image(processed_image, caption="Processed Image", use_column_width=True)
200
 
201
  elif file_type == "video":
202
  tfile = tempfile.NamedTemporaryFile(delete=False)
 
204
 
205
  st.video(tfile.name)
206
 
207
+ if st.button("Detect Objects"):
208
+ with st.spinner("Processing video..."):
209
+ processed_video = process_video(tfile.name)
210
+ st.video(processed_video)
211
+
212
+ # Add legend
213
+ st.markdown("### Detection Legend")
214
+ for class_id, class_name in CLASSES.items():
215
+ color = COLORS[class_id]
216
+ st.markdown(
217
+ f'<div style="display: flex; align-items: center;">'
218
+ f'<div style="width: 20px; height: 20px; background-color: rgb{color}; margin-right: 10px;"></div>'
219
+ f'<span>{class_name}</span></div>',
220
+ unsafe_allow_html=True
221
+ )
222
+
223
+ st.write("Upload an image or video to detect vehicles and license plates.")