import gradio as gr import torch from PIL import Image import numpy as np import os import requests import json import base64 from io import BytesIO import uuid # Model initialization print("Loading models... This may take a moment.") # YOLOv8 model yolo_model = None try: from ultralytics import YOLO yolo_model = YOLO("yolov8n.pt") # Using the nano model for faster inference print("YOLOv8 model loaded successfully") except Exception as e: print("Error loading YOLOv8 model:", e) yolo_model = None # DETR model (DEtection TRansformer) detr_processor = None detr_model = None try: from transformers import DetrImageProcessor, DetrForObjectDetection # Load the DETR image processor # DetrImageProcessor: Handles preprocessing of images for DETR model # - Resizes images to appropriate dimensions # - Normalizes pixel values # - Converts images to tensors # - Handles batch processing detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") # Load the DETR object detection model # DetrForObjectDetection: The actual object detection model # - Uses ResNet-50 as backbone # - Transformer-based architecture for object detection # - Predicts bounding boxes and object classes # - Pre-trained on COCO dataset by Facebook AI Research detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") print("DETR model loaded successfully") except Exception as e: print("Error loading DETR model:", e) detr_processor = None detr_model = None # ViT model vit_processor = None vit_model = None try: from transformers import ViTImageProcessor, ViTForImageClassification vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") print("ViT model loaded successfully") except Exception as e: print("Error loading ViT model:", e) vit_processor = None vit_model = None # Get device information import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 벡터 DB에 객체 저장 함수 def save_objects_to_vector_db(image, detection_results): if image is None or detection_results is None: return "이미지나 객체 인식 결과가 없습니다." try: # 이미지를 base64로 인코딩 buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') # 객체 정보 추출 objects = [] for obj in detection_results['objects']: objects.append({ "class": obj['class'], "confidence": obj['confidence'], "bbox": obj['bbox'] }) # API 요청 데이터 구성 data = { "image": img_str, "objects": objects, "image_id": str(uuid.uuid4()) } # API 호출 response = requests.post( "http://localhost:7860/api/add-detected-objects", json=data ) if response.status_code == 200: result = response.json() return f"벡터 DB에 {len(objects)}개 객체 저장 성공! 저장된 객체 ID: {', '.join(result.get('object_ids', [])[:3])}..." else: return f"저장 실패: {response.text}" except Exception as e: return f"오류 발생: {str(e)}" # Define model inference functions def process_yolo(image): if yolo_model is None: return None, "YOLOv8 model not loaded", None # Measure inference time import time start_time = time.time() # Convert to numpy if it's a PIL image if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image # Run inference results = yolo_model(image_np) # Process results result_image = results[0].plot() result_image = Image.fromarray(result_image) # Get detection information boxes = results[0].boxes class_names = results[0].names # Format detection results detections = [] detection_objects = {'objects': []} for box in boxes: class_id = int(box.cls[0].item()) class_name = class_names[class_id] confidence = round(box.conf[0].item(), 2) bbox = box.xyxy[0].tolist() bbox = [round(x) for x in bbox] detections.append("{}: {} at {}".format(class_name, confidence, bbox)) # 벡터 DB 저장용 객체 정보 추가 detection_objects['objects'].append({ 'class': class_name, 'confidence': confidence, 'bbox': bbox }) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info to detection text device_info = "GPU" if torch.cuda.is_available() else "CPU" performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}" detection_text = "\n".join(detections) if detections else "No objects detected" detection_text += performance_info return result_image, detection_text, detection_objects return result_image, detection_text def process_detr(image): if detr_model is None or detr_processor is None: return None, "DETR model not loaded" # Measure inference time import time start_time = time.time() # Prepare image for the model inputs = detr_processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = detr_model(**inputs) # Convert outputs to image with bounding boxes # Create tensor with original image dimensions (height, width) # image.size[::-1] reverses the (width, height) to (height, width) as required by DETR target_sizes = torch.tensor([image.size[::-1]]) # Process raw model outputs into usable detection results # - Maps predictions back to original image size # - Filters detections using confidence threshold (0.9) # - Returns a dictionary with 'scores', 'labels', and 'boxes' keys # - [0] extracts results for the first (and only) image in the batch results = detr_processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=0.9 )[0] # Create a copy of the image to draw on result_image = image.copy() import matplotlib.pyplot as plt from matplotlib.patches import Rectangle import io # Create figure and axes fig, ax = plt.subplots(1) ax.imshow(result_image) # Format detection results detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i) for i in box.tolist()] class_name = detr_model.config.id2label[label.item()] confidence = round(score.item(), 2) # Draw rectangle rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=2, edgecolor='r', facecolor='none') ax.add_patch(rect) # Add label plt.text(box[0], box[1], "{}: {}".format(class_name, confidence), bbox=dict(facecolor='white', alpha=0.8)) detections.append("{}: {} at {}".format(class_name, confidence, box)) # Save figure to image buf = io.BytesIO() plt.tight_layout() plt.axis('off') plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) result_image = Image.open(buf) plt.close(fig) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info to detection text device_info = "GPU" if torch.cuda.is_available() else "CPU" performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}" detection_text = "\n".join(detections) if detections else "No objects detected" detection_text += performance_info return result_image, detection_text def process_vit(image): if vit_model is None or vit_processor is None: return "ViT model not loaded" # Measure inference time import time start_time = time.time() # Prepare image for the model inputs = vit_processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = vit_model(**inputs) # Extract raw logits (unnormalized scores) from model output # Hugging Face models return logits directly, not probabilities logits = outputs.logits # Get the predicted class # argmax(-1) finds the index with highest score across the last dimension (class dimension) # item() converts the tensor value to a Python scalar predicted_class_idx = logits.argmax(-1).item() # Map the class index to human-readable label using the model's configuration prediction = vit_model.config.id2label[predicted_class_idx] # Get top 5 predictions # Apply softmax to convert raw logits to probabilities # softmax normalizes the exponentials of logits so they sum to 1.0 # dim=-1 applies softmax along the class dimension # Shape before softmax: [1, num_classes] (batch_size=1, num_classes=1000) # [0] extracts the first (and only) item from the batch dimension # Shape after [0]: [num_classes] (a 1D tensor with 1000 class probabilities) probs = torch.nn.functional.softmax(logits, dim=-1)[0] # Get the values and indices of the 5 highest probabilities top5_prob, top5_indices = torch.topk(probs, 5) results = [] for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): class_name = vit_model.config.id2label[idx.item()] results.append("{}. {}: {:.3f}".format(i+1, class_name, prob.item())) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info to results device_info = "GPU" if torch.cuda.is_available() else "CPU" performance_info = f"\n\nInference time: {inference_time:.3f} seconds on {device_info}" result_text = "\n".join(results) result_text += performance_info return result_text # Define Gradio interface with gr.Blocks(title="Object Detection Demo") as demo: gr.Markdown(""" # Multi-Model Object Detection Demo This demo showcases three different object detection and image classification models: - **YOLOv8**: Fast and accurate object detection - **DETR**: DEtection TRansformer for object detection - **ViT**: Vision Transformer for image classification Upload an image to see how each model performs! """) with gr.Row(): input_image = gr.Image(type="pil", label="Input Image") with gr.Row(): yolo_button = gr.Button("Detect with YOLOv8") detr_button = gr.Button("Detect with DETR") vit_button = gr.Button("Classify with ViT") with gr.Row(): with gr.Column(): yolo_output = gr.Image(type="pil", label="YOLOv8 Detection") yolo_text = gr.Textbox(label="YOLOv8 Results") with gr.Column(): detr_output = gr.Image(type="pil", label="DETR Detection") detr_text = gr.Textbox(label="DETR Results") with gr.Column(): vit_text = gr.Textbox(label="ViT Classification Results") # 벡터 DB 저장 버튼 및 결과 표시 with gr.Row(): with gr.Column(): save_to_db_button = gr.Button("YOLOv8 인식 결과를 벡터 DB에 저장", variant="primary") save_result = gr.Textbox(label="벡터 DB 저장 결과") # 객체 인식 결과 저장용 상태 변수 detection_state = gr.State(None) # Set up event handlers yolo_result = yolo_button.click( fn=process_yolo, inputs=input_image, outputs=[yolo_output, yolo_text, detection_state] ) # 벡터 DB 저장 버튼 이벤트 핸들러 save_to_db_button.click( fn=save_objects_to_vector_db, inputs=[input_image, detection_state], outputs=save_result ) detr_button.click( fn=process_detr, inputs=input_image, outputs=[detr_output, detr_text] ) vit_button.click( fn=process_vit, inputs=input_image, outputs=vit_text ) # Launch the app if __name__ == "__main__": demo.launch()