ShinSeungJ commited on
Commit
7a87091
·
verified ·
1 Parent(s): eadeab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -20
app.py CHANGED
@@ -1,3 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import io
3
  from ultralytics import YOLO
@@ -6,15 +84,28 @@ import numpy as np
6
  from PIL import Image
7
  import json
8
 
9
- # Load your custom YOLO model
10
- model = YOLO("fentanyl_oft.pt")
11
- # model = YOLO("avatar_ckpt.pt")
12
 
13
- def detect_keypoints(image):
14
  """
15
  Run YOLO inference and return keypoints data
 
 
 
16
  """
17
  try:
 
 
 
 
 
 
 
 
 
 
18
  # Convert PIL Image to numpy array
19
  if isinstance(image, Image.Image):
20
  image_np = np.array(image)
@@ -22,13 +113,13 @@ def detect_keypoints(image):
22
  else:
23
  image_cv2 = image
24
 
25
- # Run inference
26
  results = model.predict(
27
  source=image_cv2,
28
  conf=0.05,
29
  iou=0.7,
30
- max_det=1,
31
- imgsz=1440,
32
  device='cpu',
33
  verbose=False
34
  )
@@ -40,35 +131,54 @@ def detect_keypoints(image):
40
  kpts = result.keypoints.xy.cpu().numpy()
41
  conf = result.keypoints.conf.cpu().numpy()
42
 
43
- for i in range(kpts.shape[1]):
44
- if i < len(kpts[0]):
45
- x, y = kpts[0][i]
46
- confidence = conf[0][i] if i < len(conf[0]) else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  keypoints_data.append({
48
- "id": i,
49
- "x": float(x),
50
- "y": float(y),
51
- "confidence": float(confidence)
52
  })
53
 
54
  return {
55
  "success": True,
 
56
  "keypoints": keypoints_data,
57
  "image_width": image_cv2.shape[1],
58
  "image_height": image_cv2.shape[0],
59
- "num_keypoints": len(keypoints_data)
 
60
  }
61
 
62
  except Exception as e:
63
- return {"success": False, "error": str(e)}
64
 
65
- # Create Gradio interface with API access enabled
66
  iface = gr.Interface(
67
  fn=detect_keypoints,
68
- inputs=gr.Image(type="pil"),
 
 
 
69
  outputs=gr.JSON(),
70
  title="YOLO Keypoint Detection",
71
- description="Upload an image to detect keypoints using custom YOLO model",
72
  api_name="predict" # This enables API access at /api/predict
73
  )
74
 
 
1
+ # import gradio as gr
2
+ # import io
3
+ # from ultralytics import YOLO
4
+ # import cv2
5
+ # import numpy as np
6
+ # from PIL import Image
7
+ # import json
8
+
9
+ # # Load your custom YOLO model
10
+ # model = YOLO("fentanyl_oft.pt")
11
+ # # model = YOLO("avatar_ckpt.pt")
12
+
13
+ # def detect_keypoints(image):
14
+ # """
15
+ # Run YOLO inference and return keypoints data
16
+ # """
17
+ # try:
18
+ # # Convert PIL Image to numpy array
19
+ # if isinstance(image, Image.Image):
20
+ # image_np = np.array(image)
21
+ # image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
22
+ # else:
23
+ # image_cv2 = image
24
+
25
+ # # Run inference
26
+ # results = model.predict(
27
+ # source=image_cv2,
28
+ # conf=0.05,
29
+ # iou=0.7,
30
+ # max_det=1,
31
+ # imgsz=1440,
32
+ # device='cpu',
33
+ # verbose=False
34
+ # )
35
+
36
+ # keypoints_data = []
37
+ # if results and len(results) > 0:
38
+ # result = results[0]
39
+ # if result.keypoints is not None:
40
+ # kpts = result.keypoints.xy.cpu().numpy()
41
+ # conf = result.keypoints.conf.cpu().numpy()
42
+
43
+ # for i in range(kpts.shape[1]):
44
+ # if i < len(kpts[0]):
45
+ # x, y = kpts[0][i]
46
+ # confidence = conf[0][i] if i < len(conf[0]) else 0.0
47
+ # keypoints_data.append({
48
+ # "id": i,
49
+ # "x": float(x),
50
+ # "y": float(y),
51
+ # "confidence": float(confidence)
52
+ # })
53
+
54
+ # return {
55
+ # "success": True,
56
+ # "keypoints": keypoints_data,
57
+ # "image_width": image_cv2.shape[1],
58
+ # "image_height": image_cv2.shape[0],
59
+ # "num_keypoints": len(keypoints_data)
60
+ # }
61
+
62
+ # except Exception as e:
63
+ # return {"success": False, "error": str(e)}
64
+
65
+ # # Create Gradio interface with API access enabled
66
+ # iface = gr.Interface(
67
+ # fn=detect_keypoints,
68
+ # inputs=gr.Image(type="pil"),
69
+ # outputs=gr.JSON(),
70
+ # title="YOLO Keypoint Detection",
71
+ # description="Upload an image to detect keypoints using custom YOLO model",
72
+ # api_name="predict" # This enables API access at /api/predict
73
+ # )
74
+
75
+ # # Launch with API enabled
76
+ # if __name__ == "__main__":
77
+ # iface.launch(share=False)
78
+
79
  import gradio as gr
