Abs6187 commited on
Commit
428e3e7
·
verified ·
1 Parent(s): 713bd0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -61
app.py CHANGED
@@ -1,64 +1,51 @@
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- from tensorflow.keras.models import load_model
5
- from sklearn.preprocessing import StandardScaler
6
- from ultralytics import YOLO
7
-
8
- # Load models
9
- lstm_model = load_model('suspicious_activity_model.h5')
10
- yolo_model = YOLO('yolov8n-pose.pt') # Ensure this model supports keypoint detection
11
- scaler = StandardScaler()
12
-
13
- # Function to extract keypoints from a frame
14
- def extract_keypoints(frame):
15
- results = yolo_model(frame, verbose=False)
16
- for r in results:
17
- if r.keypoints is not None and len(r.keypoints) > 0:
18
- keypoints = r.keypoints.xyn.tolist()[0] # Use the first person's keypoints
19
- flattened_keypoints = [kp for keypoint in keypoints for kp in keypoint[:2]] # Flatten x, y values
20
- return flattened_keypoints
21
- return None # Return None if no keypoints are detected
22
-
23
- # Function to process each frame
24
- def process_frame(frame):
25
- results = yolo_model(frame, verbose=False)
26
-
27
- for box in results[0].boxes:
28
- cls = int(box.cls[0]) # Class ID
29
- confidence = float(box.conf[0])
30
-
31
- if cls == 0 and confidence > 0.5: # Detect persons only
32
- x1, y1, x2, y2 = map(int, box.xyxy[0]) # Bounding box coordinates
33
-
34
- # Extract ROI for classification
35
- roi = frame[y1:y2, x1:x2]
36
- if roi.size > 0:
37
- keypoints = extract_keypoints(roi)
38
- if keypoints is not None and len(keypoints) > 0:
39
- keypoints_scaled = scaler.fit_transform([keypoints])
40
- keypoints_reshaped = keypoints_scaled.reshape((1, 1, len(keypoints)))
41
-
42
- prediction = (lstm_model.predict(keypoints_reshaped) > 0.5).astype(int)[0][0]
43
-
44
- color = (0, 0, 255) if prediction == 1 else (0, 255, 0)
45
- label = 'Suspicious' if prediction == 1 else 'Normal'
46
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
47
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
48
- return frame
49
-
50
- # Gradio video streaming function
51
- def video_processing(video_frame):
52
- frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB) # Convert to RGB
53
- processed_frame = process_frame(frame)
54
- return processed_frame
55
-
56
- # Launch Gradio app
57
- gr.Interface(
58
- fn=video_processing,
59
- inputs=gr.Video(streaming=True), # Correct the Video component
60
- outputs="video",
61
- live=True,
62
- title="Suspicious Activity Detection"
63
- ).launch(debug=True)
64
-
 
1
+ # app.py
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
5
+ from model import SuspiciousActivityModel # Import the model
6
+
7
+ # Initialize the model paths
8
+ lstm_model_path = 'suspicious_activity_model.h5' # Path to your LSTM model
9
+ yolo_model_path = 'yolov8n-pose.pt' # Path to your YOLO model
10
+
11
+ # Initialize the suspicious activity model
12
+ model = SuspiciousActivityModel(lstm_model_path, yolo_model_path)
13
+
14
+ # Function to process video frame
15
+ def process_video(video_frame):
16
+ # Check if the input frame is a valid NumPy array
17
+ if isinstance(video_frame, np.ndarray):
18
+ print(f"Frame shape: {video_frame.shape}") # Print the shape of the frame for debugging
19
+
20
+ # Convert frame from BGR to RGB (OpenCV uses BGR by default)
21
+ try:
22
+ frame_rgb = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB)
23
+ except cv2.error as e:
24
+ print(f"Error in cvtColor: {e}")
25
+ return video_frame # Return the original frame if error occurs
26
+
27
+ # Call model to detect activity in the frame
28
+ label = model.detect_activity(frame_rgb)
29
+
30
+ # Add label to the frame (Optional: you can also draw bounding boxes)
31
+ cv2.putText(frame_rgb, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
32
+
33
+ # Convert back to BGR for Gradio (since it expects BGR format)
34
+ frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
35
+ return frame_bgr
36
+ else:
37
+ print("Received invalid frame format")
38
+ return video_frame # Return the original frame if it's not valid
39
+
40
+ # Gradio interface
41
+ iface = gr.Interface(
42
+ fn=process_video, # Function that processes each frame
43
+ inputs=gr.Video(source="webcam", streaming=True), # Use webcam as input
44
+ outputs=gr.Video(), # Output is also a video
45
+ live=True, # Stream the video in real time
46
+ title="Suspicious Activity Detection" # Interface title
47
+ )
48
+
49
+ # Launch the app with public link
50
+ if __name__ == "__main__":
51
+ iface.launch(share=True) # Set share=True to create a public link