Abs6187 commited on
Commit
9e8d516
·
verified ·
1 Parent(s): dab86c7

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +315 -0
  2. best.pt +3 -0
  3. gitattributes +1 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from ultralytics import YOLO
6
+ from ultralytics.utils.plotting import Annotator, colors
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import tempfile
9
+ from pathlib import Path
10
+ import time
11
+ from typing import List, Tuple, Dict, Any, Optional
12
+ import google.generativeai as genai
13
+
14
+ # For tracking
15
+ from collections import defaultdict
16
+
17
+ # Configure Gemini API
18
+ gemini_api_key = "AIzaSyCBs4TumAonKI0AodIzbl4b8Vmu9eM_r9I" # In production, use environment variables
19
+ genai.configure(api_key=gemini_api_key)
20
+
21
+ def get_safety_analysis(stats: Dict[str, int], image_path: Optional[str] = None) -> str:
22
+ """Generate safety analysis using Gemini AI based on detection statistics."""
23
+ try:
24
+ model = genai.GenerativeModel('gemini-2.0-flash')
25
+
26
+ # Create a detailed prompt
27
+ prompt = f"""You are a traffic safety analyst. Based on the following detection statistics:
28
+ - Total Detections: {stats.get('total_detections', 0)}
29
+ - Riders with Helmet: {stats.get('with_helmet', 0)}
30
+ - Riders without Helmet: {stats.get('without_helmet', 0)}
31
+ - License Plates Detected: {stats.get('license_plates', 0)}
32
+
33
+ Provide a concise safety analysis and recommendations. Focus on:
34
+ 1. Helmet compliance rate
35
+ 2. Potential safety concerns
36
+ 3. Suggestions for improvement
37
+
38
+ Keep the response under 100 words."""
39
+
40
+ response = model.generate_content(prompt)
41
+ return response.text
42
+ except Exception as e:
43
+ print(f"Error in Gemini API: {str(e)}")
44
+ return "Safety analysis is currently unavailable. Please check your API key and internet connection."
45
+
46
+ # Download sample images and videos (optional)
47
+ sample_files = {
48
+ 'sample_1.jpg': 'https://github.com/Janno1402/Helmet-License-Plate-Detection/raw/main/Sample-Image-1.jpg',
49
+ 'sample_2.jpg': 'https://github.com/Janno1402/Helmet-License-Plate-Detection/raw/main/Sample-Image-2.jpg',
50
+ 'sample_3.jpg': 'https://github.com/Janno1402/Helmet-License-Plate-Detection/raw/main/Sample-Image-3.jpg',
51
+ 'sample_4.jpg': 'https://github.com/Janno1402/Helmet-License-Plate-Detection/raw/main/Sample-Image-4.jpg',
52
+ 'sample_5.jpg': 'https://github.com/Janno1402/Helmet-License-Plate-Detection/raw/main/Sample-Image-5.jpg',
53
+ 'traffic_violation.mp4': 'https://github.com/anmspro/Traffic-Signal-Violation-Detection-System/raw/master/Resources/input/input.mp4' # Traffic violation video
54
+ }
55
+
56
+ for filename, url in sample_files.items():
57
+ if not Path(filename).exists():
58
+ try:
59
+ torch.hub.download_url_to_file(url, filename)
60
+ except:
61
+ print(f"Could not download {filename}")
62
+
63
+ # Initialize model and tracking
64
+ model = YOLO("best.pt")
65
+
66
+ # Tracking variables
67
+ track_history = defaultdict(lambda: [])
68
+ violations = defaultdict(int)
69
+
70
+
71
+ def process_image(image_path: str, conf_threshold: float = 0.4, iou_threshold: float = 0.5,
72
+ image_size: int = 640, enable_tracking: bool = False) -> Tuple[Image.Image, Dict]:
73
+ """Process a single image and return annotated image and statistics."""
74
+ # Process image
75
+ results = model.predict(
76
+ source=image_path,
77
+ conf=conf_threshold,
78
+ iou=iou_threshold,
79
+ imgsz=image_size,
80
+ verbose=False
81
+ )
82
+
83
+ # Get results
84
+ boxes = results[0].boxes.xyxy.cpu().numpy()
85
+ scores = results[0].boxes.conf.cpu().numpy()
86
+ class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
87
+
88
+ # Initialize statistics with additional metrics
89
+ total_riders = int(sum((class_ids == 0) | (class_ids == 1)))
90
+ helmet_compliance = 0 if total_riders == 0 else int(sum(class_ids == 0) / total_riders * 100)
91
+
92
+ stats = {
93
+ 'total_detections': len(boxes),
94
+ 'with_helmet': int(sum(class_ids == 0)),
95
+ 'without_helmet': int(sum(class_ids == 1)),
96
+ 'license_plates': int(sum(class_ids == 2)),
97
+ 'helmet_compliance': helmet_compliance,
98
+ 'total_riders': total_riders,
99
+ 'violation_rate': 0 if total_riders == 0 else int((sum(class_ids == 1) / total_riders) * 100)
100
+ }
101
+
102
+ # Create annotated image
103
+ img = Image.open(image_path).convert("RGB")
104
+ draw = ImageDraw.Draw(img)
105
+
106
+ # Draw detections
107
+ for box, score, class_id in zip(boxes, scores, class_ids):
108
+ x1, y1, x2, y2 = box
109
+ label = f"{'Helmet' if class_id == 0 else 'No Helmet' if class_id == 1 else 'License Plate'} {score:.2f}"
110
+
111
+ # Draw rectangle
112
+ color = (0, 255, 0) if class_id == 0 else (0, 0, 255) if class_id == 1 else (255, 0, 0)
113
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
114
+
115
+ # Draw label background
116
+ text_bbox = draw.textbbox((x1, y1 - 20), label)
117
+ draw.rectangle(text_bbox, fill=color)
118
+ draw.text((x1, y1 - 20), label, fill=(255, 255, 255))
119
+
120
+ return img, stats
121
+
122
+ def process_video(video_path: str, conf_threshold: float = 0.4, iou_threshold: float = 0.5,
123
+ image_size: int = 640, enable_tracking: bool = True) -> str:
124
+ """Process a video file and return path to the output video."""
125
+ cap = cv2.VideoCapture(video_path)
126
+ if not cap.isOpened():
127
+ return "Error: Could not open video file."
128
+
129
+ # Get video properties
130
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
131
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
132
+ fps = cap.get(cv2.CAP_PROP_FPS)
133
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
134
+
135
+ # Create output video
136
+ output_path = "output_video.mp4"
137
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
138
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
139
+
140
+ # Process video frame by frame
141
+ frame_count = 0
142
+ while cap.isOpened():
143
+ ret, frame = cap.read()
144
+ if not ret:
145
+ break
146
+
147
+ # Process frame
148
+ results = model.track(
149
+ source=frame,
150
+ conf=conf_threshold,
151
+ iou=iou_threshold,
152
+ imgsz=image_size,
153
+ persist=True,
154
+ verbose=False
155
+ )
156
+
157
+ # Get tracking results
158
+ if hasattr(results[0].boxes, 'id') and results[0].boxes.id is not None:
159
+ track_ids = results[0].boxes.id.cpu().numpy().astype(int)
160
+ boxes = results[0].boxes.xyxy.cpu().numpy()
161
+ class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
162
+
163
+ # Update tracking history and detect violations
164
+ for box, track_id, class_id in zip(boxes, track_ids, class_ids):
165
+ if class_id == 1: # No helmet
166
+ violations[track_id] += 1
167
+ if violations[track_id] > 10: # If no helmet for 10 consecutive frames
168
+ # Draw warning
169
+ cv2.putText(frame, "SAFETY VIOLATION: NO HELMET!",
170
+ (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
171
+ 1, (0, 0, 255), 2, cv2.LINE_AA)
172
+
173
+ # Write frame to output video
174
+ out.write(results[0].plot())
175
+ frame_count += 1
176
+
177
+ # Release resources
178
+ cap.release()
179
+ out.release()
180
+
181
+ return output_path
182
+
183
+ def process_input(input_data, input_type, conf_threshold, iou_threshold, image_size, enable_tracking):
184
+ """Process input based on its type (image or video)."""
185
+ if input_type == "image":
186
+ if isinstance(input_data, str):
187
+ img_path = input_data
188
+ else:
189
+ # Save uploaded file temporarily
190
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
191
+ img_path = temp_file.name
192
+ input_data.save(img_path)
193
+
194
+ result_img, stats = process_image(
195
+ img_path, conf_threshold, iou_threshold, image_size, enable_tracking
196
+ )
197
+
198
+ # Generate safety analysis
199
+ safety_analysis = get_safety_analysis(stats)
200
+
201
+ # Create statistics text with safety analysis
202
+ stats_text = f"""
203
+ 🚦 Detection Results:
204
+ - Total Detections: {stats['total_detections']}
205
+ - With Helmet: {stats['with_helmet']}
206
+ - Without Helmet: {stats['without_helmet']}
207
+ - License Plates: {stats['license_plates']}
208
+
209
+ 🔍 Safety Analysis:
210
+ {safety_analysis}
211
+ """
212
+
213
+ return result_img, stats_text, None
214
+
215
+ elif input_type == "video":
216
+ if isinstance(input_data, str):
217
+ video_path = input_data
218
+ else:
219
+ # Save uploaded file temporarily
220
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
221
+ video_path = temp_file.name
222
+ input_data.save(video_path)
223
+
224
+ output_path = process_video(
225
+ video_path, conf_threshold, iou_threshold, image_size, enable_tracking
226
+ )
227
+
228
+ return None, "Video processing complete!", output_path
229
+
230
+ return None, "Unsupported input type", None
231
+
232
+
233
+ # Define Gradio interface components
234
+ with gr.Blocks(title="AI-Powered Helmet & License Plate Detection") as demo:
235
+ gr.Markdown("""
236
+ # 🛵 AI-Powered Helmet & License Plate Detection
237
+
238
+ This application uses YOLOv8 to detect motorcyclists with/without helmets and license plates in images and videos.
239
+ """)
240
+
241
+ with gr.Tabs():
242
+ with gr.TabItem("Image Detection"):
243
+ with gr.Row():
244
+ with gr.Column():
245
+ image_input = gr.Image(type="filepath", label="Upload Image")
246
+ video_input = gr.Video(visible=False)
247
+
248
+ with gr.Row():
249
+ conf_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05,
250
+ label="Confidence Threshold")
251
+ iou_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05,
252
+ label="IOU Threshold")
253
+
254
+ image_size = gr.Slider(minimum=320, maximum=1280, value=640, step=32,
255
+ label="Image Size")
256
+
257
+ process_btn = gr.Button("Process", variant="primary")
258
+
259
+ with gr.Column():
260
+ output_image = gr.Image(label="Detection Results", type="pil")
261
+ stats_output = gr.Textbox(label="Detection Statistics")
262
+ video_output = gr.Video(visible=False)
263
+
264
+ with gr.TabItem("Video Detection"):
265
+ with gr.Row():
266
+ with gr.Column():
267
+ video_input = gr.Video(label="Upload Video", examples=[["traffic_violation.mp4"]])
268
+ image_input = gr.Image(visible=False)
269
+
270
+ with gr.Row():
271
+ conf_slider_vid = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05,
272
+ label="Confidence Threshold")
273
+ iou_slider_vid = gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05,
274
+ label="IOU Threshold")
275
+
276
+ image_size_vid = gr.Slider(minimum=320, maximum=1280, value=640, step=32,
277
+ label="Processing Frame Size")
278
+
279
+ process_vid_btn = gr.Button("Process Video", variant="primary")
280
+
281
+ with gr.Column():
282
+ video_output = gr.Video(label="Processed Video")
283
+ stats_output_vid = gr.Textbox(label="Processing Status")
284
+ output_image = gr.Image(visible=False)
285
+
286
+ # Connect the process buttons to their respective functions
287
+ process_btn.click(
288
+ fn=process_input,
289
+ inputs=[
290
+ image_input,
291
+ gr.Number(value="image", visible=False),
292
+ conf_slider,
293
+ iou_slider,
294
+ image_size,
295
+ gr.Checkbox(value=True, visible=False)
296
+ ],
297
+ outputs=[output_image, stats_output, video_output]
298
+ )
299
+
300
+ process_vid_btn.click(
301
+ fn=process_input,
302
+ inputs=[
303
+ video_input,
304
+ gr.Number(value="video", visible=False),
305
+ conf_slider_vid,
306
+ iou_slider_vid,
307
+ image_size_vid,
308
+ gr.Checkbox(value=True, visible=False)
309
+ ],
310
+ outputs=[output_image, stats_output_vid, video_output]
311
+ )
312
+
313
+ # Launch the app
314
+ if __name__ == "__main__":
315
+ demo.launch(debug=True, share=True)
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f7830a9d4a0b36119bdfe37d8b9bdb52e951dccae43583c1c5c2e97cf8165b7
3
+ size 5468883
gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ best.pt filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==3.39.0
2
+ torch
3
+ ultralytics==8.3.40
4
+ numpy