DRS_AIP_LBW / app.py
dschandra's picture
Update app.py
5fb1ae2 verified
raw
history blame
12.7 kB
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import gradio as gr
from scipy.interpolate import interp1d
import uuid
import os
try:
from OpenGL.GL import *
from OpenGL.GLU import *
from pygame import display, event, QUIT
HAS_OPENGL = True
except ImportError:
print("Warning: PyOpenGL or Pygame not found. 3D visualization will be disabled. Install with 'pip install PyOpenGL PyOpenGL_accelerate pygame'.")
HAS_OPENGL = False
# Load the trained YOLOv8n model
model = YOLO("best.pt")
# Constants
STUMPS_WIDTH = 0.2286 # meters
FRAME_RATE = 20
SLOW_MOTION_FACTOR = 2
CONF_THRESHOLD = 0.3
PITCH_ZONE_Y = 0.8
IMPACT_ZONE_Y = 0.7
IMPACT_DELTA_Y = 20
STUMPS_HEIGHT = 0.711 # meters
PITCH_LENGTH = 20.12 # meters (22 yards)
def process_video(video_path):
if not os.path.exists(video_path):
return [], [], [], "Error: Video file not found"
cap = cv2.VideoCapture(video_path)
frames = []
ball_positions = []
detection_frames = []
debug_log = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame.copy())
frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=10)
results = model.predict(frame, conf=CONF_THRESHOLD)
detections = [det for det in results[0].boxes if det.cls == 0]
if len(detections) == 1:
x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
detection_frames.append(len(frames) - 1)
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
frames[-1] = frame
debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
frame_count += 1
cap.release()
if not ball_positions:
debug_log.append("No valid single-ball detections in any frame")
else:
debug_log.append(f"Total valid single-ball detections: {len(ball_positions)}")
return frames, ball_positions, detection_frames, "\n".join(debug_log)
def estimate_trajectory_3d(ball_positions, detection_frames, frames):
if len(ball_positions) < 2:
return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections"
frame_height, frame_width = frames[0].shape[:2]
x_coords = np.array([pos[0] for pos in ball_positions]) / frame_width * PITCH_LENGTH
y_coords = np.array([frame_height - pos[1] for pos in ball_positions]) / frame_height * STUMPS_HEIGHT * 2
z_coords = np.zeros_like(x_coords) # Placeholder for depth
times = np.array([i / FRAME_RATE for i in range(len(ball_positions))])
pitch_idx = 0
for i, y in enumerate(y_coords):
if y < STUMPS_HEIGHT:
pitch_idx = i
break
pitch_point = (x_coords[pitch_idx], y_coords[pitch_idx], 0)
pitch_frame = detection_frames[pitch_idx]
impact_idx = None
for i in range(1, len(y_coords)):
if (y_coords[i] > STUMPS_HEIGHT and
abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y * STUMPS_HEIGHT / frame_height):
impact_idx = i
break
if impact_idx is None:
impact_idx = len(y_coords) - 1
impact_point = (x_coords[impact_idx], y_coords[impact_idx], 0)
impact_frame = detection_frames[impact_idx]
# Use cubic interpolation to avoid derivative mismatch
try:
fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
fz = interp1d(times[:impact_idx + 1], z_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
except ValueError as e:
# Fallback to linear if cubic fails (e.g., too few points)
fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
fz = interp1d(times[:impact_idx + 1], z_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
print(f"Warning: Cubic interpolation failed, falling back to linear. Error: {str(e)}")
t_full = np.linspace(times[0], times[impact_idx] + 0.5, 50)
full_trajectory = list(zip(fx(t_full), fy(t_full), fz(t_full)))
vis_trajectory = list(zip(x_coords, y_coords, z_coords))[:impact_idx + 1]
return full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, "Trajectory estimated"
def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point):
if not frames or not full_trajectory:
return "Error: No data", None, None, None
frame_height, frame_width = frames[0].shape[:2]
stumps_x = PITCH_LENGTH / 2
stumps_y = 0
stumps_width = STUMPS_WIDTH
pitch_x, pitch_y, _ = pitch_point
impact_x, impact_y, _ = impact_point
in_line_threshold = stumps_width / 2
if abs(pitch_x - stumps_x) > in_line_threshold:
return f"Not Out (Pitched outside line at x: {pitch_x:.1f})", full_trajectory, pitch_point, impact_point
if abs(impact_x - stumps_x) > in_line_threshold or impact_y < stumps_y:
return f"Not Out (Impact outside line at x: {impact_x:.1f})", full_trajectory, pitch_point, impact_point
hit_stumps = False
for x, y, z in full_trajectory:
if (abs(x - stumps_x) < in_line_threshold and
abs(y - stumps_y) < STUMPS_HEIGHT / 2):
hit_stumps = True
break
if hit_stumps:
if abs(x - stumps_x) < in_line_threshold * 0.1:
return f"Umpire's Call - Not Out", full_trajectory, pitch_point, impact_point
return f"Out (Ball hits stumps)", full_trajectory, pitch_point, impact_point
return f"Not Out (Missing stumps)", full_trajectory, pitch_point, impact_point
def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path, decision, frame_width, frame_height):
if not frames:
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
trajectory_points = np.array([[p[0] * frame_width / PITCH_LENGTH, frame_height - (p[1] * frame_height / (STUMPS_HEIGHT * 2))] for p in vis_trajectory], dtype=np.int32).reshape((-1, 1, 2))
for i, frame in enumerate(frames):
# Draw stumps outline (scaled back to pixel coordinates)
stumps_x = frame_width / 2
stumps_y = frame_height * 0.8
stumps_width_pixels = frame_width * (STUMPS_WIDTH / PITCH_LENGTH)
stumps_height_pixels = frame_height * (STUMPS_HEIGHT / (STUMPS_HEIGHT * 2))
cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)),
(int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y - stumps_height_pixels)),
(int(stumps_x - stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
cv2.line(frame, (int(stumps_x + stumps_width_pixels / 2), int(stumps_y - stumps_height_pixels)),
(int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
# Draw crease line
cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)),
(int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 0), 2)
if i in detection_frames and trajectory_points.size > 0:
idx = detection_frames.index(i) + 1
if idx <= len(trajectory_points):
cv2.polylines(frame, [trajectory_points[:idx]], False, (0, 0, 255), 2) # Blue trajectory
if pitch_point and i == pitch_frame:
x = pitch_point[0] * frame_width / PITCH_LENGTH
y = frame_height - (pitch_point[1] * frame_height / (STUMPS_HEIGHT * 2))
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 0), -1) # Green for pitching
cv2.putText(frame, "Pitching", (int(x) + 10, int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
if impact_point and i == impact_frame:
x = impact_point[0] * frame_width / PITCH_LENGTH
y = frame_height - (impact_point[1] * frame_height / (STUMPS_HEIGHT * 2))
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red for impact
cv2.putText(frame, "Impact", (int(x) + 10, int(y) + 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
if impact_point and i == impact_frame and "Out" in decision:
cv2.putText(frame, "Wickets", (int(stumps_x) - 50, int(stumps_y) - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 1) # Orange for wickets
for _ in range(SLOW_MOTION_FACTOR):
out.write(frame)
out.release()
return output_path
def draw_3d_scene(trajectory, pitch_point, impact_point, decision):
if not HAS_OPENGL:
return
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
glBegin(GL_LINES)
for i in range(len(trajectory) - 1):
glColor3f(0, 0, 1) # Blue trajectory
glVertex3f(trajectory[i][0], trajectory[i][1], trajectory[i][2])
glVertex3f(trajectory[i + 1][0], trajectory[i + 1][1], trajectory[i + 1][2])
glEnd()
glColor3f(0, 1, 0) # Green pitch
glBegin(GL_QUADS)
glVertex3f(0, 0, 0)
glVertex3f(PITCH_LENGTH, 0, 0)
glVertex3f(PITCH_LENGTH, 0, -1)
glVertex3f(0, 0, -1)
glEnd()
glColor3f(1, 1, 1) # White stumps
glBegin(GL_LINES)
glVertex3f(PITCH_LENGTH / 2 - STUMPS_WIDTH / 2, 0, 0)
glVertex3f(PITCH_LENGTH / 2 - STUMPS_WIDTH / 2, STUMPS_HEIGHT, 0)
glVertex3f(PITCH_LENGTH / 2 + STUMPS_WIDTH / 2, 0, 0)
glVertex3f(PITCH_LENGTH / 2 + STUMPS_WIDTH / 2, STUMPS_HEIGHT, 0)
glEnd()
if pitch_point:
glColor3f(0, 1, 0) # Green
glPushMatrix()
glTranslatef(pitch_point[0], pitch_point[1], pitch_point[2])
glutSolidSphere(0.1, 20, 20)
glPopMatrix()
if impact_point:
glColor3f(1, 0, 0) # Red
glPushMatrix()
glTranslatef(impact_point[0], impact_point[1], impact_point[2])
glutSolidSphere(0.1, 20, 20)
glPopMatrix()
if "Out" in decision:
glColor3f(1, 0.65, 0) # Orange
glRasterPos3f(PITCH_LENGTH / 2, STUMPS_HEIGHT, 0)
for char in "Wickets":
glutBitmapCharacter(GLUT_BITMAP_HELVETICA_12, ord(char))
display.flip()
def init_3d_window(width, height):
if not HAS_OPENGL:
return
pygame.init()
display.set_mode((width, height), DOUBLEBUF | OPENGL)
gluPerspective(45, (width / height), 0.1, 50.0)
glTranslatef(0.0, -5.0, -30)
glEnable(GL_DEPTH_TEST)
def drs_review(video):
frames, ball_positions, detection_frames, debug_log = process_video(video)
if not frames:
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, trajectory_log = estimate_trajectory_3d(ball_positions, detection_frames, frames)
decision, full_trajectory, pitch_point, impact_point = lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)
frame_height, frame_width = frames[0].shape[:2]
output_path = f"output_{uuid.uuid4()}.mp4"
slow_motion_path = generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path, decision, frame_width, frame_height)
if HAS_OPENGL:
init_3d_window(800, 600)
from OpenGL.GLUT import glutInit, glutSolidSphere
glutInit()
for _ in range(100): # Limited frames for demo
draw_3d_scene(full_trajectory, pitch_point, impact_point, decision)
event.pump()
debug_output = f"{debug_log}\n{trajectory_log}"
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
# Gradio interface
iface = gr.Interface(
fn=drs_review,
inputs=gr.Video(label="Upload Video Clip"),
outputs=[
gr.Textbox(label="DRS Decision and Debug Log"),
gr.Video(label="Slow-Motion Replay with 2D Annotations")
],
title="AI-Powered 3D DRS for LBW",
description="Upload a video clip for 3D DRS analysis with pitching (green), impact (red), and wickets (orange) visualization, and 2D annotated video output."
)
if __name__ == "__main__":
iface.launch()