David Ko commited on
Commit
c28eadf
·
1 Parent(s): cd3a15b

Gradio UI에 벡터 DB 저장 및 검색 기능 통합

Browse files
Files changed (1) hide show
  1. app.py +222 -35
app.py CHANGED
@@ -69,7 +69,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  print(f"Using device: {device}")
70
 
71
  # 벡터 DB에 객체 저장 함수
72
- def save_objects_to_vector_db(image, detection_results):
73
  if image is None or detection_results is None:
74
  return "이미지나 객체 인식 결과가 없습니다."
75
 
@@ -79,36 +79,126 @@ def save_objects_to_vector_db(image, detection_results):
79
  image.save(buffered, format="JPEG")
80
  img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
81
 
82
- # 객체 정보 추출
83
- objects = []
84
- for obj in detection_results['objects']:
85
- objects.append({
86
- "class": obj['class'],
87
- "confidence": obj['confidence'],
88
- "bbox": obj['bbox']
89
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # API 요청 데이터 구성
92
- data = {
93
- "image": img_str,
94
- "objects": objects,
95
- "image_id": str(uuid.uuid4())
96
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # API 호출
99
  response = requests.post(
100
- "http://localhost:7860/api/add-detected-objects",
101
  json=data
102
  )
103
 
104
  if response.status_code == 200:
105
- result = response.json()
106
- return f"벡터 DB에 {len(objects)}개 객체 저장 성공! 저장된 객체 ID: {', '.join(result.get('object_ids', [])[:3])}..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
- return f"저장 실패: {response.text}"
109
-
110
  except Exception as e:
111
- return f"오류 발생: {str(e)}"
112
 
113
  # Define model inference functions
114
  def process_yolo(image):
@@ -334,36 +424,133 @@ with gr.Blocks(title="Object Detection Demo") as demo:
334
  # 벡터 DB 저장 버튼 및 결과 표시
335
  with gr.Row():
336
  with gr.Column():
337
- save_to_db_button = gr.Button("YOLOv8 인식 결과를 벡터 DB 저장", variant="primary")
 
 
 
338
  save_result = gr.Textbox(label="벡터 DB 저장 결과")
 
 
 
 
 
 
 
339
 
340
  # 객체 인식 결과 저장용 상태 변수
341
- detection_state = gr.State(None)
 
 
342
 
343
  # Set up event handlers
344
- yolo_result = yolo_button.click(
345
  fn=process_yolo,
346
  inputs=input_image,
347
- outputs=[yolo_output, yolo_text, detection_state]
348
  )
349
 
350
- # 벡터 DB 저장 버튼 이벤트 핸들러
351
- save_to_db_button.click(
352
- fn=save_objects_to_vector_db,
353
- inputs=[input_image, detection_state],
354
- outputs=save_result
355
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  detr_button.click(
358
- fn=process_detr,
359
  inputs=input_image,
360
- outputs=[detr_output, detr_text]
361
  )
362
 
363
  vit_button.click(
364
- fn=process_vit,
365
  inputs=input_image,
366
- outputs=vit_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  )
368
 
369
 
 
69
  print(f"Using device: {device}")
70
 
71
  # 벡터 DB에 객체 저장 함수
72
+ def save_objects_to_vector_db(image, detection_results, model_type='yolo'):
73
  if image is None or detection_results is None:
74
  return "이미지나 객체 인식 결과가 없습니다."
75
 
 
79
  image.save(buffered, format="JPEG")
80
  img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
81
 
82
+ # 모델 타입에 따라 다른 API 엔드포인트 호출
83
+ if model_type in ['yolo', 'detr']:
84
+ # 객체 정보 추출
85
+ objects = []
86
+ for obj in detection_results['objects']:
87
+ objects.append({
88
+ "class": obj['class'],
89
+ "confidence": obj['confidence'],
90
+ "bbox": obj['bbox']
91
+ })
92
+
93
+ # API 요청 데이터 구성
94
+ data = {
95
+ "image": img_str,
96
+ "objects": objects,
97
+ "image_id": str(uuid.uuid4())
98
+ }
99
+
100
+ # API 호출
101
+ response = requests.post(
102
+ "http://localhost:7860/api/add-detected-objects",
103
+ json=data
104
+ )
105
+
106
+ if response.status_code == 200:
107
+ result = response.json()
108
+ if 'error' in result:
109
+ return f"오류 발생: {result['error']}"
110
+ return f"벡터 DB에 {len(objects)}개 객체 저장 완료! ID: {result.get('ids', '알 수 없음')}"
111
 
112
+ elif model_type == 'vit':
113
+ # ViT 분류 결과 저장
114
+ data = {
115
+ "image": img_str,
116
+ "metadata": {
117
+ "model": "vit",
118
+ "classifications": detection_results.get('classifications', [])
119
+ }
120
+ }
121
+
122
+ # API 호출
123
+ response = requests.post(
124
+ "http://localhost:7860/api/add-image",
125
+ json=data
126
+ )
127
+
128
+ if response.status_code == 200:
129
+ result = response.json()
130
+ if 'error' in result:
131
+ return f"오류 발생: {result['error']}"
132
+ return f"벡터 DB에 이미지 및 분류 결과 저장 완료! ID: {result.get('id', '알 수 없음')}"
133
+
134
+ else:
135
+ return "지원하지 않는 모델 타입입니다."
136
+
137
+ if response.status_code != 200:
138
+ return f"API 오류: {response.status_code}"
139
+ except Exception as e:
140
+ return f"오류 발생: {str(e)}"
141
+
142
+ # 벡터 DB에서 유사 객체 검색 함수
143
+ def search_similar_objects(image=None, class_name=None):
144
+ try:
145
+ data = {}
146
+
147
+ if image is not None:
148
+ # 이미지를 base64로 인코딩
149
+ buffered = BytesIO()
150
+ image.save(buffered, format="JPEG")
151
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
152
+ data["image"] = img_str
153
+ data["n_results"] = 5
154
+ elif class_name is not None and class_name.strip():
155
+ data["class_name"] = class_name.strip()
156
+ data["n_results"] = 5
157
+ else:
158
+ return "이미지나 클래스 이름 중 하나는 제공해야 합니다.", []
159
 
160
  # API 호출
161
  response = requests.post(
162
+ "http://localhost:7860/api/search-similar-objects",
163
  json=data
164
  )
165
 
166
  if response.status_code == 200:
167
+ results = response.json()
168
+ if isinstance(results, dict) and 'error' in results:
169
+ return f"오류 발생: {results['error']}", []
170
+
171
+ # 결과 포맷팅
172
+ formatted_results = []
173
+ for i, result in enumerate(results):
174
+ similarity = (1 - result.get('distance', 0)) * 100
175
+ img_data = result.get('image', '')
176
+
177
+ # 이미지 데이터를 PIL 이미지로 변환
178
+ if img_data:
179
+ try:
180
+ img_bytes = base64.b64decode(img_data)
181
+ img = Image.open(BytesIO(img_bytes))
182
+ except Exception:
183
+ img = None
184
+ else:
185
+ img = None
186
+
187
+ # 메타데이터 추출
188
+ metadata = result.get('metadata', {})
189
+ class_name = metadata.get('class', 'N/A')
190
+ confidence = metadata.get('confidence', 0) * 100 if metadata.get('confidence') else 'N/A'
191
+
192
+ formatted_results.append({
193
+ 'image': img,
194
+ 'info': f"결과 #{i+1} | 유사도: {similarity:.2f}% | 클래스: {class_name} | 신뢰도: {confidence if isinstance(confidence, str) else f'{confidence:.2f}%'} | ID: {result.get('id', 'N/A')}"
195
+ })
196
+
197
+ return f"{len(formatted_results)}개의 유사 객체를 찾았습니다.", formatted_results
198
  else:
199
+ return f"API 오류: {response.status_code}", []
 
200
  except Exception as e:
201
+ return f"오류 발생: {str(e)}", []
202
 
203
  # Define model inference functions
204
  def process_yolo(image):
 
424
  # 벡터 DB 저장 버튼 및 결과 표시
425
  with gr.Row():
426
  with gr.Column():
427
+ gr.Markdown("### 벡터 DB 저장")
428
+ save_yolo_button = gr.Button("YOLOv8 인식 결과 저장", variant="primary")
429
+ save_detr_button = gr.Button("DETR 인식 결과 저장", variant="primary")
430
+ save_vit_button = gr.Button("ViT 분류 결과 저장", variant="primary")
431
  save_result = gr.Textbox(label="벡터 DB 저장 결과")
432
+
433
+ with gr.Column():
434
+ gr.Markdown("### 벡터 DB 검색")
435
+ search_class = gr.Textbox(label="클래스 이름으로 검색")
436
+ search_button = gr.Button("검색", variant="secondary")
437
+ search_result_text = gr.Textbox(label="검색 결과 정보")
438
+ search_result_gallery = gr.Gallery(label="검색 결과", columns=5, height=300)
439
 
440
  # 객체 인식 결과 저장용 상태 변수
441
+ yolo_detection_state = gr.State(None)
442
+ detr_detection_state = gr.State(None)
443
+ vit_classification_state = gr.State(None)
444
 
445
  # Set up event handlers
446
+ yolo_button.click(
447
  fn=process_yolo,
448
  inputs=input_image,
449
+ outputs=[yolo_output, yolo_text, yolo_detection_state]
450
  )
451
 
452
+ # DETR 결과 처리 함수 수정 - 상태 저장 추가
453
+ def process_detr_with_state(image):
454
+ result_image, result_text = process_detr(image)
455
+
456
+ # 객체 인식 결과 추출
457
+ detection_results = {"objects": []}
458
+
459
+ # 결과 텍스트에서 객체 정보 추출
460
+ lines = result_text.split('\n')
461
+ for line in lines:
462
+ if ': ' in line and ' at ' in line:
463
+ try:
464
+ class_conf, location = line.split(' at ')
465
+ class_name, confidence = class_conf.split(': ')
466
+ confidence = float(confidence)
467
+
468
+ # 바운딩 박스 정보 추출
469
+ bbox_str = location.strip('[]').split(', ')
470
+ bbox = [int(coord) for coord in bbox_str]
471
+
472
+ detection_results["objects"].append({
473
+ "class": class_name,
474
+ "confidence": confidence,
475
+ "bbox": bbox
476
+ })
477
+ except Exception:
478
+ pass
479
+
480
+ return result_image, result_text, detection_results
481
+
482
+ # ViT 결과 처리 함수 수정 - 상태 저장 추가
483
+ def process_vit_with_state(image):
484
+ result_text = process_vit(image)
485
+
486
+ # 분류 결과 추출
487
+ classifications = []
488
+
489
+ # 결과 텍스트에서 분류 정보 추출
490
+ lines = result_text.split('\n')
491
+ for line in lines:
492
+ if '. ' in line and ': ' in line:
493
+ try:
494
+ rank_class, confidence = line.split(': ')
495
+ _, class_name = rank_class.split('. ')
496
+ confidence = float(confidence)
497
+
498
+ classifications.append({
499
+ "class": class_name,
500
+ "confidence": confidence
501
+ })
502
+ except Exception:
503
+ pass
504
+
505
+ return result_text, {"classifications": classifications}
506
 
507
  detr_button.click(
508
+ fn=process_detr_with_state,
509
  inputs=input_image,
510
+ outputs=[detr_output, detr_text, detr_detection_state]
511
  )
512
 
513
  vit_button.click(
514
+ fn=process_vit_with_state,
515
  inputs=input_image,
516
+ outputs=[vit_text, vit_classification_state]
517
+ )
518
+
519
+ # 벡터 DB 저장 버튼 이벤트 핸들러
520
+ save_yolo_button.click(
521
+ fn=lambda img, det: save_objects_to_vector_db(img, det, 'yolo'),
522
+ inputs=[input_image, yolo_detection_state],
523
+ outputs=save_result
524
+ )
525
+
526
+ save_detr_button.click(
527
+ fn=lambda img, det: save_objects_to_vector_db(img, det, 'detr'),
528
+ inputs=[input_image, detr_detection_state],
529
+ outputs=save_result
530
+ )
531
+
532
+ save_vit_button.click(
533
+ fn=lambda img, det: save_objects_to_vector_db(img, det, 'vit'),
534
+ inputs=[input_image, vit_classification_state],
535
+ outputs=save_result
536
+ )
537
+
538
+ # 검색 버튼 이벤트 핸들러
539
+ def format_search_results(result_text, results):
540
+ images = []
541
+ captions = []
542
+
543
+ for result in results:
544
+ if result.get('image'):
545
+ images.append(result['image'])
546
+ captions.append(result['info'])
547
+
548
+ return result_text, [(img, cap) for img, cap in zip(images, captions)]
549
+
550
+ search_button.click(
551
+ fn=lambda class_name: search_similar_objects(class_name=class_name),
552
+ inputs=search_class,
553
+ outputs=[search_result_text, search_result_gallery]
554
  )
555
 
556