Spaces:
Runtime error
Runtime error
import numpy as np | |
from segment_anything import sam_model_registry, SamPredictor | |
import cv2 | |
from scipy.optimize import curve_fit | |
class GolfTrajectoryPredictor: | |
def __init__(self, sam_checkpoint): | |
# Initialize SAM model | |
self.sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint) | |
self.predictor = SamPredictor(self.sam) | |
def get_club_metrics(self, frame, point): | |
"""Extract golf club angle and position using SAM""" | |
self.predictor.set_image(frame) | |
masks, scores, _ = self.predictor.predict( | |
point_coords=np.array([point]), | |
point_labels=np.array([1]) | |
) | |
# Get the best mask | |
club_mask = masks[np.argmax(scores)] | |
# Calculate club angle from mask | |
coords = np.column_stack(np.where(club_mask)) | |
if len(coords) < 2: | |
return None, None | |
# Fit line to get club angle | |
vx, vy, x0, y0 = cv2.fitLine(coords, cv2.DIST_L2, 0, 0.01, 0.01) | |
angle = np.arctan2(vy, vx) | |
return angle, (x0[0], y0[0]) | |
def physics_trajectory(self, t, v0, theta, h0, g=9.81): | |
"""Model the physics of projectile motion""" | |
# Convert angle to radians | |
theta_rad = np.radians(theta) | |
# Initial velocities | |
v0x = v0 * np.cos(theta_rad) | |
v0y = v0 * np.sin(theta_rad) | |
# Position equations | |
x = v0x * t | |
y = h0 + v0y * t - 0.5 * g * t**2 | |
return np.column_stack((x, y)) | |
def fit_trajectory(self, points, initial_height, club_angle=None): | |
"""Fit trajectory to user-selected points""" | |
times = np.linspace(0, 1, len(points)) | |
# Initial guess for parameters | |
# Use club angle if available to better estimate initial velocity direction | |
initial_theta = club_angle if club_angle is not None else 45 | |
# Calculate approximate initial velocity from first two points | |
if len(points) >= 2: | |
dx = points[1][0] - points[0][0] | |
dy = points[1][1] - points[0][1] | |
initial_v0 = np.sqrt(dx**2 + dy**2) / (times[1] - times[0]) | |
else: | |
initial_v0 = 50 # Default initial guess | |
# Fit physics model to points | |
try: | |
params, _ = curve_fit( | |
lambda t, v0, theta: self.physics_trajectory(t, v0, theta, initial_height), | |
times, | |
points, | |
p0=[initial_v0, initial_theta] | |
) | |
return params | |
except RuntimeError: | |
return None | |
def predict_full_trajectory(self, video_path, user_selected_points): | |
"""Main function to predict full trajectory""" | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# Get initial frame for club analysis | |
ret, first_frame = cap.read() | |
if not ret: | |
return None | |
# Get club metrics from first frame | |
club_angle, club_pos = self.get_club_metrics(first_frame, user_selected_points[0]) | |
# Convert pixel coordinates to physical space | |
# This requires camera calibration in practice | |
scale_factor = 0.01 # meters per pixel | |
physical_points = np.array(user_selected_points) * scale_factor | |
# Fit trajectory | |
params = self.fit_trajectory( | |
physical_points, | |
initial_height=physical_points[0][1], | |
club_angle=club_angle | |
) | |
if params is None: | |
return None | |
# Generate full trajectory | |
t = np.linspace(0, len(user_selected_points)/fps, 100) | |
full_trajectory = self.physics_trajectory(t, params[0], params[1], physical_points[0][1]) | |
# Convert back to pixel coordinates | |
pixel_trajectory = full_trajectory / scale_factor | |
return pixel_trajectory, params | |
def visualize_trajectory(self, frame, trajectory, user_points): | |
"""Visualize the predicted trajectory and user-selected points""" | |
vis_frame = frame.copy() | |
# Draw predicted trajectory | |
trajectory = trajectory.astype(np.int32) | |
for i in range(len(trajectory)-1): | |
cv2.line(vis_frame, tuple(trajectory[i]), tuple(trajectory[i+1]), | |
(0, 255, 0), 2) | |
# Draw user-selected points | |
for point in user_points: | |
cv2.circle(vis_frame, tuple(map(int, point)), 5, (255, 0, 0), -1) | |
return vis_frame |