Spaces:
Runtime error
Runtime error
David Ko
commited on
Commit
·
1aa3dcb
1
Parent(s):
aa7ad0c
Add vector DB save feature to Gradio UI for object detection results
Browse files
app.py
CHANGED
@@ -3,6 +3,11 @@ import torch
|
|
3 |
from PIL import Image
|
4 |
import numpy as np
|
5 |
import os
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Model initialization
|
8 |
print("Loading models... This may take a moment.")
|
@@ -63,10 +68,52 @@ import torch
|
|
63 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
print(f"Using device: {device}")
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# Define model inference functions
|
67 |
def process_yolo(image):
|
68 |
if yolo_model is None:
|
69 |
-
return None, "YOLOv8 model not loaded"
|
70 |
|
71 |
# Measure inference time
|
72 |
import time
|
@@ -91,13 +138,23 @@ def process_yolo(image):
|
|
91 |
|
92 |
# Format detection results
|
93 |
detections = []
|
|
|
|
|
94 |
for box in boxes:
|
95 |
class_id = int(box.cls[0].item())
|
96 |
class_name = class_names[class_id]
|
97 |
confidence = round(box.conf[0].item(), 2)
|
98 |
bbox = box.xyxy[0].tolist()
|
99 |
bbox = [round(x) for x in bbox]
|
|
|
100 |
detections.append("{}: {} at {}".format(class_name, confidence, bbox))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# Calculate inference time
|
103 |
inference_time = time.time() - start_time
|
@@ -108,6 +165,8 @@ def process_yolo(image):
|
|
108 |
detection_text = "\n".join(detections) if detections else "No objects detected"
|
109 |
detection_text += performance_info
|
110 |
|
|
|
|
|
111 |
return result_image, detection_text
|
112 |
|
113 |
def process_detr(image):
|
@@ -272,11 +331,27 @@ with gr.Blocks(title="Object Detection Demo") as demo:
|
|
272 |
with gr.Column():
|
273 |
vit_text = gr.Textbox(label="ViT Classification Results")
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
# Set up event handlers
|
276 |
-
yolo_button.click(
|
277 |
fn=process_yolo,
|
278 |
inputs=input_image,
|
279 |
-
outputs=[yolo_output, yolo_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
)
|
281 |
|
282 |
detr_button.click(
|
|
|
3 |
from PIL import Image
|
4 |
import numpy as np
|
5 |
import os
|
6 |
+
import requests
|
7 |
+
import json
|
8 |
+
import base64
|
9 |
+
from io import BytesIO
|
10 |
+
import uuid
|
11 |
|
12 |
# Model initialization
|
13 |
print("Loading models... This may take a moment.")
|
|
|
68 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
69 |
print(f"Using device: {device}")
|
70 |
|
71 |
+
# 벡터 DB에 객체 저장 함수
|
72 |
+
def save_objects_to_vector_db(image, detection_results):
|
73 |
+
if image is None or detection_results is None:
|
74 |
+
return "이미지나 객체 인식 결과가 없습니다."
|
75 |
+
|
76 |
+
try:
|
77 |
+
# 이미지를 base64로 인코딩
|
78 |
+
buffered = BytesIO()
|
79 |
+
image.save(buffered, format="JPEG")
|
80 |
+
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
81 |
+
|
82 |
+
# 객체 정보 추출
|
83 |
+
objects = []
|
84 |
+
for obj in detection_results['objects']:
|
85 |
+
objects.append({
|
86 |
+
"class": obj['class'],
|
87 |
+
"confidence": obj['confidence'],
|
88 |
+
"bbox": obj['bbox']
|
89 |
+
})
|
90 |
+
|
91 |
+
# API 요청 데이터 구성
|
92 |
+
data = {
|
93 |
+
"image": img_str,
|
94 |
+
"objects": objects,
|
95 |
+
"image_id": str(uuid.uuid4())
|
96 |
+
}
|
97 |
+
|
98 |
+
# API 호출
|
99 |
+
response = requests.post(
|
100 |
+
"http://localhost:7860/api/add-detected-objects",
|
101 |
+
json=data
|
102 |
+
)
|
103 |
+
|
104 |
+
if response.status_code == 200:
|
105 |
+
result = response.json()
|
106 |
+
return f"벡터 DB에 {len(objects)}개 객체 저장 성공! 저장된 객체 ID: {', '.join(result.get('object_ids', [])[:3])}..."
|
107 |
+
else:
|
108 |
+
return f"저장 실패: {response.text}"
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
return f"오류 발생: {str(e)}"
|
112 |
+
|
113 |
# Define model inference functions
|
114 |
def process_yolo(image):
|
115 |
if yolo_model is None:
|
116 |
+
return None, "YOLOv8 model not loaded", None
|
117 |
|
118 |
# Measure inference time
|
119 |
import time
|
|
|
138 |
|
139 |
# Format detection results
|
140 |
detections = []
|
141 |
+
detection_objects = {'objects': []}
|
142 |
+
|
143 |
for box in boxes:
|
144 |
class_id = int(box.cls[0].item())
|
145 |
class_name = class_names[class_id]
|
146 |
confidence = round(box.conf[0].item(), 2)
|
147 |
bbox = box.xyxy[0].tolist()
|
148 |
bbox = [round(x) for x in bbox]
|
149 |
+
|
150 |
detections.append("{}: {} at {}".format(class_name, confidence, bbox))
|
151 |
+
|
152 |
+
# 벡터 DB 저장용 객체 정보 추가
|
153 |
+
detection_objects['objects'].append({
|
154 |
+
'class': class_name,
|
155 |
+
'confidence': confidence,
|
156 |
+
'bbox': bbox
|
157 |
+
})
|
158 |
|
159 |
# Calculate inference time
|
160 |
inference_time = time.time() - start_time
|
|
|
165 |
detection_text = "\n".join(detections) if detections else "No objects detected"
|
166 |
detection_text += performance_info
|
167 |
|
168 |
+
return result_image, detection_text, detection_objects
|
169 |
+
|
170 |
return result_image, detection_text
|
171 |
|
172 |
def process_detr(image):
|
|
|
331 |
with gr.Column():
|
332 |
vit_text = gr.Textbox(label="ViT Classification Results")
|
333 |
|
334 |
+
# 벡터 DB 저장 버튼 및 결과 표시
|
335 |
+
with gr.Row():
|
336 |
+
with gr.Column():
|
337 |
+
save_to_db_button = gr.Button("YOLOv8 인식 결과를 벡터 DB에 저장", variant="primary")
|
338 |
+
save_result = gr.Textbox(label="벡터 DB 저장 결과")
|
339 |
+
|
340 |
+
# 객체 인식 결과 저장용 상태 변수
|
341 |
+
detection_state = gr.State(None)
|
342 |
+
|
343 |
# Set up event handlers
|
344 |
+
yolo_result = yolo_button.click(
|
345 |
fn=process_yolo,
|
346 |
inputs=input_image,
|
347 |
+
outputs=[yolo_output, yolo_text, detection_state]
|
348 |
+
)
|
349 |
+
|
350 |
+
# 벡터 DB 저장 버튼 이벤트 핸들러
|
351 |
+
save_to_db_button.click(
|
352 |
+
fn=save_objects_to_vector_db,
|
353 |
+
inputs=[input_image, detection_state],
|
354 |
+
outputs=save_result
|
355 |
)
|
356 |
|
357 |
detr_button.click(
|