File size: 12,437 Bytes
a2f3593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import gradio as gr
from ultralytics import YOLO
import tempfile
import os
import cv2
import numpy as np
import torch
import atexit
import uuid

# Load the YOLOv8 pose estimation model once at the start
model = YOLO("yolov8n-pose.pt")  

# Define the skeleton connections based on COCO keypoints
COCO_KEYPOINTS = [
    "nose", "left_eye", "right_eye", "left_ear", "right_ear",
    "left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
    "left_wrist", "right_wrist", "left_hip", "right_hip",
    "left_knee", "right_knee", "left_ankle", "right_ankle"
]

# Define the skeleton as pairs of keypoints indices
SKELETON_CONNECTIONS = [
    (0, 1), (0, 2),       # Nose to eyes
    (1, 3), (2, 4),       # Eyes to ears
    (0, 5), (0, 6),       # Nose to shoulders
    (5, 6),               # Shoulders to each other
    (5, 7), (6, 8),       # Shoulders to elbows
    (7, 9), (8, 10),      # Elbows to wrists
    (5, 11), (6, 12),     # Shoulders to hips
    (11, 12),             # Hips to each other
    (11, 13), (12, 14),   # Hips to knees
    (13, 15), (14, 16)    # Knees to ankles
]

def calculate_torso_angle(keypoints, frame_height):
    """
    Calculate the angle of the torso with respect to the vertical axis.
    
    Args:
        keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
        frame_height (int): Height of the video frame in pixels.
    
    Returns:
        float: Angle in degrees. Returns None if keypoints are not detected properly.
    """
    try:
        # COCO keypoint indices
        LEFT_SHOULDER = 5
        RIGHT_SHOULDER = 6
        LEFT_HIP = 11
        RIGHT_HIP = 12

        # Extract shoulder and hip coordinates
        left_shoulder = keypoints[LEFT_SHOULDER][:2]
        right_shoulder = keypoints[RIGHT_SHOULDER][:2]
        left_hip = keypoints[LEFT_HIP][:2]
        right_hip = keypoints[RIGHT_HIP][:2]

        # Check visibility (visibility > 0.3)
        if (keypoints[LEFT_SHOULDER][2] < 0.3 or keypoints[RIGHT_SHOULDER][2] < 0.3 or
            keypoints[LEFT_HIP][2] < 0.3 or keypoints[RIGHT_HIP][2] < 0.3):
            return None

        # Calculate mid points
        mid_shoulder = (left_shoulder + right_shoulder) / 2
        mid_hip = (left_hip + right_hip) / 2

        # Calculate the vector of the torso
        vector = mid_hip - mid_shoulder

        # Calculate angle with respect to the vertical axis
        angle_rad = np.arctan2(vector[0], vector[1])
        angle_deg = np.degrees(angle_rad)

        return angle_deg
    except Exception as e:
        print(f"Error calculating torso angle: {e}")
        return None

def draw_skeleton(frame, keypoints, show_labels=True):
    """
    Draws the skeleton on the frame based on keypoints.
    
    Args:
        frame (numpy.ndarray): The current video frame.
        keypoints (numpy.ndarray): Array of shape (17, 3) representing COCO keypoints.
        show_labels (bool): Whether to display keypoint indices.
    
    Returns:
        numpy.ndarray: Annotated frame with skeleton.
    """
    for connection in SKELETON_CONNECTIONS:
        start_idx, end_idx = connection
        x_start, y_start, conf_start = keypoints[start_idx]
        x_end, y_end, conf_end = keypoints[end_idx]
        
        # Only draw if both keypoints have sufficient confidence
        if conf_start > 0.5 and conf_end > 0.5:
            start_point = (int(x_start), int(y_start))
            end_point = (int(x_end), int(y_end))
            cv2.line(frame, start_point, end_point, (255, 0, 0), 2)  # Blue lines

    if show_labels:
        # Draw keypoints indices
        for idx, (x, y, conf) in enumerate(keypoints):
            if conf > 0.5:
                cv2.putText(frame, f"{idx}", (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 0, 0), 1)  # Blue labels

    return frame

