import gradio as gr import torch from PIL import Image import numpy as np import os import requests import json import base64 from io import BytesIO import uuid # Model initialization print("Loading models... This may take a moment.") # YOLOv8 model yolo_model = None try: from ultralytics import YOLO yolo_model = YOLO("yolov8n.pt") # Using the nano model for faster inference print("YOLOv8 model loaded successfully") except Exception as e: print("Error loading YOLOv8 model:", e) yolo_model = None # DETR model (DEtection TRansformer) detr_processor = None detr_model = None try: from transformers import DetrImageProcessor, DetrForObjectDetection # Load the DETR image processor # DetrImageProcessor: Handles preprocessing of images for DETR model # - Resizes images to appropriate dimensions # - Normalizes pixel values # - Converts images to tensors # - Handles batch processing detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") # Load the DETR object detection model # DetrForObjectDetection: The actual object detection model # - Uses ResNet-50 as backbone # - Transformer-based architecture for object detection # - Predicts bounding boxes and object classes # - Pre-trained on COCO dataset by Facebook AI Research detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") print("DETR model loaded successfully") except Exception as e: print("Error loading DETR model:", e) detr_processor = None detr_model = None # ViT model vit_processor = None vit_model = None try: from transformers import ViTImageProcessor, ViTForImageClassification vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") print("ViT model loaded successfully") except Exception as e: print("Error loading ViT model:", e) vit_processor = None vit_model = None # Get device information import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 벡터 DB에 객체 저장 함수 def save_objects_to_vector_db(image, detection_results, model_type='yolo'): if image is None or detection_results is None: return "이미지나 객체 인식 결과가 없습니다." try: # 이미지를 base64로 인코딩 buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') # 모델 타입에 따라 다른 API 엔드포인트 호출 if model_type in ['yolo', 'detr']: # 객체 정보 추출 objects = [] for obj in detection_results['objects']: objects.append({ "class": obj['class'], "confidence": obj['confidence'], "bbox": obj['bbox'] }) # API 요청 데이터 구성 data = { "image": img_str, "objects": objects, "image_id": str(uuid.uuid4()) } # API 호출 response = requests.post( "http://localhost:7860/api/add-detected-objects", json=data ) if response.status_code == 200: result = response.json() if 'error' in result: return f"오류 발생: {result['error']}" return f"벡터 DB에 {len(objects)}개 객체 저장 완료! ID: {result.get('ids', '알 수 없음')}" elif model_type == 'vit': # ViT 분류 결과 저장 data = { "image": img_str, "metadata": { "model": "vit", "classifications": detection_results.get('classifications', []) } } # API 호출 response = requests.post( "http://localhost:7860/api/add-image", json=data ) if response.status_code == 200: result = response.json() if 'error' in result: return f"오류 발생: {result['error']}" return f"벡터 DB에 이미지 및 분류 결과 저장 완료! ID: {result.get('id', '알 수 없음')}" else: return "지원하지 않는 모델 타입입니다." if response.status_code != 200: return f"API 오류: {response.status_code}" except Exception as e: return f"오류 발생: {str(e)}" # 벡터 DB에서 유사 객체 검색 함수 def search_similar_objects(image=None, class_name=None): try: data = {} if image is not None: # 이미지를 base64로 인코딩 buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') data["image"] = img_str data["n_results"] = 5 elif class_name is not None and class_name.strip(): data["class_name"] = class_name.strip() data["n_results"] = 5 else: return "이미지나 클래스 이름 중 하나는 제공해야 합니다.", [] # API 호출 response = requests.post( "http://localhost:7860/api/search-similar-objects", json=data ) if response.status_code == 200: results = response.json() if isinstance(results, dict) and 'error' in results: return f"오류 발생: {results['error']}", [] # 결과 포맷팅 formatted_results = [] for i, result in enumerate(results): similarity = (1 - result.get('distance', 0)) * 100 img_data = result.get('image', '') # 이미지 데이터를 PIL 이미지로 변환 if img_data: try: img_bytes = base64.b64decode(img_data) img = Image.open(BytesIO(img_bytes)) except Exception: img = None else: img = None # 메타데이터 추출 metadata = result.get('metadata', {}) class_name = metadata.get('class', 'N/A') confidence = metadata.get('confidence', 0) * 100 if metadata.get('confidence') else 'N/A' formatted_results.append({ 'image': img, 'info': f"결과 #{i+1} | 유사도: {similarity:.2f}% | 클래스: {class_name} | 신뢰도: {confidence if isinstance(confidence, str) else f'{confidence:.2f}%'} | ID: {result.get('id', 'N/A')}" }) return f"{len(formatted_results)}개의 유사 객체를 찾았습니다.", formatted_results else: return f"API 오류: {response.status_code}", [] except Exception as e: return f"오류 발생: {str(e)}", [] # Define model inference functions def process_yolo(image): if yolo_model is None: return None, "YOLOv8 model not loaded", None # Measure inference time import time start_time = time.time() # Convert to numpy if it's a PIL image if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image # Run inference results = yolo_model(image_np) # Process results result_image = results[0].plot() result_image = Image.fromarray(result_image) # Get detection information boxes = results[0].boxes class_names = results[0].names # Format detection results detections = [] detection_objects = {'objects': []} for box in boxes: class_id = int(box.cls[0].item()) class_name = class_names[class_id] confidence = round(box.conf[0].item(), 2) bbox = box.xyxy[0].tolist() bbox = [round(x) for x in bbox] detections.append("{}: {} at {}".format(class_name, confidence, bbox)) # 벡터 DB 저장용 객체 정보 추가 detection_objects['objects'].append({ 'class': class_name, 'confidence': confidence, 'bbox': bbox }) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info to detection text device_info = "GPU" if torch.cuda.is_available() else "CPU" performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}" detection_text = "\n".join(detections) if detections else "No objects detected" detection_text += performance_info return result_image, detection_text, detection_objects return result_image, detection_text def process_detr(image): if detr_model is None or detr_processor is None: return None, "DETR model not loaded" # Measure inference time import time start_time = time.time() # Prepare image for the model inputs = detr_processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = detr_model(**inputs) # Convert outputs to image with bounding boxes # Create tensor with original image dimensions (height, width) # image.size[::-1] reverses the (width, height) to (height, width) as required by DETR target_sizes = torch.tensor([image.size[::-1]]) # Process raw model outputs into usable detection results # - Maps predictions back to original image size # - Filters detections using confidence threshold (0.9) # - Returns a dictionary with 'scores', 'labels', and 'boxes' keys # - [0] extracts results for the first (and only) image in the batch results = detr_processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=0.9 )[0] # Create a copy of the image to draw on result_image = image.copy() import matplotlib.pyplot as plt from matplotlib.patches import Rectangle import io # Create figure and axes fig, ax = plt.subplots(1) ax.imshow(result_image) # Format detection results detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i) for i in box.tolist()] class_name = detr_model.config.id2label[label.item()] confidence = round(score.item(), 2) # Draw rectangle rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=2, edgecolor='r', facecolor='none') ax.add_patch(rect) # Add label plt.text(box[0], box[1], "{}: {}".format(class_name, confidence), bbox=dict(facecolor='white', alpha=0.8)) detections.append("{}: {} at {}".format(class_name, confidence, box)) # Save figure to image buf = io.BytesIO() plt.tight_layout() plt.axis('off') plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) result_image = Image.open(buf) plt.close(fig) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info to detection text device_info = "GPU" if torch.cuda.is_available() else "CPU" performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}" detection_text = "\n".join(detections) if detections else "No objects detected" detection_text += performance_info return result_image, detection_text def process_vit(image): if vit_model is None or vit_processor is None: return "ViT model not loaded" # Measure inference time import time start_time = time.time() # Prepare image for the model inputs = vit_processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = vit_model(**inputs) # Extract raw logits (unnormalized scores) from model output # Hugging Face models return logits directly, not probabilities logits = outputs.logits # Get the predicted class # argmax(-1) finds the index with highest score across the last dimension (class dimension) # item() converts the tensor value to a Python scalar predicted_class_idx = logits.argmax(-1).item() # Map the class index to human-readable label using the model's configuration prediction = vit_model.config.id2label[predicted_class_idx] # Get top 5 predictions # Apply softmax to convert raw logits to probabilities # softmax normalizes the exponentials of logits so they sum to 1.0 # dim=-1 applies softmax along the class dimension # Shape before softmax: [1, num_classes] (batch_size=1, num_classes=1000) # [0] extracts the first (and only) item from the batch dimension # Shape after [0]: [num_classes] (a 1D tensor with 1000 class probabilities) probs = torch.nn.functional.softmax(logits, dim=-1)[0] # Get the values and indices of the 5 highest probabilities top5_prob, top5_indices = torch.topk(probs, 5) results = [] for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): class_name = vit_model.config.id2label[idx.item()] results.append("{}. {}: {:.3f}".format(i+1, class_name, prob.item())) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info to results device_info = "GPU" if torch.cuda.is_available() else "CPU" performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}" result_text = "\n".join(results) result_text += performance_info return result_text # Define Gradio interface with gr.Blocks(title="Object Detection Demo") as demo: gr.Markdown(""" # Multi-Model Object Detection Demo This demo showcases three different object detection and image classification models: - **YOLOv8**: Fast and accurate object detection - **DETR**: DEtection TRansformer for object detection - **ViT**: Vision Transformer for image classification Upload an image to see how each model performs! """) with gr.Row(): input_image = gr.Image(type="pil", label="Input Image") with gr.Row(): yolo_button = gr.Button("Detect with YOLOv8") detr_button = gr.Button("Detect with DETR") vit_button = gr.Button("Classify with ViT") with gr.Row(): with gr.Column(): yolo_output = gr.Image(type="pil", label="YOLOv8 Detection") yolo_text = gr.Textbox(label="YOLOv8 Results") with gr.Column(): detr_output = gr.Image(type="pil", label="DETR Detection") detr_text = gr.Textbox(label="DETR Results") with gr.Column(): vit_text = gr.Textbox(label="ViT Classification Results") # 벡터 DB 저장 버튼 및 결과 표시 with gr.Row(): with gr.Column(): gr.Markdown("### 벡터 DB 저장") save_yolo_button = gr.Button("YOLOv8 인식 결과 저장", variant="primary") save_detr_button = gr.Button("DETR 인식 결과 저장", variant="primary") save_vit_button = gr.Button("ViT 분류 결과 저장", variant="primary") save_result = gr.Textbox(label="벡터 DB 저장 결과") with gr.Column(): gr.Markdown("### 벡터 DB 검색") search_class = gr.Textbox(label="클래스 이름으로 검색") search_button = gr.Button("검색", variant="secondary") search_result_text = gr.Textbox(label="검색 결과 정보") search_result_gallery = gr.Gallery(label="검색 결과", columns=5, height=300) # 객체 인식 결과 저장용 상태 변수 yolo_detection_state = gr.State(None) detr_detection_state = gr.State(None) vit_classification_state = gr.State(None) # Set up event handlers yolo_button.click( fn=process_yolo, inputs=input_image, outputs=[yolo_output, yolo_text, yolo_detection_state] ) # DETR 결과 처리 함수 수정 - 상태 저장 추가 def process_detr_with_state(image): result_image, result_text = process_detr(image) # 객체 인식 결과 추출 detection_results = {"objects": []} # 결과 텍스트에서 객체 정보 추출 lines = result_text.split('\n') for line in lines: if ': ' in line and ' at ' in line: try: class_conf, location = line.split(' at ') class_name, confidence = class_conf.split(': ') confidence = float(confidence) # 바운딩 박스 정보 추출 bbox_str = location.strip('[]').split(', ') bbox = [int(coord) for coord in bbox_str] detection_results["objects"].append({ "class": class_name, "confidence": confidence, "bbox": bbox }) except Exception: pass return result_image, result_text, detection_results # ViT 결과 처리 함수 수정 - 상태 저장 추가 def process_vit_with_state(image): result_text = process_vit(image) # 분류 결과 추출 classifications = [] # 결과 텍스트에서 분류 정보 추출 lines = result_text.split('\n') for line in lines: if '. ' in line and ': ' in line: try: rank_class, confidence = line.split(': ') _, class_name = rank_class.split('. ') confidence = float(confidence) classifications.append({ "class": class_name, "confidence": confidence }) except Exception: pass return result_text, {"classifications": classifications} detr_button.click( fn=process_detr_with_state, inputs=input_image, outputs=[detr_output, detr_text, detr_detection_state] ) vit_button.click( fn=process_vit_with_state, inputs=input_image, outputs=[vit_text, vit_classification_state] ) # 벡터 DB 저장 버튼 이벤트 핸들러 save_yolo_button.click( fn=lambda img, det: save_objects_to_vector_db(img, det, 'yolo'), inputs=[input_image, yolo_detection_state], outputs=save_result ) save_detr_button.click( fn=lambda img, det: save_objects_to_vector_db(img, det, 'detr'), inputs=[input_image, detr_detection_state], outputs=save_result ) save_vit_button.click( fn=lambda img, det: save_objects_to_vector_db(img, det, 'vit'), inputs=[input_image, vit_classification_state], outputs=save_result ) # 검색 버튼 이벤트 핸들러 def format_search_results(result_text, results): images = [] captions = [] for result in results: if result.get('image'): images.append(result['image']) captions.append(result['info']) return result_text, [(img, cap) for img, cap in zip(images, captions)] search_button.click( fn=lambda class_name: search_similar_objects(class_name=class_name), inputs=search_class, outputs=[search_result_text, search_result_gallery] ) # Launch the app if __name__ == "__main__": demo.launch()