detect_kpts / app.py
ShinSeungJ's picture
Update app.py
c2a39aa verified
raw
history blame
2.29 kB
import gradio as gr
import io
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import json
# Load your custom YOLO model
model = YOLO("fentanyl_oft.pt")
def detect_keypoints(image):
"""
Run YOLO inference and return keypoints data
"""
try:
# Convert PIL Image to numpy array
if isinstance(image, Image.Image):
image_np = np.array(image)
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
else:
image_cv2 = image
# Run inference
results = model.predict(
source=image_cv2,
conf=0.05,
iou=0.7,
max_det=1,
imgsz=1440,
device='cpu',
verbose=False
)
keypoints_data = []
if results and len(results) > 0:
result = results[0]
if result.keypoints is not None:
kpts = result.keypoints.xy.cpu().numpy()
conf = result.keypoints.conf.cpu().numpy()
for i in range(kpts.shape[1]):
if i < len(kpts[0]):
x, y = kpts[0][i]
confidence = conf[0][i] if i < len(conf[0]) else 0.0
keypoints_data.append({
"id": i,
"x": float(x),
"y": float(y),
"confidence": float(confidence)
})
return {
"success": True,
"keypoints": keypoints_data,
"image_width": image_cv2.shape[1],
"image_height": image_cv2.shape[0],
"num_keypoints": len(keypoints_data)
}
except Exception as e:
return {"success": False, "error": str(e)}
# Create Gradio interface with API access enabled
iface = gr.Interface(
fn=detect_keypoints,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(),
title="YOLO Keypoint Detection",
description="Upload an image to detect keypoints using custom YOLO model",
api_name="predict" # This enables API access at /api/predict
)
# Launch with API enabled
if __name__ == "__main__":
iface.launch(share=False)