def detect_fall(video_path, angle_threshold=30, consecutive_frames=3, frame_sampling_rate=1, confidence_threshold=0.3, show_labels=True):
    """
    Detects falls in the uploaded video using pose estimation.
    
    Args:
        video_path (str): The path to the input video file uploaded by the user.
        angle_threshold (float): Angle threshold to classify a fall (in degrees).
        consecutive_frames (int): Number of consecutive frames to confirm a fall.
        frame_sampling_rate (int): Process every nth frame.
        confidence_threshold (float): Minimum confidence required for keypoint detection.
        show_labels (bool): Whether to display keypoint indices.
    
    Returns:
        tuple: (annotated_video_path, notification_message)
    """
    try:
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError("Unable to open the video file.")

        # Video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')

        # Create a unique temporary file for the annotated video
        unique_id = uuid.uuid4().hex
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", prefix=f"annotated_{unique_id}_") as tmp:
            annotated_video_path = tmp.name

        out = cv2.VideoWriter(annotated_video_path, fourcc, fps, (width, height))

        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        current_frame = 0
        consecutive_fall_frames = 0
        total_falls = 0
        fall_frames = []  # To store frames where falls were detected

        while True:
            ret, frame = cap.read()
            if not ret:
                break  # End of video

            current_frame += 1

            # Implement frame sampling
            if current_frame % frame_sampling_rate != 0:
                out.write(frame)
                continue

            print(f"Processing frame {current_frame}/{frame_count}")

            # Run pose estimation
            results = model.predict(source=frame, conf=confidence_threshold, save=False, stream=False)

            # Iterate through detected persons
            for result in results:
                if not hasattr(result, 'keypoints') or result.keypoints is None:
                    continue
                for keypoints in result.keypoints.data:
                    # keypoints should be a tensor of shape (17,3)
                    if keypoints is None or not hasattr(keypoints, 'cpu'):
                        continue
                    # Convert to NumPy array
                    if isinstance(keypoints, torch.Tensor):
                        kpts = keypoints.cpu().numpy()
                    elif isinstance(keypoints, np.ndarray):
                        kpts = keypoints
                    else:
                        print(f"Unexpected keypoints data type: {type(keypoints)}")
                        continue

                    if kpts.size == 0 or kpts.shape[0] < 17:
                        print(f"Insufficient keypoints for processing in frame {current_frame}")
                        continue

                    angle = calculate_torso_angle(kpts, height)
                    if angle is None:
                        continue

                    # Determine if it's a fall
                    if abs(angle) > angle_threshold:
                        consecutive_fall_frames += 1
                        label = "Fall Detected!"
                        color = (0, 0, 255)  # Red
                    else:
                        if consecutive_fall_frames >= consecutive_frames:
                            total_falls += 1
                            fall_frames.append(current_frame)
                        consecutive_fall_frames = 0
                        label = "Normal"
                        color = (0, 255, 0)  # Green

                    # If fall persists over consecutive frames, mark as fall
                    if consecutive_fall_frames >= consecutive_frames:
                        cv2.putText(frame, label, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)

                    # Draw keypoints and skeleton
                    frame = draw_skeleton(frame, kpts, show_labels=show_labels)

            # Write annotated frame
            out.write(frame)

        # Release resources
        cap.release()
        out.release()

        # Final check for falls that persisted until the end of the video
        if consecutive_fall_frames >= consecutive_frames:
            total_falls += 1
            fall_frames.append(current_frame)

        # Generate notification message
        if total_falls > 0:
            if total_falls == 1:
                notification = f"A fall was detected at frame {fall_frames[0]}."
            else:
                frames = ', '.join(map(str, fall_frames))
                notification = f"{total_falls} falls were detected at frames: {frames}."
        else:
            notification = "No falls were detected in the video."

        # Check if annotated video was created
        if not os.path.exists(annotated_video_path):
            raise FileNotFoundError("Annotated video was not found. Please check the model and processing steps.")

        return annotated_video_path, notification

    except Exception as e:
        # Clean up in case of an error
        print(f"Error during fall detection: {e}")
        return None, f"An error occurred during fall detection: {e}"

def create_gradio_interface():
    # Define the Gradio interface with adjustable parameters
    iface = gr.Interface(
        fn=detect_fall,
        inputs=[
            gr.Video(label="Upload Video"),
            gr.Slider(
                label="Angle Threshold (degrees)",
                minimum=0,
                maximum=90,
                step=1,
                value=30,
                interactive=True,
                info="Adjust the torso angle threshold to classify a fall. Lower values increase sensitivity."
            ),
            gr.Slider(
                label="Consecutive Frames to Confirm Fall",
                minimum=1,
                maximum=10,
                step=1,
                value=3,
                interactive=True,
                info="Number of consecutive frames exceeding the angle threshold required to confirm a fall."
            ),
            gr.Slider(
                label="Frame Sampling Rate",
                minimum=1,
                maximum=10,
                step=1,
                value=1,
                interactive=True,
                info="Process every nth frame to speed up detection. Higher values reduce processing time."
            ),
            gr.Slider(
                label="Confidence Threshold",
                minimum=0.0,
                maximum=1.0,
                step=0.05,
                value=0.3,  # Changed default value to 0.3
                interactive=True,
                info="Minimum confidence required for keypoint detection. Higher values reduce false positives."
            ),
            gr.Checkbox(
                label="Show Keypoint Labels",
                value=True,
                interactive=True,
                info="Toggle the display of keypoint indices on the video."
            )
        ],
        outputs=[
            gr.Video(label="Annotated Video"),
            gr.Textbox(label="Fall Detection Notification")
        ],
        title="Fall Detection App 🚨",
        description=(
            "Upload a video of a person falling, and the app will detect and annotate the fall "
            "using pose estimation. Adjust the angle threshold, consecutive frames, frame sampling rate, "
            "and confidence threshold to fine-tune detection sensitivity and performance. "
            "The annotated video will display keypoints, skeleton lines, and indicate when a fall is detected."
        ),
        examples=[
            ["demo/person falling.mp4", 30, 3, 1, 0.3, True]
        ],  # Added example video with corresponding parameter values
        flagging_mode="never",  # Updated parameter name
    )
    return iface

# Create the Gradio interface
iface = create_gradio_interface()

# Ensure temporary directories are cleaned up on exit
def cleanup_temp_dirs():
    temp_dir = tempfile.gettempdir()
    # Implement additional cleanup logic if necessary

atexit.register(cleanup_temp_dirs)

# Launch the app
if __name__ == "__main__":
    iface.launch()