golf_tracking / v1.py
rehctiw25's picture
Upload folder using huggingface_hub
013216e verified
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
import cv2
from scipy.optimize import curve_fit
class GolfTrajectoryPredictor:
def __init__(self, sam_checkpoint):
# Initialize SAM model
self.sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
self.predictor = SamPredictor(self.sam)
def get_club_metrics(self, frame, point):
"""Extract golf club angle and position using SAM"""
self.predictor.set_image(frame)
masks, scores, _ = self.predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1])
)
# Get the best mask
club_mask = masks[np.argmax(scores)]
# Calculate club angle from mask
coords = np.column_stack(np.where(club_mask))
if len(coords) < 2:
return None, None
# Fit line to get club angle
vx, vy, x0, y0 = cv2.fitLine(coords, cv2.DIST_L2, 0, 0.01, 0.01)
angle = np.arctan2(vy, vx)
return angle, (x0[0], y0[0])
def physics_trajectory(self, t, v0, theta, h0, g=9.81):
"""Model the physics of projectile motion"""
# Convert angle to radians
theta_rad = np.radians(theta)
# Initial velocities
v0x = v0 * np.cos(theta_rad)
v0y = v0 * np.sin(theta_rad)
# Position equations
x = v0x * t
y = h0 + v0y * t - 0.5 * g * t**2
return np.column_stack((x, y))
def fit_trajectory(self, points, initial_height, club_angle=None):
"""Fit trajectory to user-selected points"""
times = np.linspace(0, 1, len(points))
# Initial guess for parameters
# Use club angle if available to better estimate initial velocity direction
initial_theta = club_angle if club_angle is not None else 45
# Calculate approximate initial velocity from first two points
if len(points) >= 2:
dx = points[1][0] - points[0][0]
dy = points[1][1] - points[0][1]
initial_v0 = np.sqrt(dx**2 + dy**2) / (times[1] - times[0])
else:
initial_v0 = 50 # Default initial guess
# Fit physics model to points
try:
params, _ = curve_fit(
lambda t, v0, theta: self.physics_trajectory(t, v0, theta, initial_height),
times,
points,
p0=[initial_v0, initial_theta]
)
return params
except RuntimeError:
return None
def predict_full_trajectory(self, video_path, user_selected_points):
"""Main function to predict full trajectory"""
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
# Get initial frame for club analysis
ret, first_frame = cap.read()
if not ret:
return None
# Get club metrics from first frame
club_angle, club_pos = self.get_club_metrics(first_frame, user_selected_points[0])
# Convert pixel coordinates to physical space
# This requires camera calibration in practice
scale_factor = 0.01 # meters per pixel
physical_points = np.array(user_selected_points) * scale_factor
# Fit trajectory
params = self.fit_trajectory(
physical_points,
initial_height=physical_points[0][1],
club_angle=club_angle
)
if params is None:
return None
# Generate full trajectory
t = np.linspace(0, len(user_selected_points)/fps, 100)
full_trajectory = self.physics_trajectory(t, params[0], params[1], physical_points[0][1])
# Convert back to pixel coordinates
pixel_trajectory = full_trajectory / scale_factor
return pixel_trajectory, params
def visualize_trajectory(self, frame, trajectory, user_points):
"""Visualize the predicted trajectory and user-selected points"""
vis_frame = frame.copy()
# Draw predicted trajectory
trajectory = trajectory.astype(np.int32)
for i in range(len(trajectory)-1):
cv2.line(vis_frame, tuple(trajectory[i]), tuple(trajectory[i+1]),
(0, 255, 0), 2)
# Draw user-selected points
for point in user_points:
cv2.circle(vis_frame, tuple(map(int, point)), 5, (255, 0, 0), -1)
return vis_frame