80
  import io
81
  from ultralytics import YOLO
 
84
  from PIL import Image
85
  import json
86
 
87
+ # Load both models
88
+ single_animal_model = YOLO("fentanyl_oft.pt") # Single animal model
89
+ multi_animal_model = YOLO("avatar_ckpt.pt") # Multi-animal model
90
 
91
+ def detect_keypoints(image, mode="single"):
92
  """
93
  Run YOLO inference and return keypoints data
94
+ Args:
95
+ image: PIL Image
96
+ mode: "single" or "multi" to determine which model to use
97
  """
98
  try:
99
+ # Select model and parameters based on mode
100
+ if mode == "multi":
101
+ model = multi_animal_model
102
+ imgsz = 1504
103
+ max_det = 5
104
+ else: # default to single
105
+ model = single_animal_model
106
+ imgsz = 1440
107
+ max_det = 1
108
+
109
  # Convert PIL Image to numpy array
110
  if isinstance(image, Image.Image):
111
  image_np = np.array(image)
 
113
  else:
114
  image_cv2 = image
115
 
116
+ # Run inference with mode-specific parameters
117
  results = model.predict(
118
  source=image_cv2,
119
  conf=0.05,
120
  iou=0.7,
121
+ max_det=max_det,
122
+ imgsz=imgsz,
123
  device='cpu',
124
  verbose=False
125
  )
 
131
  kpts = result.keypoints.xy.cpu().numpy()
132
  conf = result.keypoints.conf.cpu().numpy()
133
 
134
+ # Handle multiple detections (for multi-animal mode)
135
+ for detection_idx in range(kpts.shape[0]):
136
+ detection_keypoints = []
137
+ for i in range(kpts.shape[1]):
138
+ if i < len(kpts[detection_idx]):
139
+ x, y = kpts[detection_idx][i]
140
+ confidence = conf[detection_idx][i] if i < len(conf[detection_idx]) else 0.0
141
+ detection_keypoints.append({
142
+ "id": i,
143
+ "x": float(x),
144
+ "y": float(y),
145
+ "confidence": float(confidence)
146
+ })
147
+
148
+ # For single animal mode, flatten the structure
149
+ if mode == "single":
150
+ keypoints_data = detection_keypoints
151
+ break # Only take first detection
152
+ else:
153
+ # For multi-animal mode, keep detection structure
154
  keypoints_data.append({
155
+ "detection_id": detection_idx,
156
+ "keypoints": detection_keypoints
 
 
157
  })
158
 
159
  return {
160
  "success": True,
161
+ "mode": mode,
162
  "keypoints": keypoints_data,
163
  "image_width": image_cv2.shape[1],
164
  "image_height": image_cv2.shape[0],
165
+ "num_detections": len(keypoints_data) if mode == "multi" else (1 if keypoints_data else 0),
166
+ "num_keypoints": len(keypoints_data) if mode == "single" else sum(len(det["keypoints"]) for det in keypoints_data) if mode == "multi" else 0
167
  }
168
 
169
  except Exception as e:
170
+ return {"success": False, "error": str(e), "mode": mode}
171
 
172
+ # Create Gradio interface with mode parameter
173
  iface = gr.Interface(
174
  fn=detect_keypoints,
175
+ inputs=[
176
+ gr.Image(type="pil"),
177
+ gr.Dropdown(choices=["single", "multi"], value="single", label="Detection Mode")
178
+ ],
179
  outputs=gr.JSON(),
180
  title="YOLO Keypoint Detection",
181
+ description="Upload an image to detect keypoints using custom YOLO model. Choose single or multi-animal mode.",
182
  api_name="predict" # This enables API access at /api/predict
183
  )
184