ShinSeungJ commited on
Commit
ba17a5a
·
verified ·
1 Parent(s): 98ccaaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -27
app.py CHANGED
@@ -4,14 +4,13 @@ import cv2
4
  import numpy as np
5
  from PIL import Image
6
  import json
 
7
 
8
  # Load your custom YOLO model
9
  model = YOLO("fentanyl_oft.pt")
10
 
11
  def detect_keypoints(image):
12
- """
13
- Run YOLO inference and return keypoints data
14
- """
15
  try:
16
  # Convert PIL Image to numpy array
17
  if isinstance(image, Image.Image):
@@ -20,7 +19,7 @@ def detect_keypoints(image):
20
  else:
21
  image_cv2 = image
22
 
23
- # Run inference with your exact parameters
24
  results = model.predict(
25
  source=image_cv2,
26
  conf=0.05,
@@ -32,21 +31,21 @@ def detect_keypoints(image):
32
  )
33
 
34
  keypoints_data = []
35
-
36
- for result in results:
37
- if result.keypoints is not None and len(result.keypoints.data) > 0:
38
- # Extract keypoints from the first detection
39
- kpts = result.keypoints.data[0] # Shape: [num_keypoints, 3] (x, y, confidence)
40
-
41
- h, w = image_cv2.shape[:2]
42
 
43
- for i, (x, y, conf) in enumerate(kpts):
44
- if conf > 0.3: # Confidence threshold for individual keypoints
 
 
45
  keypoints_data.append({
46
- "id": int(i),
47
  "x": float(x),
48
- "y": float(y),
49
- "confidence": float(conf)
50
  })
51
 
52
  return {
@@ -58,20 +57,28 @@ def detect_keypoints(image):
58
  }
59
 
60
  except Exception as e:
61
- return {
62
- "success": False,
63
- "error": str(e),
64
- "keypoints": []
65
- }
 
 
 
 
 
 
 
 
 
66
 
67
- # Create simple Gradio interface
68
- interface = gr.Interface(
69
  fn=detect_keypoints,
70
  inputs=gr.Image(type="pil"),
71
  outputs=gr.JSON(),
72
- title="YOLO Keypoint Detection",
73
- description="Upload an image to detect keypoints using custom YOLO model"
74
  )
75
 
76
- if __name__ == "__main__":
77
- interface.launch()
 
4
  import numpy as np
5
  from PIL import Image
6
  import json
7
+ from fastapi import FastAPI
8
 
9
  # Load your custom YOLO model
10
  model = YOLO("fentanyl_oft.pt")
11
 
12
  def detect_keypoints(image):
13
+ # Your existing detection code...
 
 
14
  try:
15
  # Convert PIL Image to numpy array
16
  if isinstance(image, Image.Image):
 
19
  else:
20
  image_cv2 = image
21
 
22
+ # Run inference
23
  results = model.predict(
24
  source=image_cv2,
25
  conf=0.05,
 
31
  )
32
 
33
  keypoints_data = []
34
+ if results and len(results) > 0:
35
+ result = results[0]
36
+ if result.keypoints is not None:
37
+ kpts = result.keypoints.xy.cpu().numpy()
38
+ conf = result.keypoints.conf.cpu().numpy()
 
 
39
 
40
+ for i in range(kpts.shape[1]):
41
+ if i < len(kpts[0]):
42
+ x, y = kpts[0][i]
43
+ confidence = conf[0][i] if i < len(conf[0]) else 0.0
44
  keypoints_data.append({
45
+ "id": i,
46
  "x": float(x),
47
+ "y": float(y),
48
+ "confidence": float(confidence)
49
  })
50
 
51
  return {
 
57
  }
58
 
59
  except Exception as e:
60
+ return {"success": False, "error": str(e)}
61
+
62
+ # Create FastAPI app for API endpoints
63
+ app = FastAPI()
64
+
65
+ @app.post("/api/detect")
66
+ async def api_detect_keypoints(file: bytes):
67
+ try:
68
+ # Convert bytes to PIL Image
69
+ image = Image.open(io.BytesIO(file))
70
+ result = detect_keypoints(image)
71
+ return result
72
+ except Exception as e:
73
+ return {"success": False, "error": str(e)}
74
 
75
+ # Create Gradio interface
76
+ iface = gr.Interface(
77
  fn=detect_keypoints,
78
  inputs=gr.Image(type="pil"),
79
  outputs=gr.JSON(),
80
+ title="YOLO Keypoint Detection"
 
81
  )
82
 
83
+ # Mount Gradio on FastAPI
84
+ app = gr.mount_gradio_app(app, iface, path="/")