David Ko commited on
Commit
1aa3dcb
·
1 Parent(s): aa7ad0c

Add vector DB save feature to Gradio UI for object detection results

Browse files
Files changed (1) hide show
  1. app.py +78 -3
app.py CHANGED
@@ -3,6 +3,11 @@ import torch
3
  from PIL import Image
4
  import numpy as np
5
  import os
 
 
 
 
 
6
 
7
  # Model initialization
8
  print("Loading models... This may take a moment.")
@@ -63,10 +68,52 @@ import torch
63
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
  print(f"Using device: {device}")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Define model inference functions
67
  def process_yolo(image):
68
  if yolo_model is None:
69
- return None, "YOLOv8 model not loaded"
70
 
71
  # Measure inference time
72
  import time
@@ -91,13 +138,23 @@ def process_yolo(image):
91
 
92
  # Format detection results
93
  detections = []
 
 
94
  for box in boxes:
95
  class_id = int(box.cls[0].item())
96
  class_name = class_names[class_id]
97
  confidence = round(box.conf[0].item(), 2)
98
  bbox = box.xyxy[0].tolist()
99
  bbox = [round(x) for x in bbox]
 
100
  detections.append("{}: {} at {}".format(class_name, confidence, bbox))
 
 
 
 
 
 
 
101
 
102
  # Calculate inference time
103
  inference_time = time.time() - start_time
@@ -108,6 +165,8 @@ def process_yolo(image):
108
  detection_text = "\n".join(detections) if detections else "No objects detected"
109
  detection_text += performance_info
110
 
 
 
111
  return result_image, detection_text
112
 
113
  def process_detr(image):
@@ -272,11 +331,27 @@ with gr.Blocks(title="Object Detection Demo") as demo:
272
  with gr.Column():
273
  vit_text = gr.Textbox(label="ViT Classification Results")
274
 
 
 
 
 
 
 
 
 
 
275
  # Set up event handlers
276
- yolo_button.click(
277
  fn=process_yolo,
278
  inputs=input_image,
279
- outputs=[yolo_output, yolo_text]
 
 
 
 
 
 
 
280
  )
281
 
282
  detr_button.click(
 
3
  from PIL import Image
4
  import numpy as np
5
  import os
6
+ import requests
7
+ import json
8
+ import base64
9
+ from io import BytesIO
10
+ import uuid
11
 
12
  # Model initialization
13
  print("Loading models... This may take a moment.")
 
68
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  print(f"Using device: {device}")
70
 
71
+ # 벡터 DB에 객체 저장 함수
72
+ def save_objects_to_vector_db(image, detection_results):
73
+ if image is None or detection_results is None:
74
+ return "이미지나 객체 인식 결과가 없습니다."
75
+
76
+ try:
77
+ # 이미지를 base64로 인코딩
78
+ buffered = BytesIO()
79
+ image.save(buffered, format="JPEG")
80
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
81
+
82
+ # 객체 정보 추출
83
+ objects = []
84
+ for obj in detection_results['objects']:
85
+ objects.append({
86
+ "class": obj['class'],
87
+ "confidence": obj['confidence'],
88
+ "bbox": obj['bbox']
89
+ })
90
+
91
+ # API 요청 데이터 구성
92
+ data = {
93
+ "image": img_str,
94
+ "objects": objects,
95
+ "image_id": str(uuid.uuid4())
96
+ }
97
+
98
+ # API 호출
99
+ response = requests.post(
100
+ "http://localhost:7860/api/add-detected-objects",
101
+ json=data
102
+ )
103
+
104
+ if response.status_code == 200:
105
+ result = response.json()
106
+ return f"벡터 DB에 {len(objects)}개 객체 저장 성공! 저장된 객체 ID: {', '.join(result.get('object_ids', [])[:3])}..."
107
+ else:
108
+ return f"저장 실패: {response.text}"
109
+
110
+ except Exception as e:
111
+ return f"오류 발생: {str(e)}"
112
+
113
  # Define model inference functions
114
  def process_yolo(image):
115
  if yolo_model is None:
116
+ return None, "YOLOv8 model not loaded", None
117
 
118
  # Measure inference time
119
  import time
 
138
 
139
  # Format detection results
140
  detections = []
141
+ detection_objects = {'objects': []}
142
+
143
  for box in boxes:
144
  class_id = int(box.cls[0].item())
145
  class_name = class_names[class_id]
146
  confidence = round(box.conf[0].item(), 2)
147
  bbox = box.xyxy[0].tolist()
148
  bbox = [round(x) for x in bbox]
149
+
150
  detections.append("{}: {} at {}".format(class_name, confidence, bbox))
151
+
152
+ # 벡터 DB 저장용 객체 정보 추가
153
+ detection_objects['objects'].append({
154
+ 'class': class_name,
155
+ 'confidence': confidence,
156
+ 'bbox': bbox
157
+ })
158
 
159
  # Calculate inference time
160
  inference_time = time.time() - start_time
 
165
  detection_text = "\n".join(detections) if detections else "No objects detected"
166
  detection_text += performance_info
167
 
168
+ return result_image, detection_text, detection_objects
169
+
170
  return result_image, detection_text
171
 
172
  def process_detr(image):
 
331
  with gr.Column():
332
  vit_text = gr.Textbox(label="ViT Classification Results")
333
 
334
+ # 벡터 DB 저장 버튼 및 결과 표시
335
+ with gr.Row():
336
+ with gr.Column():
337
+ save_to_db_button = gr.Button("YOLOv8 인식 결과를 벡터 DB에 저장", variant="primary")
338
+ save_result = gr.Textbox(label="벡터 DB 저장 결과")
339
+
340
+ # 객체 인식 결과 저장용 상태 변수
341
+ detection_state = gr.State(None)
342
+
343
  # Set up event handlers
344
+ yolo_result = yolo_button.click(
345
  fn=process_yolo,
346
  inputs=input_image,
347
+ outputs=[yolo_output, yolo_text, detection_state]
348
+ )
349
+
350
+ # 벡터 DB 저장 버튼 이벤트 핸들러
351
+ save_to_db_button.click(
352
+ fn=save_objects_to_vector_db,
353
+ inputs=[input_image, detection_state],
354
+ outputs=save_result
355
  )
356
 
357
  detr_button.click(