Spaces:
Runtime error
Runtime error
import cv2 | |
import torch | |
import numpy as np | |
import gradio as gr | |
from ultralytics import YOLO | |
from deep_sort_realtime.deep_sort import DeepSort | |
class ObjectTracker: | |
def __init__(self, person_model_path='yolov8n.pt'): | |
""" | |
Initialize object tracker with YOLO and DeepSort | |
""" | |
# Load YOLO model for person detection | |
self.model = YOLO(person_model_path) | |
# Initialize DeepSort tracker | |
self.tracker = DeepSort( | |
max_age=30, # Tracks can be lost for up to 30 frames | |
n_init=3, # Number of consecutive detections before track is confirmed | |
) | |
# Tracking statistics | |
self.person_count = 0 | |
self.tracking_data = {} | |
def process_frame(self, frame): | |
""" | |
Process a single frame for object detection and tracking | |
""" | |
# Detect persons using YOLO | |
results = self.model(frame, classes=[0], conf=0.5) | |
# Extract bounding boxes and confidences | |
detections = [] | |
for r in results: | |
boxes = r.boxes | |
for box in boxes: | |
# Convert to [x, y, w, h] format for DeepSort | |
x1, y1, x2, y2 = box.xyxy[0] | |
bbox = [x1.item(), y1.item(), (x2-x1).item(), (y2-y1).item()] | |
conf = box.conf.item() | |
detections.append((bbox, conf)) | |
# Update tracks | |
if detections: | |
tracks = self.tracker.update_tracks( | |
detections, | |
frame=frame | |
) | |
# Annotate frame with tracking information | |
for track in tracks: | |
if not track.is_confirmed(): | |
continue | |
track_id = track.track_id | |
ltrb = track.to_ltrb() | |
# Draw bounding box | |
cv2.rectangle( | |
frame, | |
(int(ltrb[0]), int(ltrb[1])), | |
(int(ltrb[2]), int(ltrb[3])), | |
(0, 255, 0), | |
2 | |
) | |
# Add track ID | |
cv2.putText( | |
frame, | |
f'ID: {track_id}', | |
(int(ltrb[0]), int(ltrb[1]-10)), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.9, | |
(0, 255, 0), | |
2 | |
) | |
return frame | |
def process_video(input_video): | |
""" | |
Main video processing function for Gradio | |
""" | |
# Initialize tracker | |
tracker = ObjectTracker() | |
# Open input video | |
cap = cv2.VideoCapture(input_video) | |
# Prepare output video writer | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter('output_tracked.mp4', fourcc, fps, (width, height)) | |
# Process video frames | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Process and annotate frame | |
processed_frame = tracker.process_frame(frame) | |
# Write processed frame | |
out.write(processed_frame) | |
# Release resources | |
cap.release() | |
out.release() | |
return 'output_tracked.mp4' | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_video, | |
inputs=gr.Video(label="Upload Video for Tracking"), | |
outputs=gr.Video(label="Tracked Video"), | |
title="Person Tracking with YOLO and DeepSort", | |
description="Upload a video to track and annotate person movements" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |