from flask import Flask, request, jsonify, send_from_directory import torch from PIL import Image import numpy as np import os import io import base64 import matplotlib.pyplot as plt from matplotlib.patches import Rectangle import time from flask_cors import CORS import json import chromadb from chromadb.utils import embedding_functions app = Flask(__name__, static_folder='static') CORS(app) # Enable CORS for all routes # Model initialization print("Loading models... This may take a moment.") # Image embedding model (CLIP) for vector search clip_model = None clip_processor = None try: from transformers import CLIPProcessor, CLIPModel # 임시 디렉토리 사용 import tempfile temp_dir = tempfile.gettempdir() os.environ["TRANSFORMERS_CACHE"] = temp_dir # CLIP 모델 로드 (이미지 임베딩용) clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") print("CLIP model loaded successfully") except Exception as e: print("Error loading CLIP model:", e) clip_model = None clip_processor = None # Vector DB 초기화 vector_db = None image_collection = None try: # ChromaDB 클라이언트 초기화 (인메모리 DB) vector_db = chromadb.Client() # 임베딩 함수 설정 ef = embedding_functions.DefaultEmbeddingFunction() # 이미지 컬렉션 생성 image_collection = vector_db.create_collection( name="image_collection", embedding_function=ef, get_or_create=True ) print("Vector DB initialized successfully") except Exception as e: print("Error initializing Vector DB:", e) vector_db = None image_collection = None # YOLOv8 model yolo_model = None try: import os from ultralytics import YOLO # 모델 파일 경로 - 임시 디렉토리 사용 import tempfile temp_dir = tempfile.gettempdir() model_path = os.path.join(temp_dir, "yolov8n.pt") # 모델 파일이 없으면 직접 다운로드 if not os.path.exists(model_path): print(f"Downloading YOLOv8 model to {model_path}...") try: os.system(f"wget -q https://ultralytics.com/assets/yolov8n.pt -O {model_path}") print("YOLOv8 model downloaded successfully") except Exception as e: print(f"Error downloading YOLOv8 model: {e}") # 다운로드 실패 시 대체 URL 시도 try: os.system(f"wget -q https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt -O {model_path}") print("YOLOv8 model downloaded from alternative source") except Exception as e2: print(f"Error downloading from alternative source: {e2}") # 마지막 대안으로 직접 모델 URL 사용 try: os.system(f"curl -L https://ultralytics.com/assets/yolov8n.pt --output {model_path}") print("YOLOv8 model downloaded using curl") except Exception as e3: print(f"All download attempts failed: {e3}") # 환경 변수 설정 - 설정 파일 경로 지정 os.environ["YOLO_CONFIG_DIR"] = temp_dir os.environ["MPLCONFIGDIR"] = temp_dir yolo_model = YOLO(model_path) # Using the nano model for faster inference print("YOLOv8 model loaded successfully") except Exception as e: print("Error loading YOLOv8 model:", e) yolo_model = None # DETR model (DEtection TRansformer) detr_processor = None detr_model = None try: from transformers import DetrImageProcessor, DetrForObjectDetection detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") print("DETR model loaded successfully") except Exception as e: print("Error loading DETR model:", e) detr_processor = None detr_model = None # ViT model vit_processor = None vit_model = None try: from transformers import ViTImageProcessor, ViTForImageClassification vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") print("ViT model loaded successfully") except Exception as e: print("Error loading ViT model:", e) vit_processor = None vit_model = None # Get device information device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # LLM model (using an open-access model instead of Llama 4 which requires authentication) llm_model = None llm_tokenizer = None try: from transformers import AutoModelForCausalLM, AutoTokenizer print("Loading LLM model... This may take a moment.") model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Using TinyLlama as an open-access alternative llm_tokenizer = AutoTokenizer.from_pretrained(model_name) llm_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # Removing options that require accelerate package # device_map="auto", # load_in_8bit=True ).to(device) print("LLM model loaded successfully") except Exception as e: print(f"Error loading LLM model: {e}") llm_model = None llm_tokenizer = None def process_llm_query(vision_results, user_query): """Process a query with the LLM model using vision results and user text""" if llm_model is None or llm_tokenizer is None: return {"error": "LLM model not available"} # 결과 데이터 요약 (토큰 길이 제한을 위해) summarized_results = [] # 객체 탐지 결과 요약 if isinstance(vision_results, list): # 최대 10개 객체만 포함 for i, obj in enumerate(vision_results[:10]): if isinstance(obj, dict): # 필요한 정보만 추출 summary = { "label": obj.get("label", "unknown"), "confidence": obj.get("confidence", 0), } summarized_results.append(summary) # Create a prompt combining vision results and user query prompt = f"""You are an AI assistant analyzing image detection results. Here are the objects detected in the image: {json.dumps(summarized_results, indent=2)} User question: {user_query} Please provide a detailed analysis based on the detected objects and the user's question. """ # Tokenize and generate response try: start_time = time.time() # 토큰 길이 확인 및 제한 tokens = llm_tokenizer.encode(prompt) if len(tokens) > 1500: # 안전 마진 설정 prompt = f"""You are an AI assistant analyzing image detection results. The image contains {len(summarized_results)} detected objects. User question: {user_query} Please provide a general analysis based on the user's question. """ inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output = llm_model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True ) response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) # Remove the prompt from the response if response_text.startswith(prompt): response_text = response_text[len(prompt):].strip() inference_time = time.time() - start_time return { "response": response_text, "performance": { "inference_time": round(inference_time, 3), "device": "GPU" if torch.cuda.is_available() else "CPU" } } except Exception as e: return {"error": f"Error processing LLM query: {str(e)}"} def image_to_base64(img): """Convert PIL Image to base64 string""" buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') return img_str def process_yolo(image): if yolo_model is None: return {"error": "YOLOv8 model not loaded"} # Measure inference time start_time = time.time() # Convert to numpy if it's a PIL image if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image # Run inference results = yolo_model(image_np) # Process results result_image = results[0].plot() result_image = Image.fromarray(result_image) # Get detection information boxes = results[0].boxes class_names = results[0].names # Format detection results detections = [] for box in boxes: class_id = int(box.cls[0].item()) class_name = class_names[class_id] confidence = round(box.conf[0].item(), 2) bbox = box.xyxy[0].tolist() bbox = [round(x) for x in bbox] detections.append({ "class": class_name, "confidence": confidence, "bbox": bbox }) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info device_info = "GPU" if torch.cuda.is_available() else "CPU" return { "image": image_to_base64(result_image), "detections": detections, "performance": { "inference_time": round(inference_time, 3), "device": device_info } } def process_detr(image): if detr_model is None or detr_processor is None: return {"error": "DETR model not loaded"} # Measure inference time start_time = time.time() # Prepare image for the model inputs = detr_processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = detr_model(**inputs) # Process results target_sizes = torch.tensor([image.size[::-1]]) results = detr_processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=0.9 )[0] # Create a copy of the image to draw on result_image = image.copy() fig, ax = plt.subplots(1) ax.imshow(result_image) # Format detection results detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i) for i in box.tolist()] class_name = detr_model.config.id2label[label.item()] confidence = round(score.item(), 2) # Draw rectangle rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=2, edgecolor='r', facecolor='none') ax.add_patch(rect) # Add label plt.text(box[0], box[1], "{}: {}".format(class_name, confidence), bbox=dict(facecolor='white', alpha=0.8)) detections.append({ "class": class_name, "confidence": confidence, "bbox": box }) # Save figure to image buf = io.BytesIO() plt.tight_layout() plt.axis('off') plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) result_image = Image.open(buf) plt.close(fig) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info device_info = "GPU" if torch.cuda.is_available() else "CPU" return { "image": image_to_base64(result_image), "detections": detections, "performance": { "inference_time": round(inference_time, 3), "device": device_info } } def process_vit(image): if vit_model is None or vit_processor is None: return {"error": "ViT model not loaded"} # Measure inference time start_time = time.time() # Prepare image for the model inputs = vit_processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = vit_model(**inputs) logits = outputs.logits # Get the predicted class predicted_class_idx = logits.argmax(-1).item() prediction = vit_model.config.id2label[predicted_class_idx] # Get top 5 predictions probs = torch.nn.functional.softmax(logits, dim=-1)[0] top5_prob, top5_indices = torch.topk(probs, 5) results = [] for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): class_name = vit_model.config.id2label[idx.item()] results.append({ "rank": i+1, "class": class_name, "probability": round(prob.item(), 3) }) # Calculate inference time inference_time = time.time() - start_time # Add inference time and device info device_info = "GPU" if torch.cuda.is_available() else "CPU" return { "top_predictions": results, "performance": { "inference_time": round(inference_time, 3), "device": device_info } } @app.route('/api/detect/yolo', methods=['POST']) def yolo_detect(): if 'image' not in request.files: return jsonify({"error": "No image provided"}), 400 file = request.files['image'] image = Image.open(file.stream) result = process_yolo(image) return jsonify(result) @app.route('/api/detect/detr', methods=['POST']) def detr_detect(): if 'image' not in request.files: return jsonify({"error": "No image provided"}), 400 file = request.files['image'] image = Image.open(file.stream) result = process_detr(image) return jsonify(result) @app.route('/api/classify/vit', methods=['POST']) def vit_classify(): if 'image' not in request.files: return jsonify({"error": "No image provided"}), 400 file = request.files['image'] image = Image.open(file.stream) result = process_vit(image) return jsonify(result) @app.route('/api/analyze', methods=['POST']) def analyze_with_llm(): # Check if required data is in the request if not request.json: return jsonify({"error": "No JSON data provided"}), 400 # Extract vision results and user query from request data = request.json if 'visionResults' not in data or 'userQuery' not in data: return jsonify({"error": "Missing required fields: visionResults or userQuery"}), 400 vision_results = data['visionResults'] user_query = data['userQuery'] # Process the query with LLM result = process_llm_query(vision_results, user_query) return jsonify(result) def generate_image_embedding(image): """CLIP 모델을 사용하여 이미지 임베딩 생성""" if clip_model is None or clip_processor is None: return None try: # 이미지 전처리 inputs = clip_processor(images=image, return_tensors="pt") # 이미지 임베딩 생성 with torch.no_grad(): image_features = clip_model.get_image_features(**inputs) # 임베딩 정규화 및 numpy 배열로 변환 image_embedding = image_features.squeeze().cpu().numpy() normalized_embedding = image_embedding / np.linalg.norm(image_embedding) return normalized_embedding.tolist() except Exception as e: print(f"Error generating image embedding: {e}") return None @app.route('/api/similar-images', methods=['POST']) def find_similar_images(): """유사 이미지 검색 API""" if clip_model is None or clip_processor is None or image_collection is None: return jsonify({"error": "Image embedding model or vector DB not available"}) try: # 요청에서 이미지 데이터 추출 if 'image' not in request.files and 'image' not in request.form: return jsonify({"error": "No image provided"}) if 'image' in request.files: # 파일로 업로드된 경우 image_file = request.files['image'] image = Image.open(image_file).convert('RGB') else: # base64로 인코딩된 경우 image_data = request.form['image'] if image_data.startswith('data:image'): # Remove the data URL prefix if present image_data = image_data.split(',')[1] image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') # 이미지 ID 생성 (임시) image_id = str(uuid.uuid4()) # 이미지 임베딩 생성 embedding = generate_image_embedding(image) if embedding is None: return jsonify({"error": "Failed to generate image embedding"}) # 현재 이미지를 DB에 추가 (선택적) # image_collection.add( # ids=[image_id], # embeddings=[embedding] # ) # 유사 이미지 검색 results = image_collection.query( query_embeddings=[embedding], n_results=5 # 상위 5개 결과 반환 ) # 결과 포맷팅 similar_images = [] if len(results['ids'][0]) > 0: for i, img_id in enumerate(results['ids'][0]): similar_images.append({ "id": img_id, "distance": float(results['distances'][0][i]) if 'distances' in results else 0.0, "metadata": results['metadatas'][0][i] if 'metadatas' in results else {} }) return jsonify({ "query_image_id": image_id, "similar_images": similar_images }) except Exception as e: print(f"Error in similar-images API: {e}") return jsonify({"error": str(e)}), 500 @app.route('/api/add-to-collection', methods=['POST']) def add_to_collection(): """이미지를 벡터 DB에 추가하는 API""" if clip_model is None or clip_processor is None or image_collection is None: return jsonify({"error": "Image embedding model or vector DB not available"}) try: # 요청에서 이미지 데이터 추출 if 'image' not in request.files and 'image' not in request.form: return jsonify({"error": "No image provided"}) # 메타데이터 추출 metadata = {} if 'metadata' in request.form: metadata = json.loads(request.form['metadata']) # 이미지 ID (제공되지 않은 경우 자동 생성) image_id = request.form.get('id', str(uuid.uuid4())) if 'image' in request.files: # 파일로 업로드된 경우 image_file = request.files['image'] image = Image.open(image_file).convert('RGB') else: # base64로 인코딩된 경우 image_data = request.form['image'] if image_data.startswith('data:image'): # Remove the data URL prefix if present image_data = image_data.split(',')[1] image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') # 이미지 임베딩 생성 embedding = generate_image_embedding(image) if embedding is None: return jsonify({"error": "Failed to generate image embedding"}) # 이미지를 DB에 추가 image_collection.add( ids=[image_id], embeddings=[embedding], metadatas=[metadata] ) return jsonify({ "success": True, "image_id": image_id, "message": "Image added to collection" }) except Exception as e: print(f"Error in add-to-collection API: {e}") return jsonify({"error": str(e)}), 500 @app.route('/', defaults={'path': ''}, methods=['GET']) @app.route('/', methods=['GET']) def serve_react(path): """Serve React frontend""" if path != "" and os.path.exists(os.path.join(app.static_folder, path)): return send_from_directory(app.static_folder, path) else: return send_from_directory(app.static_folder, 'index.html') @app.route('/similar-images', methods=['GET']) def similar_images_page(): """Serve similar images search page""" return send_from_directory(app.static_folder, 'similar-images.html') @app.route('/api/status', methods=['GET']) def status(): return jsonify({ "status": "online", "models": { "yolo": yolo_model is not None, "detr": detr_model is not None and detr_processor is not None, "vit": vit_model is not None and vit_processor is not None }, "device": "GPU" if torch.cuda.is_available() else "CPU" }) def index(): return send_from_directory('static', 'index.html') if __name__ == "__main__": # 허깅페이스 Space에서는 PORT 환경 변수를 사용합니다 port = int(os.environ.get("PORT", 7860)) app.run(debug=False, host='0.0.0.0', port=port)