arsath-sm commited on
Commit
ab634f0
·
verified ·
1 Parent(s): 76107e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -15
app.py CHANGED
@@ -13,6 +13,13 @@ def load_model():
13
  ort_session = load_model()
14
 
15
  def preprocess_image(image, target_size=(640, 640)):
 
 
 
 
 
 
 
16
  # Resize image
17
  image = cv2.resize(image, target_size)
18
  # Normalize
@@ -24,10 +31,24 @@ def preprocess_image(image, target_size=(640, 640)):
24
  return image
25
 
26
  def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45):
27
- # Assuming YOLO v5 output format
28
- boxes = output[0]
29
- scores = output[1]
30
- class_ids = output[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Filter by confidence
33
  mask = scores > confidence_threshold
@@ -35,6 +56,13 @@ def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_thre
35
  scores = scores[mask]
36
  class_ids = class_ids[mask]
37
 
 
 
 
 
 
 
 
38
  # Apply NMS
39
  indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold)
40
 
@@ -43,12 +71,8 @@ def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_thre
43
  box = boxes[i]
44
  score = scores[i]
45
  class_id = class_ids[i]
46
- x, y, w, h = box
47
- x1 = int(x * image_shape[1])
48
- y1 = int(y * image_shape[0])
49
- x2 = int((x + w) * image_shape[1])
50
- y2 = int((y + h) * image_shape[0])
51
- results.append((x1, y1, x2, y2, score, class_id))
52
 
53
  return results
54
 
@@ -68,7 +92,7 @@ def process_image(image):
68
  label = f"License Plate: {score:.2f}"
69
  cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
70
 
71
- return orig_image
72
 
73
  def process_video(video_path):
74
  cap = cv2.VideoCapture(video_path)
@@ -88,7 +112,7 @@ def process_video(video_path):
88
  break
89
 
90
  processed_frame = process_image(frame)
91
- out.write(processed_frame)
92
 
93
  cap.release()
94
  out.release()
@@ -104,12 +128,10 @@ if uploaded_file is not None:
104
 
105
  if file_type == "image":
106
  image = Image.open(uploaded_file)
107
- image = np.array(image)
108
-
109
  st.image(image, caption="Uploaded Image", use_column_width=True)
110
 
111
  if st.button("Detect License Plates"):
112
- processed_image = process_image(image)
113
  st.image(processed_image, caption="Processed Image", use_column_width=True)
114
 
115
  elif file_type == "video":
 
13
  ort_session = load_model()
14
 
15
  def preprocess_image(image, target_size=(640, 640)):
16
+ # Convert PIL Image to numpy array if necessary
17
+ if isinstance(image, Image.Image):
18
+ image = np.array(image)
19
+
20
+ # Convert RGB to BGR
21
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
22
+
23
  # Resize image
24
  image = cv2.resize(image, target_size)
25
  # Normalize
 
31
  return image
32
 
33
  def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45):
34
+ # Handle different possible output formats
35
+ if isinstance(output, (list, tuple)):
36
+ predictions = output[0]
37
+ elif isinstance(output, np.ndarray):
38
+ predictions = output
39
+ else:
40
+ raise ValueError(f"Unexpected output type: {type(output)}")
41
+
42
+ # Reshape if necessary
43
+ if len(predictions.shape) == 4:
44
+ predictions = predictions.squeeze((0, 1))
45
+ elif len(predictions.shape) == 3:
46
+ predictions = predictions.squeeze(0)
47
+
48
+ # Extract boxes, scores, and class_ids
49
+ boxes = predictions[:, :4]
50
+ scores = predictions[:, 4]
51
+ class_ids = predictions[:, 5]
52
 
53
  # Filter by confidence
54
  mask = scores > confidence_threshold
 
56
  scores = scores[mask]
57
  class_ids = class_ids[mask]
58
 
59
+ # Convert boxes from [x, y, w, h] to [x1, y1, x2, y2]
60
+ boxes[:, 2:] += boxes[:, :2]
61
+
62
+ # Scale boxes to image size
63
+ boxes[:, [0, 2]] *= image_shape[1]
64
+ boxes[:, [1, 3]] *= image_shape[0]
65
+
66
  # Apply NMS
67
  indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold)
68
 
 
71
  box = boxes[i]
72
  score = scores[i]
73
  class_id = class_ids[i]
74
+ x1, y1, x2, y2 = map(int, box)
75
+ results.append((x1, y1, x2, y2, float(score), int(class_id)))
 
 
 
 
76
 
77
  return results
78
 
 
92
  label = f"License Plate: {score:.2f}"
93
  cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
94
 
95
+ return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
96
 
97
  def process_video(video_path):
98
  cap = cv2.VideoCapture(video_path)
 
112
  break
113
 
114
  processed_frame = process_image(frame)
115
+ out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
116
 
117
  cap.release()
118
  out.release()
 
128
 
129
  if file_type == "image":
130
  image = Image.open(uploaded_file)
 
 
131
  st.image(image, caption="Uploaded Image", use_column_width=True)
132
 
133
  if st.button("Detect License Plates"):
134
+ processed_image = process_image(np.array(image))
135
  st.image(processed_image, caption="Processed Image", use_column_width=True)
136
 
137
  elif file_type == "video":