Spaces:
Running
Running
David Ko
commited on
Commit
·
c28eadf
1
Parent(s):
cd3a15b
Gradio UI에 벡터 DB 저장 및 검색 기능 통합
Browse files
app.py
CHANGED
@@ -69,7 +69,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
69 |
print(f"Using device: {device}")
|
70 |
|
71 |
# 벡터 DB에 객체 저장 함수
|
72 |
-
def save_objects_to_vector_db(image, detection_results):
|
73 |
if image is None or detection_results is None:
|
74 |
return "이미지나 객체 인식 결과가 없습니다."
|
75 |
|
@@ -79,36 +79,126 @@ def save_objects_to_vector_db(image, detection_results):
|
|
79 |
image.save(buffered, format="JPEG")
|
80 |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
81 |
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
objects
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# API 호출
|
99 |
response = requests.post(
|
100 |
-
"http://localhost:7860/api/
|
101 |
json=data
|
102 |
)
|
103 |
|
104 |
if response.status_code == 200:
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
else:
|
108 |
-
return f"
|
109 |
-
|
110 |
except Exception as e:
|
111 |
-
return f"오류 발생: {str(e)}"
|
112 |
|
113 |
# Define model inference functions
|
114 |
def process_yolo(image):
|
@@ -334,36 +424,133 @@ with gr.Blocks(title="Object Detection Demo") as demo:
|
|
334 |
# 벡터 DB 저장 버튼 및 결과 표시
|
335 |
with gr.Row():
|
336 |
with gr.Column():
|
337 |
-
|
|
|
|
|
|
|
338 |
save_result = gr.Textbox(label="벡터 DB 저장 결과")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
# 객체 인식 결과 저장용 상태 변수
|
341 |
-
|
|
|
|
|
342 |
|
343 |
# Set up event handlers
|
344 |
-
|
345 |
fn=process_yolo,
|
346 |
inputs=input_image,
|
347 |
-
outputs=[yolo_output, yolo_text,
|
348 |
)
|
349 |
|
350 |
-
#
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
detr_button.click(
|
358 |
-
fn=
|
359 |
inputs=input_image,
|
360 |
-
outputs=[detr_output, detr_text]
|
361 |
)
|
362 |
|
363 |
vit_button.click(
|
364 |
-
fn=
|
365 |
inputs=input_image,
|
366 |
-
outputs=vit_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
)
|
368 |
|
369 |
|
|
|
69 |
print(f"Using device: {device}")
|
70 |
|
71 |
# 벡터 DB에 객체 저장 함수
|
72 |
+
def save_objects_to_vector_db(image, detection_results, model_type='yolo'):
|
73 |
if image is None or detection_results is None:
|
74 |
return "이미지나 객체 인식 결과가 없습니다."
|
75 |
|
|
|
79 |
image.save(buffered, format="JPEG")
|
80 |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
81 |
|
82 |
+
# 모델 타입에 따라 다른 API 엔드포인트 호출
|
83 |
+
if model_type in ['yolo', 'detr']:
|
84 |
+
# 객체 정보 추출
|
85 |
+
objects = []
|
86 |
+
for obj in detection_results['objects']:
|
87 |
+
objects.append({
|
88 |
+
"class": obj['class'],
|
89 |
+
"confidence": obj['confidence'],
|
90 |
+
"bbox": obj['bbox']
|
91 |
+
})
|
92 |
+
|
93 |
+
# API 요청 데이터 구성
|
94 |
+
data = {
|
95 |
+
"image": img_str,
|
96 |
+
"objects": objects,
|
97 |
+
"image_id": str(uuid.uuid4())
|
98 |
+
}
|
99 |
+
|
100 |
+
# API 호출
|
101 |
+
response = requests.post(
|
102 |
+
"http://localhost:7860/api/add-detected-objects",
|
103 |
+
json=data
|
104 |
+
)
|
105 |
+
|
106 |
+
if response.status_code == 200:
|
107 |
+
result = response.json()
|
108 |
+
if 'error' in result:
|
109 |
+
return f"오류 발생: {result['error']}"
|
110 |
+
return f"벡터 DB에 {len(objects)}개 객체 저장 완료! ID: {result.get('ids', '알 수 없음')}"
|
111 |
|
112 |
+
elif model_type == 'vit':
|
113 |
+
# ViT 분류 결과 저장
|
114 |
+
data = {
|
115 |
+
"image": img_str,
|
116 |
+
"metadata": {
|
117 |
+
"model": "vit",
|
118 |
+
"classifications": detection_results.get('classifications', [])
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
# API 호출
|
123 |
+
response = requests.post(
|
124 |
+
"http://localhost:7860/api/add-image",
|
125 |
+
json=data
|
126 |
+
)
|
127 |
+
|
128 |
+
if response.status_code == 200:
|
129 |
+
result = response.json()
|
130 |
+
if 'error' in result:
|
131 |
+
return f"오류 발생: {result['error']}"
|
132 |
+
return f"벡터 DB에 이미지 및 분류 결과 저장 완료! ID: {result.get('id', '알 수 없음')}"
|
133 |
+
|
134 |
+
else:
|
135 |
+
return "지원하지 않는 모델 타입입니다."
|
136 |
+
|
137 |
+
if response.status_code != 200:
|
138 |
+
return f"API 오류: {response.status_code}"
|
139 |
+
except Exception as e:
|
140 |
+
return f"오류 발생: {str(e)}"
|
141 |
+
|
142 |
+
# 벡터 DB에서 유사 객체 검색 함수
|
143 |
+
def search_similar_objects(image=None, class_name=None):
|
144 |
+
try:
|
145 |
+
data = {}
|
146 |
+
|
147 |
+
if image is not None:
|
148 |
+
# 이미지를 base64로 인코딩
|
149 |
+
buffered = BytesIO()
|
150 |
+
image.save(buffered, format="JPEG")
|
151 |
+
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
152 |
+
data["image"] = img_str
|
153 |
+
data["n_results"] = 5
|
154 |
+
elif class_name is not None and class_name.strip():
|
155 |
+
data["class_name"] = class_name.strip()
|
156 |
+
data["n_results"] = 5
|
157 |
+
else:
|
158 |
+
return "이미지나 클래스 이름 중 하나는 제공해야 합니다.", []
|
159 |
|
160 |
# API 호출
|
161 |
response = requests.post(
|
162 |
+
"http://localhost:7860/api/search-similar-objects",
|
163 |
json=data
|
164 |
)
|
165 |
|
166 |
if response.status_code == 200:
|
167 |
+
results = response.json()
|
168 |
+
if isinstance(results, dict) and 'error' in results:
|
169 |
+
return f"오류 발생: {results['error']}", []
|
170 |
+
|
171 |
+
# 결과 포맷팅
|
172 |
+
formatted_results = []
|
173 |
+
for i, result in enumerate(results):
|
174 |
+
similarity = (1 - result.get('distance', 0)) * 100
|
175 |
+
img_data = result.get('image', '')
|
176 |
+
|
177 |
+
# 이미지 데이터를 PIL 이미지로 변환
|
178 |
+
if img_data:
|
179 |
+
try:
|
180 |
+
img_bytes = base64.b64decode(img_data)
|
181 |
+
img = Image.open(BytesIO(img_bytes))
|
182 |
+
except Exception:
|
183 |
+
img = None
|
184 |
+
else:
|
185 |
+
img = None
|
186 |
+
|
187 |
+
# 메타데이터 추출
|
188 |
+
metadata = result.get('metadata', {})
|
189 |
+
class_name = metadata.get('class', 'N/A')
|
190 |
+
confidence = metadata.get('confidence', 0) * 100 if metadata.get('confidence') else 'N/A'
|
191 |
+
|
192 |
+
formatted_results.append({
|
193 |
+
'image': img,
|
194 |
+
'info': f"결과 #{i+1} | 유사도: {similarity:.2f}% | 클래스: {class_name} | 신뢰도: {confidence if isinstance(confidence, str) else f'{confidence:.2f}%'} | ID: {result.get('id', 'N/A')}"
|
195 |
+
})
|
196 |
+
|
197 |
+
return f"{len(formatted_results)}개의 유사 객체를 찾았습니다.", formatted_results
|
198 |
else:
|
199 |
+
return f"API 오류: {response.status_code}", []
|
|
|
200 |
except Exception as e:
|
201 |
+
return f"오류 발생: {str(e)}", []
|
202 |
|
203 |
# Define model inference functions
|
204 |
def process_yolo(image):
|
|
|
424 |
# 벡터 DB 저장 버튼 및 결과 표시
|
425 |
with gr.Row():
|
426 |
with gr.Column():
|
427 |
+
gr.Markdown("### 벡터 DB 저장")
|
428 |
+
save_yolo_button = gr.Button("YOLOv8 인식 결과 저장", variant="primary")
|
429 |
+
save_detr_button = gr.Button("DETR 인식 결과 저장", variant="primary")
|
430 |
+
save_vit_button = gr.Button("ViT 분류 결과 저장", variant="primary")
|
431 |
save_result = gr.Textbox(label="벡터 DB 저장 결과")
|
432 |
+
|
433 |
+
with gr.Column():
|
434 |
+
gr.Markdown("### 벡터 DB 검색")
|
435 |
+
search_class = gr.Textbox(label="클래스 이름으로 검색")
|
436 |
+
search_button = gr.Button("검색", variant="secondary")
|
437 |
+
search_result_text = gr.Textbox(label="검색 결과 정보")
|
438 |
+
search_result_gallery = gr.Gallery(label="검색 결과", columns=5, height=300)
|
439 |
|
440 |
# 객체 인식 결과 저장용 상태 변수
|
441 |
+
yolo_detection_state = gr.State(None)
|
442 |
+
detr_detection_state = gr.State(None)
|
443 |
+
vit_classification_state = gr.State(None)
|
444 |
|
445 |
# Set up event handlers
|
446 |
+
yolo_button.click(
|
447 |
fn=process_yolo,
|
448 |
inputs=input_image,
|
449 |
+
outputs=[yolo_output, yolo_text, yolo_detection_state]
|
450 |
)
|
451 |
|
452 |
+
# DETR 결과 처리 함수 수정 - 상태 저장 추가
|
453 |
+
def process_detr_with_state(image):
|
454 |
+
result_image, result_text = process_detr(image)
|
455 |
+
|
456 |
+
# 객체 인식 결과 추출
|
457 |
+
detection_results = {"objects": []}
|
458 |
+
|
459 |
+
# 결과 텍스트에서 객체 정보 추출
|
460 |
+
lines = result_text.split('\n')
|
461 |
+
for line in lines:
|
462 |
+
if ': ' in line and ' at ' in line:
|
463 |
+
try:
|
464 |
+
class_conf, location = line.split(' at ')
|
465 |
+
class_name, confidence = class_conf.split(': ')
|
466 |
+
confidence = float(confidence)
|
467 |
+
|
468 |
+
# 바운딩 박스 정보 추출
|
469 |
+
bbox_str = location.strip('[]').split(', ')
|
470 |
+
bbox = [int(coord) for coord in bbox_str]
|
471 |
+
|
472 |
+
detection_results["objects"].append({
|
473 |
+
"class": class_name,
|
474 |
+
"confidence": confidence,
|
475 |
+
"bbox": bbox
|
476 |
+
})
|
477 |
+
except Exception:
|
478 |
+
pass
|
479 |
+
|
480 |
+
return result_image, result_text, detection_results
|
481 |
+
|
482 |
+
# ViT 결과 처리 함수 수정 - 상태 저장 추가
|
483 |
+
def process_vit_with_state(image):
|
484 |
+
result_text = process_vit(image)
|
485 |
+
|
486 |
+
# 분류 결과 추출
|
487 |
+
classifications = []
|
488 |
+
|
489 |
+
# 결과 텍스트에서 분류 정보 추출
|
490 |
+
lines = result_text.split('\n')
|
491 |
+
for line in lines:
|
492 |
+
if '. ' in line and ': ' in line:
|
493 |
+
try:
|
494 |
+
rank_class, confidence = line.split(': ')
|
495 |
+
_, class_name = rank_class.split('. ')
|
496 |
+
confidence = float(confidence)
|
497 |
+
|
498 |
+
classifications.append({
|
499 |
+
"class": class_name,
|
500 |
+
"confidence": confidence
|
501 |
+
})
|
502 |
+
except Exception:
|
503 |
+
pass
|
504 |
+
|
505 |
+
return result_text, {"classifications": classifications}
|
506 |
|
507 |
detr_button.click(
|
508 |
+
fn=process_detr_with_state,
|
509 |
inputs=input_image,
|
510 |
+
outputs=[detr_output, detr_text, detr_detection_state]
|
511 |
)
|
512 |
|
513 |
vit_button.click(
|
514 |
+
fn=process_vit_with_state,
|
515 |
inputs=input_image,
|
516 |
+
outputs=[vit_text, vit_classification_state]
|
517 |
+
)
|
518 |
+
|
519 |
+
# 벡터 DB 저장 버튼 이벤트 핸들러
|
520 |
+
save_yolo_button.click(
|
521 |
+
fn=lambda img, det: save_objects_to_vector_db(img, det, 'yolo'),
|
522 |
+
inputs=[input_image, yolo_detection_state],
|
523 |
+
outputs=save_result
|
524 |
+
)
|
525 |
+
|
526 |
+
save_detr_button.click(
|
527 |
+
fn=lambda img, det: save_objects_to_vector_db(img, det, 'detr'),
|
528 |
+
inputs=[input_image, detr_detection_state],
|
529 |
+
outputs=save_result
|
530 |
+
)
|
531 |
+
|
532 |
+
save_vit_button.click(
|
533 |
+
fn=lambda img, det: save_objects_to_vector_db(img, det, 'vit'),
|
534 |
+
inputs=[input_image, vit_classification_state],
|
535 |
+
outputs=save_result
|
536 |
+
)
|
537 |
+
|
538 |
+
# 검색 버튼 이벤트 핸들러
|
539 |
+
def format_search_results(result_text, results):
|
540 |
+
images = []
|
541 |
+
captions = []
|
542 |
+
|
543 |
+
for result in results:
|
544 |
+
if result.get('image'):
|
545 |
+
images.append(result['image'])
|
546 |
+
captions.append(result['info'])
|
547 |
+
|
548 |
+
return result_text, [(img, cap) for img, cap in zip(images, captions)]
|
549 |
+
|
550 |
+
search_button.click(
|
551 |
+
fn=lambda class_name: search_similar_objects(class_name=class_name),
|
552 |
+
inputs=search_class,
|
553 |
+
outputs=[search_result_text, search_result_gallery]
|
554 |
)
|
555 |
|
556 |
|