Spaces:
Sleeping
Sleeping
File size: 5,322 Bytes
68f0979 ce3cabf 68f0979 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import os
def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7):
"""Create a gradient circle with a glowing effect"""
size = radius * 2 + 1
center = (radius, radius)
circle_img = np.zeros((size, size, 4), dtype=np.uint8)
for r in range(radius + 1):
alpha_r = alpha * (1 - (r/radius)**2)
cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1)
return circle_img
def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12):
"""Draw an advanced technical keypoint with class ID"""
x, y = center
color = (0, 255, 0)
gradient = create_gradient_circle(radius + 4, color)
gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4)
gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5)
if gx1 < gx2 and gy1 < gy2:
roi = frame[gy1:gy2, gx1:gx2]
gradient_roi = gradient[:gy2-gy1, :gx2-gx1]
alpha = gradient_roi[:, :, 3:4] / 255.0
roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha
cv2.circle(frame, center, radius, color, 2)
cv2.circle(frame, center, radius-2, color, -1)
cv2.circle(frame, center, radius-1, (255, 255, 255), 1)
label_text = f"{keypoint_id}:{conf:.2f}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.0
thickness = 2
(text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
margin = 2
bg_pts = np.array([
[x - text_w//2 - margin, y - radius - text_h - margin*2],
[x + text_w//2 + margin, y - radius - text_h - margin*2],
[x + text_w//2 + margin, y - radius - margin],
[x + margin, y - radius + margin],
[x - margin, y - radius + margin],
[x - text_w//2 - margin, y - radius - margin],
], np.int32)
cv2.fillPoly(frame, [bg_pts], (0, 0, 0))
cv2.polylines(frame, [bg_pts], True, color, 1)
cv2.putText(frame, label_text,
(x - text_w//2, y - radius - margin*2),
font, font_scale, (255, 255, 255), thickness)
def process_image(input_image, conf_threshold=0.5):
"""Process image for pose estimation"""
# Load model
model_path = "HockeyRink.pt"
model = YOLO(model_path)
# Convert Gradio image to CV2 format if necessary
if isinstance(input_image, str):
frame = cv2.imread(input_image)
else:
frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# Make prediction
results = model.predict(frame, conf=conf_threshold)
# Create copy for annotation
annotated_frame = frame.copy()
# Process each detection
for result in results:
if result.keypoints is not None:
keypoints = result.keypoints.data[0]
# Draw class label
if hasattr(result, 'boxes') and len(result.boxes.cls) > 0:
class_id = int(result.boxes.cls[0])
class_conf = float(result.boxes.conf[0])
if len(keypoints) > 0:
text_x = int(min(kp[0] for kp in keypoints))
text_y = int(min(kp[1] for kp in keypoints)) - 40
main_label = f"Class ID:{class_id} ({class_conf:.2f})"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1.2
thickness = 2
(text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness)
cv2.rectangle(annotated_frame,
(text_x - 5, text_y - text_h - 5),
(text_x + text_w + 5, text_y + 5),
(0, 0, 0), -1)
cv2.putText(annotated_frame, main_label,
(text_x, text_y),
font, font_scale,
(255, 255, 255), thickness)
# Draw keypoints
for idx, kp in enumerate(keypoints):
x, y, conf = int(kp[0]), int(kp[1]), kp[2]
if conf > conf_threshold:
draw_advanced_keypoint(
annotated_frame,
(x, y),
idx,
conf
)
# Convert back to RGB for Gradio
return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
# Create Gradio interface
def create_interface():
examples = [
["exm_1.jpg"],
["exm_2.jpg"],
["exm_3.jpg"],
["exm_4.jpg"],
]
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="numpy", label="Input Image"),
],
outputs=gr.Image(type="numpy", label="Detected Poses"),
title="HockeyRink: A Model for Precise Ice Hockey Rink Keypoint Mapping and Analytics",
description="Upload an image of ice hockey to detect keypoints on the rink.",
examples=examples,
theme=gr.themes.Base()
)
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch() |