File size: 5,322 Bytes
68f0979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce3cabf
 
68f0979
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import os

def create_gradient_circle(radius, color=(0, 255, 0), alpha=0.7):
    """Create a gradient circle with a glowing effect"""
    size = radius * 2 + 1
    center = (radius, radius)
    circle_img = np.zeros((size, size, 4), dtype=np.uint8)
    
    for r in range(radius + 1):
        alpha_r = alpha * (1 - (r/radius)**2)
        cv2.circle(circle_img, center, r, (*color, int(255 * alpha_r)), -1)
    
    return circle_img

def draw_advanced_keypoint(frame, center, keypoint_id, conf, radius=12):
    """Draw an advanced technical keypoint with class ID"""
    x, y = center
    color = (0, 255, 0)
    
    gradient = create_gradient_circle(radius + 4, color)
    
    gx1, gy1 = max(0, x-radius-4), max(0, y-radius-4)
    gx2, gy2 = min(frame.shape[1], x+radius+5), min(frame.shape[0], y+radius+5)
    
    if gx1 < gx2 and gy1 < gy2:
        roi = frame[gy1:gy2, gx1:gx2]
        gradient_roi = gradient[:gy2-gy1, :gx2-gx1]
        alpha = gradient_roi[:, :, 3:4] / 255.0
        roi[:] = roi * (1 - alpha) + gradient_roi[:, :, :3] * alpha
    
    cv2.circle(frame, center, radius, color, 2)
    cv2.circle(frame, center, radius-2, color, -1)
    cv2.circle(frame, center, radius-1, (255, 255, 255), 1)
    
    label_text = f"{keypoint_id}:{conf:.2f}"
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1.0
    thickness = 2
    
    (text_w, text_h), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
    
    margin = 2
    bg_pts = np.array([
        [x - text_w//2 - margin, y - radius - text_h - margin*2],
        [x + text_w//2 + margin, y - radius - text_h - margin*2],
        [x + text_w//2 + margin, y - radius - margin],
        [x + margin, y - radius + margin],
        [x - margin, y - radius + margin],
        [x - text_w//2 - margin, y - radius - margin],
    ], np.int32)
    
    cv2.fillPoly(frame, [bg_pts], (0, 0, 0))
    cv2.polylines(frame, [bg_pts], True, color, 1)
    
    cv2.putText(frame, label_text,
                (x - text_w//2, y - radius - margin*2),
                font, font_scale, (255, 255, 255), thickness)

def process_image(input_image, conf_threshold=0.5):
    """Process image for pose estimation"""
    # Load model
    model_path = "HockeyRink.pt"
    model = YOLO(model_path)
    
    # Convert Gradio image to CV2 format if necessary
    if isinstance(input_image, str):
        frame = cv2.imread(input_image)
    else:
        frame = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
    
    # Make prediction
    results = model.predict(frame, conf=conf_threshold)
    
    # Create copy for annotation
    annotated_frame = frame.copy()
    
    # Process each detection
    for result in results:
        if result.keypoints is not None:
            keypoints = result.keypoints.data[0]
            
            # Draw class label
            if hasattr(result, 'boxes') and len(result.boxes.cls) > 0:
                class_id = int(result.boxes.cls[0])
                class_conf = float(result.boxes.conf[0])
                
                if len(keypoints) > 0:
                    text_x = int(min(kp[0] for kp in keypoints))
                    text_y = int(min(kp[1] for kp in keypoints)) - 40
                    
                    main_label = f"Class ID:{class_id} ({class_conf:.2f})"
                    font = cv2.FONT_HERSHEY_SIMPLEX
                    font_scale = 1.2
                    thickness = 2
                    
                    (text_w, text_h), baseline = cv2.getTextSize(main_label, font, font_scale, thickness)
                    
                    cv2.rectangle(annotated_frame,
                                (text_x - 5, text_y - text_h - 5),
                                (text_x + text_w + 5, text_y + 5),
                                (0, 0, 0), -1)
                    
                    cv2.putText(annotated_frame, main_label,
                              (text_x, text_y),
                              font, font_scale,
                              (255, 255, 255), thickness)
            
            # Draw keypoints
            for idx, kp in enumerate(keypoints):
                x, y, conf = int(kp[0]), int(kp[1]), kp[2]
                if conf > conf_threshold:
                    draw_advanced_keypoint(
                        annotated_frame,
                        (x, y),
                        idx,
                        conf
                    )
    
    # Convert back to RGB for Gradio
    return cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)

# Create Gradio interface
def create_interface():
    examples = [
        ["exm_1.jpg"],
        ["exm_2.jpg"],
        ["exm_3.jpg"],
        ["exm_4.jpg"],
    ]
    
    iface = gr.Interface(
        fn=process_image,
        inputs=[
            gr.Image(type="numpy", label="Input Image"),
        ],
        outputs=gr.Image(type="numpy", label="Detected Poses"),
        title="HockeyRink: A Model for Precise Ice Hockey Rink Keypoint Mapping and Analytics",
        description="Upload an image of ice hockey to detect keypoints on the rink.",
        examples=examples,
        theme=gr.themes.Base()
    )
    return iface

if __name__ == "__main__":
    iface = create_interface()
    iface.launch()