Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Set matplotlib config directory to avoid permission issues | |
| import os | |
| os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
| from flask import Flask, request, jsonify, send_from_directory, redirect, url_for, session, render_template_string, make_response | |
| from datetime import timedelta | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| from io import BytesIO | |
| import base64 | |
| import uuid | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Rectangle | |
| import time | |
| from flask_cors import CORS | |
| import json | |
| import sys | |
| import requests | |
| try: | |
| from openai import OpenAI | |
| except Exception as _e: | |
| OpenAI = None | |
| from flask_login import ( | |
| LoginManager, | |
| UserMixin, | |
| login_user, | |
| logout_user, | |
| login_required, | |
| current_user, | |
| fresh_login_required, | |
| login_fresh, | |
| ) | |
| # Fix for SQLite3 version compatibility with ChromaDB | |
| try: | |
| import pysqlite3 | |
| sys.modules['sqlite3'] = pysqlite3 | |
| except ImportError: | |
| print("Warning: pysqlite3 not found, using built-in sqlite3") | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| app = Flask(__name__, static_folder='static') | |
| app.secret_key = 'your_secret_key_here' # ์ธ์ ์ํธํ๋ฅผ ์ํ ๋น๋ฐ ํค | |
| app.config['CORS_HEADERS'] = 'Content-Type' | |
| # Remember cookie (Flask-Login) โ minimize duration to prevent auto re-login | |
| app.config['REMEMBER_COOKIE_DURATION'] = timedelta(seconds=1) | |
| app.config['REMEMBER_COOKIE_SECURE'] = True # Spaces uses HTTPS | |
| app.config['REMEMBER_COOKIE_HTTPONLY'] = True | |
| app.config['REMEMBER_COOKIE_SAMESITE'] = 'None' | |
| # Session cookie (Flask-Session) | |
| app.config['SESSION_COOKIE_SECURE'] = True # HTTPS | |
| app.config['SESSION_COOKIE_HTTPONLY'] = True | |
| app.config['SESSION_COOKIE_SAMESITE'] = 'None' | |
| app.config['SESSION_COOKIE_PATH'] = '/' | |
| CORS(app) # Enable CORS for all routes | |
| # ์ํฌ๋ฆฟ ํค ์ค์ (์ธ์ ์ํธํ์ ์ฌ์ฉ) | |
| app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'vision_llm_agent_secret_key') | |
| app.config['SESSION_TYPE'] = 'filesystem' | |
| app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(seconds=120) # ์ธ์ ์ ํจ ์๊ฐ (2๋ถ) | |
| app.config['SESSION_REFRESH_EACH_REQUEST'] = False # ์ ๋ ๋ง๋ฃ(๋ก๊ทธ์ธ ๊ธฐ์ค 2๋ถ ํ ๋ง๋ฃ) | |
| # Flask-Login ์ค์ | |
| login_manager = LoginManager() | |
| login_manager.init_app(app) | |
| login_manager.login_view = 'login' | |
| login_manager.session_protection = 'strong' | |
| # When authentication is required or session is not fresh, redirect to login instead of 401 | |
| login_manager.refresh_view = 'login' | |
| def handle_unauthorized(): | |
| # For non-authenticated access, send user to login | |
| return redirect(url_for('login')) | |
| def handle_needs_refresh(): | |
| # For non-fresh sessions (e.g., after expiry or only remember-cookie), send to login | |
| return redirect(url_for('login')) | |
| # ์ธ์ ์ค์ | |
| import tempfile | |
| from flask_session import Session | |
| # ์์ ๋๋ ํ ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ๊ถํ ๋ฌธ์ ํด๊ฒฐ | |
| session_dir = tempfile.gettempdir() | |
| app.config['SESSION_TYPE'] = 'filesystem' | |
| app.config['SESSION_PERMANENT'] = True | |
| app.config['SESSION_USE_SIGNER'] = True | |
| app.config['SESSION_FILE_DIR'] = session_dir | |
| print(f"Using session directory: {session_dir}") | |
| Session(app) | |
| # ์ฌ์ฉ์ ํด๋์ค ์ ์ | |
| class User(UserMixin): | |
| def __init__(self, id, username, password): | |
| self.id = id | |
| self.username = username | |
| self.password = password | |
| def get_id(self): | |
| return str(self.id) # Flask-Login์ ๋ฌธ์์ด ID๋ฅผ ์๊ตฌํจ | |
| # ํ ์คํธ์ฉ ์ฌ์ฉ์ (์ค์ ํ๊ฒฝ์์๋ ๋ฐ์ดํฐ๋ฒ ์ด์ค ์ฌ์ฉ ๊ถ์ฅ) | |
| users = { | |
| 'admin': User('1', 'admin', 'admin123'), | |
| 'user': User('2', 'user', 'user123') | |
| } | |
| # ์ฌ์ฉ์ ๋ก๋ ํจ์ | |
| def load_user(user_id): | |
| print(f"Loading user with ID: {user_id}") | |
| # ์ธ์ ๋๋ฒ๊ทธ ์ ๋ณด ์ถ๋ ฅ | |
| print(f"Session data in user_loader: {dict(session)}") | |
| print(f"Current request cookies: {request.cookies}") | |
| # user_id๊ฐ ๋ฌธ์์ด๋ก ์ ๋ฌ๋๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์ ID๋ก ์ฒ๋ฆฌ | |
| for username, user in users.items(): | |
| if str(user.id) == str(user_id): # ํ์คํ ๋ฌธ์์ด ๋น๊ต | |
| print(f"User found: {username}, ID: {user.id}") | |
| # ์ธ์ ์ ๋ณด ์ ๋ฐ์ดํธ | |
| session['user_id'] = user.id | |
| session['username'] = username | |
| session.modified = True | |
| return user | |
| print(f"User not found with ID: {user_id}") | |
| return None | |
| # Model initialization | |
| print("Loading models... This may take a moment.") | |
| # Image embedding model (CLIP) for vector search | |
| clip_model = None | |
| clip_processor = None | |
| try: | |
| from transformers import CLIPProcessor, CLIPModel | |
| # ์์ ๋๋ ํ ๋ฆฌ ์ฌ์ฉ | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| os.environ["TRANSFORMERS_CACHE"] = temp_dir | |
| # CLIP ๋ชจ๋ธ ๋ก๋ (์ด๋ฏธ์ง ์๋ฒ ๋ฉ์ฉ) | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| print("CLIP model loaded successfully") | |
| except Exception as e: | |
| print("Error loading CLIP model:", e) | |
| clip_model = None | |
| clip_processor = None | |
| # Vector DB ์ด๊ธฐํ | |
| vector_db = None | |
| image_collection = None | |
| object_collection = None | |
| try: | |
| # ChromaDB ํด๋ผ์ด์ธํธ ์ด๊ธฐํ (์ธ๋ฉ๋ชจ๋ฆฌ DB) | |
| vector_db = chromadb.Client() | |
| # ์๋ฒ ๋ฉ ํจ์ ์ค์ | |
| ef = embedding_functions.DefaultEmbeddingFunction() | |
| # ์ด๋ฏธ์ง ์ปฌ๋ ์ ์์ฑ | |
| image_collection = vector_db.create_collection( | |
| name="image_collection", | |
| embedding_function=ef, | |
| get_or_create=True | |
| ) | |
| # ๊ฐ์ฒด ์ธ์ ๊ฒฐ๊ณผ ์ปฌ๋ ์ ์์ฑ | |
| object_collection = vector_db.create_collection( | |
| name="object_collection", | |
| embedding_function=ef, | |
| get_or_create=True | |
| ) | |
| print("Vector DB initialized successfully") | |
| except Exception as e: | |
| print("Error initializing Vector DB:", e) | |
| vector_db = None | |
| image_collection = None | |
| object_collection = None | |
| # YOLOv8 model | |
| yolo_model = None | |
| try: | |
| import os | |
| from ultralytics import YOLO | |
| # ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก - ์์ ๋๋ ํ ๋ฆฌ ์ฌ์ฉ | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| model_path = os.path.join(temp_dir, "yolov8n.pt") | |
| # ๋ชจ๋ธ ํ์ผ์ด ์์ผ๋ฉด ์ง์ ๋ค์ด๋ก๋ | |
| if not os.path.exists(model_path): | |
| print(f"Downloading YOLOv8 model to {model_path}...") | |
| try: | |
| os.system(f"wget -q https://ultralytics.com/assets/yolov8n.pt -O {model_path}") | |
| print("YOLOv8 model downloaded successfully") | |
| except Exception as e: | |
| print(f"Error downloading YOLOv8 model: {e}") | |
| # ๋ค์ด๋ก๋ ์คํจ ์ ๋์ฒด URL ์๋ | |
| try: | |
| os.system(f"wget -q https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt -O {model_path}") | |
| print("YOLOv8 model downloaded from alternative source") | |
| except Exception as e2: | |
| print(f"Error downloading from alternative source: {e2}") | |
| # ๋ง์ง๋ง ๋์์ผ๋ก ์ง์ ๋ชจ๋ธ URL ์ฌ์ฉ | |
| try: | |
| os.system(f"curl -L https://ultralytics.com/assets/yolov8n.pt --output {model_path}") | |
| print("YOLOv8 model downloaded using curl") | |
| except Exception as e3: | |
| print(f"All download attempts failed: {e3}") | |
| # ํ๊ฒฝ ๋ณ์ ์ค์ - ์ค์ ํ์ผ ๊ฒฝ๋ก ์ง์ | |
| os.environ["YOLO_CONFIG_DIR"] = temp_dir | |
| os.environ["MPLCONFIGDIR"] = temp_dir | |
| yolo_model = YOLO(model_path) # Using the nano model for faster inference | |
| print("YOLOv8 model loaded successfully") | |
| except Exception as e: | |
| print("Error loading YOLOv8 model:", e) | |
| yolo_model = None | |
| # DETR model (DEtection TRansformer) | |
| detr_processor = None | |
| detr_model = None | |
| try: | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| print("DETR model loaded successfully") | |
| except Exception as e: | |
| print("Error loading DETR model:", e) | |
| detr_processor = None | |
| detr_model = None | |
| # ViT model | |
| vit_processor = None | |
| vit_model = None | |
| try: | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
| print("ViT model loaded successfully") | |
| except Exception as e: | |
| print("Error loading ViT model:", e) | |
| vit_processor = None | |
| vit_model = None | |
| # Get device information | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # LLM model (using an open-access model instead of Llama 4 which requires authentication) | |
| llm_model = None | |
| llm_tokenizer = None | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| print("Loading LLM model... This may take a moment.") | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Using TinyLlama as an open-access alternative | |
| llm_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| # Removing options that require accelerate package | |
| # device_map="auto", | |
| # load_in_8bit=True | |
| ).to(device) | |
| print("LLM model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading LLM model: {e}") | |
| llm_model = None | |
| llm_tokenizer = None | |
| def process_llm_query(vision_results, user_query): | |
| """Process a query with the LLM model using vision results and user text""" | |
| if llm_model is None or llm_tokenizer is None: | |
| return {"error": "LLM model not available"} | |
| # ๊ฒฐ๊ณผ ๋ฐ์ดํฐ ์์ฝ (ํ ํฐ ๊ธธ์ด ์ ํ์ ์ํด) | |
| summarized_results = [] | |
| # ๊ฐ์ฒด ํ์ง ๊ฒฐ๊ณผ ์์ฝ | |
| if isinstance(vision_results, list): | |
| # ์ต๋ 10๊ฐ ๊ฐ์ฒด๋ง ํฌํจ | |
| for i, obj in enumerate(vision_results[:10]): | |
| if isinstance(obj, dict): | |
| # ํ์ํ ์ ๋ณด๋ง ์ถ์ถ | |
| summary = { | |
| "label": obj.get("label", "unknown"), | |
| "confidence": obj.get("confidence", 0), | |
| } | |
| summarized_results.append(summary) | |
| # Create a prompt combining vision results and user query | |
| prompt = f"""You are an AI assistant analyzing image detection results. | |
| Here are the objects detected in the image: {json.dumps(summarized_results, indent=2)} | |
| User question: {user_query} | |
| Please provide a detailed analysis based on the detected objects and the user's question. | |
| """ | |
| # Tokenize and generate response | |
| try: | |
| start_time = time.time() | |
| # ํ ํฐ ๊ธธ์ด ํ์ธ ๋ฐ ์ ํ | |
| tokens = llm_tokenizer.encode(prompt) | |
| if len(tokens) > 1500: # ์์ ๋ง์ง ์ค์ | |
| prompt = f"""You are an AI assistant analyzing image detection results. | |
| The image contains {len(summarized_results)} detected objects. | |
| User question: {user_query} | |
| Please provide a general analysis based on the user's question. | |
| """ | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Remove the prompt from the response | |
| if response_text.startswith(prompt): | |
| response_text = response_text[len(prompt):].strip() | |
| inference_time = time.time() - start_time | |
| return { | |
| "response": response_text, | |
| "performance": { | |
| "inference_time": round(inference_time, 3), | |
| "device": "GPU" if torch.cuda.is_available() else "CPU" | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": f"Error processing LLM query: {str(e)}"} | |
| def image_to_base64(img): | |
| """Convert PIL Image to base64 string""" | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| return img_str | |
| def process_yolo(image): | |
| if yolo_model is None: | |
| return {"error": "YOLOv8 model not loaded"} | |
| # Measure inference time | |
| start_time = time.time() | |
| # Convert to numpy if it's a PIL image | |
| if isinstance(image, Image.Image): | |
| image_np = np.array(image) | |
| else: | |
| image_np = image | |
| # Run inference | |
| results = yolo_model(image_np) | |
| # Process results | |
| result_image = results[0].plot() | |
| result_image = Image.fromarray(result_image) | |
| # Get detection information | |
| boxes = results[0].boxes | |
| class_names = results[0].names | |
| # Format detection results | |
| detections = [] | |
| for box in boxes: | |
| class_id = int(box.cls[0].item()) | |
| class_name = class_names[class_id] | |
| confidence = round(box.conf[0].item(), 2) | |
| bbox = box.xyxy[0].tolist() | |
| bbox = [round(x) for x in bbox] | |
| detections.append({ | |
| "class": class_name, | |
| "confidence": confidence, | |
| "bbox": bbox | |
| }) | |
| # Calculate inference time | |
| inference_time = time.time() - start_time | |
| # Add inference time and device info | |
| device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
| return { | |
| "image": image_to_base64(result_image), | |
| "detections": detections, | |
| "performance": { | |
| "inference_time": round(inference_time, 3), | |
| "device": device_info | |
| } | |
| } | |
| def process_detr(image): | |
| if detr_model is None or detr_processor is None: | |
| return {"error": "DETR model not loaded"} | |
| # Measure inference time | |
| start_time = time.time() | |
| # Prepare image for the model | |
| inputs = detr_processor(images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = detr_model(**inputs) | |
| # Process results | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = detr_processor.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=0.9 | |
| )[0] | |
| # Create a copy of the image to draw on | |
| result_image = image.copy() | |
| fig, ax = plt.subplots(1) | |
| ax.imshow(result_image) | |
| # Format detection results | |
| detections = [] | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| box = [round(i) for i in box.tolist()] | |
| class_name = detr_model.config.id2label[label.item()] | |
| confidence = round(score.item(), 2) | |
| # Draw rectangle | |
| rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], | |
| linewidth=2, edgecolor='r', facecolor='none') | |
| ax.add_patch(rect) | |
| # Add label | |
| plt.text(box[0], box[1], "{}: {}".format(class_name, confidence), | |
| bbox=dict(facecolor='white', alpha=0.8)) | |
| detections.append({ | |
| "class": class_name, | |
| "confidence": confidence, | |
| "bbox": box | |
| }) | |
| # Save figure to image | |
| buf = io.BytesIO() | |
| plt.tight_layout() | |
| plt.axis('off') | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| buf.seek(0) | |
| result_image = Image.open(buf) | |
| plt.close(fig) | |
| # Calculate inference time | |
| inference_time = time.time() - start_time | |
| # Add inference time and device info | |
| device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
| return { | |
| "image": image_to_base64(result_image), | |
| "detections": detections, | |
| "performance": { | |
| "inference_time": round(inference_time, 3), | |
| "device": device_info | |
| } | |
| } | |
| def process_vit(image): | |
| if vit_model is None or vit_processor is None: | |
| return {"error": "ViT model not loaded"} | |
| # Measure inference time | |
| start_time = time.time() | |
| # Prepare image for the model | |
| inputs = vit_processor(images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = vit_model(**inputs) | |
| logits = outputs.logits | |
| # Get the predicted class | |
| predicted_class_idx = logits.argmax(-1).item() | |
| prediction = vit_model.config.id2label[predicted_class_idx] | |
| # Get top 5 predictions | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| top5_prob, top5_indices = torch.topk(probs, 5) | |
| results = [] | |
| for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): | |
| class_name = vit_model.config.id2label[idx.item()] | |
| results.append({ | |
| "rank": i+1, | |
| "class": class_name, | |
| "probability": round(prob.item(), 3) | |
| }) | |
| # Calculate inference time | |
| inference_time = time.time() - start_time | |
| # Add inference time and device info | |
| device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
| return { | |
| "top_predictions": results, | |
| "performance": { | |
| "inference_time": round(inference_time, 3), | |
| "device": device_info | |
| } | |
| } | |
| def yolo_detect(): | |
| if 'image' not in request.files: | |
| return jsonify({"error": "No image provided"}), 400 | |
| file = request.files['image'] | |
| image = Image.open(file.stream) | |
| result = process_yolo(image) | |
| return jsonify(result) | |
| def detr_detect(): | |
| if 'image' not in request.files: | |
| return jsonify({"error": "No image provided"}), 400 | |
| file = request.files['image'] | |
| image = Image.open(file.stream) | |
| result = process_detr(image) | |
| return jsonify(result) | |
| def vit_classify(): | |
| if 'image' not in request.files: | |
| return jsonify({"error": "No image provided"}), 400 | |
| file = request.files['image'] | |
| image = Image.open(file.stream) | |
| result = process_vit(image) | |
| return jsonify(result) | |
| def analyze_with_llm(): | |
| # Check if required data is in the request | |
| if not request.json: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| # Extract vision results and user query from request | |
| data = request.json | |
| if 'visionResults' not in data or 'userQuery' not in data: | |
| return jsonify({"error": "Missing required fields: visionResults or userQuery"}), 400 | |
| vision_results = data['visionResults'] | |
| user_query = data['userQuery'] | |
| # Process the query with LLM | |
| result = process_llm_query(vision_results, user_query) | |
| return jsonify(result) | |
| def generate_image_embedding(image): | |
| """CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ""" | |
| if clip_model is None or clip_processor is None: | |
| return None | |
| try: | |
| # ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
| inputs = clip_processor(images=image, return_tensors="pt") | |
| # ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ | |
| with torch.no_grad(): | |
| image_features = clip_model.get_image_features(**inputs) | |
| # ์๋ฒ ๋ฉ ์ ๊ทํ ๋ฐ numpy ๋ฐฐ์ด๋ก ๋ณํ | |
| image_embedding = image_features.squeeze().cpu().numpy() | |
| normalized_embedding = image_embedding / np.linalg.norm(image_embedding) | |
| return normalized_embedding.tolist() | |
| except Exception as e: | |
| print(f"Error generating image embedding: {e}") | |
| return None | |
| def find_similar_images(): | |
| """์ ์ฌ ์ด๋ฏธ์ง ๊ฒ์ API""" | |
| if clip_model is None or clip_processor is None or image_collection is None: | |
| return jsonify({"error": "Image embedding model or vector DB not available"}) | |
| try: | |
| # ์์ฒญ์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ถ์ถ | |
| if 'image' not in request.files and 'image' not in request.form: | |
| return jsonify({"error": "No image provided"}) | |
| if 'image' in request.files: | |
| # ํ์ผ๋ก ์ ๋ก๋๋ ๊ฒฝ์ฐ | |
| image_file = request.files['image'] | |
| image = Image.open(image_file).convert('RGB') | |
| else: | |
| # base64๋ก ์ธ์ฝ๋ฉ๋ ๊ฒฝ์ฐ | |
| image_data = request.form['image'] | |
| if image_data.startswith('data:image'): | |
| # Remove the data URL prefix if present | |
| image_data = image_data.split(',')[1] | |
| image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
| # ์ด๋ฏธ์ง ID ์์ฑ (์์) | |
| image_id = str(uuid.uuid4()) | |
| # ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ | |
| embedding = generate_image_embedding(image) | |
| if embedding is None: | |
| return jsonify({"error": "Failed to generate image embedding"}) | |
| # ํ์ฌ ์ด๋ฏธ์ง๋ฅผ DB์ ์ถ๊ฐ (์ ํ์ ) | |
| # image_collection.add( | |
| # ids=[image_id], | |
| # embeddings=[embedding] | |
| # ) | |
| # ์ ์ฌ ์ด๋ฏธ์ง ๊ฒ์ | |
| results = image_collection.query( | |
| query_embeddings=[embedding], | |
| n_results=5 # ์์ 5๊ฐ ๊ฒฐ๊ณผ ๋ฐํ | |
| ) | |
| # ๊ฒฐ๊ณผ ํฌ๋งทํ | |
| similar_images = [] | |
| if len(results['ids'][0]) > 0: | |
| for i, img_id in enumerate(results['ids'][0]): | |
| similar_images.append({ | |
| "id": img_id, | |
| "distance": float(results['distances'][0][i]) if 'distances' in results else 0.0, | |
| "metadata": results['metadatas'][0][i] if 'metadatas' in results else {} | |
| }) | |
| return jsonify({ | |
| "query_image_id": image_id, | |
| "similar_images": similar_images | |
| }) | |
| except Exception as e: | |
| print(f"Error in similar-images API: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def add_to_collection(): | |
| """์ด๋ฏธ์ง๋ฅผ ๋ฒกํฐ DB์ ์ถ๊ฐํ๋ API""" | |
| if clip_model is None or clip_processor is None or image_collection is None: | |
| return jsonify({"error": "Image embedding model or vector DB not available"}) | |
| try: | |
| # ์์ฒญ์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ถ์ถ | |
| if 'image' not in request.files and 'image' not in request.form: | |
| return jsonify({"error": "No image provided"}) | |
| # ๋ฉํ๋ฐ์ดํฐ ์ถ์ถ | |
| metadata = {} | |
| if 'metadata' in request.form: | |
| metadata = json.loads(request.form['metadata']) | |
| # ์ด๋ฏธ์ง ID (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ ์๋ ์์ฑ) | |
| image_id = request.form.get('id', str(uuid.uuid4())) | |
| if 'image' in request.files: | |
| # ํ์ผ๋ก ์ ๋ก๋๋ ๊ฒฝ์ฐ | |
| image_file = request.files['image'] | |
| image = Image.open(image_file).convert('RGB') | |
| else: | |
| # base64๋ก ์ธ์ฝ๋ฉ๋ ๊ฒฝ์ฐ | |
| image_data = request.form['image'] | |
| if image_data.startswith('data:image'): | |
| # Remove the data URL prefix if present | |
| image_data = image_data.split(',')[1] | |
| image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
| # ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ | |
| embedding = generate_image_embedding(image) | |
| if embedding is None: | |
| return jsonify({"error": "Failed to generate image embedding"}) | |
| # ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ base64๋ก ์ธ์ฝ๋ฉํ์ฌ ๋ฉํ๋ฐ์ดํฐ์ ์ถ๊ฐ | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| metadata['image_data'] = img_str | |
| # ์ด๋ฏธ์ง๋ฅผ DB์ ์ถ๊ฐ | |
| image_collection.add( | |
| ids=[image_id], | |
| embeddings=[embedding], | |
| metadatas=[metadata] | |
| ) | |
| return jsonify({ | |
| "success": True, | |
| "image_id": image_id, | |
| "message": "Image added to collection" | |
| }) | |
| except Exception as e: | |
| print(f"Error in add-to-collection API: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def add_detected_objects(): | |
| """๊ฐ์ฒด ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฒกํฐ DB์ ์ถ๊ฐํ๋ API""" | |
| if clip_model is None or object_collection is None: | |
| return jsonify({"error": "Image embedding model or vector DB not available"}) | |
| try: | |
| # ๋๋ฒ๊น : ์์ฒญ ๋ฐ์ดํฐ ๋ก๊น | |
| print("[DEBUG] Received request in add-detected-objects") | |
| # ์์ฒญ์์ ์ด๋ฏธ์ง์ ๊ฐ์ฒด ๊ฒ์ถ ๊ฒฐ๊ณผ ๋ฐ์ดํฐ ์ถ์ถ | |
| data = request.json | |
| print(f"[DEBUG] Request data keys: {list(data.keys()) if data else 'None'}") | |
| if not data: | |
| print("[DEBUG] Error: No data received in request") | |
| return jsonify({"error": "No data received"}) | |
| if 'image' not in data: | |
| print("[DEBUG] Error: 'image' key not found in request data") | |
| return jsonify({"error": "Missing image data"}) | |
| if 'objects' not in data: | |
| print("[DEBUG] Error: 'objects' key not found in request data") | |
| return jsonify({"error": "Missing objects data"}) | |
| # ์ด๋ฏธ์ง ๋ฐ์ดํฐ ๋๋ฒ๊น | |
| print(f"[DEBUG] Image data type: {type(data['image'])}") | |
| print(f"[DEBUG] Image data starts with: {data['image'][:50]}...") # ์ฒ์ 50์๋ง ์ถ๋ ฅ | |
| # ๊ฐ์ฒด ๋ฐ์ดํฐ ๋๋ฒ๊น | |
| print(f"[DEBUG] Objects data type: {type(data['objects'])}") | |
| print(f"[DEBUG] Objects count: {len(data['objects']) if isinstance(data['objects'], list) else 'Not a list'}") | |
| if isinstance(data['objects'], list) and len(data['objects']) > 0: | |
| print(f"[DEBUG] First object keys: {list(data['objects'][0].keys()) if isinstance(data['objects'][0], dict) else 'Not a dict'}") | |
| # ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ฒ๋ฆฌ | |
| image_data = data['image'] | |
| if image_data.startswith('data:image'): | |
| image_data = image_data.split(',')[1] | |
| image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
| image_width, image_height = image.size | |
| # ์ด๋ฏธ์ง ID | |
| image_id = data.get('imageId', str(uuid.uuid4())) | |
| # ๊ฐ์ฒด ๋ฐ์ดํฐ ์ฒ๋ฆฌ | |
| objects = data['objects'] | |
| object_ids = [] | |
| object_embeddings = [] | |
| object_metadatas = [] | |
| for obj in objects: | |
| # ๊ฐ์ฒด ID ์์ฑ | |
| object_id = f"{image_id}_{str(uuid.uuid4())[:8]}" | |
| # ๋ฐ์ด๋ฉ ๋ฐ์ค ์ ๋ณด ์ถ์ถ | |
| bbox = obj.get('bbox', []) | |
| # ๋ฆฌ์คํธ ํํ์ bbox [x1, y1, x2, y2] ์ฒ๋ฆฌ | |
| if isinstance(bbox, list) and len(bbox) >= 4: | |
| x1 = bbox[0] / image_width # ์ ๊ทํ๋ ์ขํ๋ก ๋ณํ | |
| y1 = bbox[1] / image_height | |
| x2 = bbox[2] / image_width | |
| y2 = bbox[3] / image_height | |
| width = x2 - x1 | |
| height = y2 - y1 | |
| # ๋์ ๋๋ฆฌ ํํ์ bbox {'x': x, 'y': y, 'width': width, 'height': height} ์ฒ๋ฆฌ | |
| elif isinstance(bbox, dict): | |
| x1 = bbox.get('x', 0) | |
| y1 = bbox.get('y', 0) | |
| width = bbox.get('width', 0) | |
| height = bbox.get('height', 0) | |
| else: | |
| # ๊ธฐ๋ณธ๊ฐ ์ค์ | |
| x1, y1, width, height = 0, 0, 0, 0 | |
| # ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์ด๋ฏธ์ง ์ขํ๋ก ๋ณํ | |
| x1_px = int(x1 * image_width) | |
| y1_px = int(y1 * image_height) | |
| width_px = int(width * image_width) | |
| height_px = int(height * image_height) | |
| # ๊ฐ์ฒด ์ด๋ฏธ์ง ์๋ฅด๊ธฐ | |
| try: | |
| object_image = image.crop((x1_px, y1_px, x1_px + width_px, y1_px + height_px)) | |
| # ์๋ฒ ๋ฉ ์์ฑ | |
| embedding = generate_image_embedding(object_image) | |
| if embedding is None: | |
| continue | |
| # ๋ฉํ๋ฐ์ดํฐ ๊ตฌ์ฑ | |
| # bbox๋ฅผ JSON ๋ฌธ์์ด๋ก ์ง๋ ฌํํ์ฌ ChromaDB ๋ฉํ๋ฐ์ดํฐ ์ ํ ์ฐํ | |
| bbox_json = json.dumps({ | |
| "x": x1, | |
| "y": y1, | |
| "width": width, | |
| "height": height | |
| }) | |
| # ๊ฐ์ฒด ์ด๋ฏธ์ง๋ฅผ base64๋ก ์ธ์ฝ๋ฉ | |
| buffered = BytesIO() | |
| object_image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| metadata = { | |
| "image_id": image_id, | |
| "class": obj.get('class', ''), | |
| "confidence": obj.get('confidence', 0), | |
| "bbox": bbox_json, # JSON ๋ฌธ์์ด๋ก ์ ์ฅ | |
| "image_data": img_str # ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ถ๊ฐ | |
| } | |
| object_ids.append(object_id) | |
| object_embeddings.append(embedding) | |
| object_metadatas.append(metadata) | |
| except Exception as e: | |
| print(f"Error processing object: {e}") | |
| continue | |
| # ๊ฐ์ฒด๊ฐ ์๋ ๊ฒฝ์ฐ | |
| if not object_ids: | |
| return jsonify({"error": "No valid objects to add"}) | |
| # ๋๋ฒ๊น : ๋ฉํ๋ฐ์ดํฐ ์ถ๋ ฅ | |
| print(f"[DEBUG] Adding {len(object_ids)} objects to vector DB") | |
| print(f"[DEBUG] First metadata sample: {object_metadatas[0] if object_metadatas else 'None'}") | |
| try: | |
| # ๊ฐ์ฒด๋ค์ DB์ ์ถ๊ฐ | |
| object_collection.add( | |
| ids=object_ids, | |
| embeddings=object_embeddings, | |
| metadatas=object_metadatas | |
| ) | |
| print("[DEBUG] Successfully added objects to vector DB") | |
| except Exception as e: | |
| print(f"[DEBUG] Error adding to vector DB: {e}") | |
| raise e | |
| return jsonify({ | |
| "success": True, | |
| "image_id": image_id, | |
| "object_count": len(object_ids), | |
| "object_ids": object_ids | |
| }) | |
| except Exception as e: | |
| print(f"Error in add-detected-objects API: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def search_similar_objects(): | |
| """์ ์ฌํ ๊ฐ์ฒด ๊ฒ์ API""" | |
| print("[DEBUG] Received request in search-similar-objects") | |
| if clip_model is None or object_collection is None: | |
| print("[DEBUG] Error: Image embedding model or vector DB not available") | |
| return jsonify({"error": "Image embedding model or vector DB not available"}) | |
| try: | |
| # ์์ฒญ ๋ฐ์ดํฐ ์ถ์ถ | |
| data = request.json | |
| print(f"[DEBUG] Request data keys: {list(data.keys()) if data else 'None'}") | |
| if not data: | |
| print("[DEBUG] Error: Missing request data") | |
| return jsonify({"error": "Missing request data"}) | |
| # ๊ฒ์ ์ ํ ๊ฒฐ์ | |
| search_type = data.get('searchType', 'image') | |
| n_results = int(data.get('n_results', 5)) # ๊ฒฐ๊ณผ ๊ฐ์ | |
| print(f"[DEBUG] Search type: {search_type}, n_results: {n_results}") | |
| query_embedding = None | |
| if search_type == 'image' and 'image' in data: | |
| # ์ด๋ฏธ์ง๋ก ๊ฒ์ํ๋ ๊ฒฝ์ฐ | |
| print("[DEBUG] Searching by image") | |
| image_data = data['image'] | |
| if image_data.startswith('data:image'): | |
| image_data = image_data.split(',')[1] | |
| try: | |
| image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
| query_embedding = generate_image_embedding(image) | |
| print(f"[DEBUG] Generated image embedding: {type(query_embedding)}, shape: {len(query_embedding) if query_embedding is not None else 'None'}") | |
| except Exception as e: | |
| print(f"[DEBUG] Error generating image embedding: {e}") | |
| return jsonify({"error": f"Error processing image: {str(e)}"}), 500 | |
| elif search_type == 'object' and 'objectId' in data: | |
| # ๊ฐ์ฒด ID๋ก ๊ฒ์ํ๋ ๊ฒฝ์ฐ | |
| object_id = data['objectId'] | |
| result = object_collection.get(ids=[object_id], include=["embeddings"]) | |
| if result and "embeddings" in result and len(result["embeddings"]) > 0: | |
| query_embedding = result["embeddings"][0] | |
| elif search_type == 'class' and 'class_name' in data: | |
| # ํด๋์ค ์ด๋ฆ์ผ๋ก ๊ฒ์ํ๋ ๊ฒฝ์ฐ | |
| print("[DEBUG] Searching by class name") | |
| class_name = data['class_name'] | |
| print(f"[DEBUG] Class name: {class_name}") | |
| filter_query = {"class": {"$eq": class_name}} | |
| try: | |
| # ํด๋์ค๋ก ํํฐ๋งํ์ฌ ๊ฒ์ | |
| print(f"[DEBUG] Querying with filter: {filter_query}") | |
| # Use get method instead of query for class-based filtering | |
| results = object_collection.get( | |
| where=filter_query, | |
| limit=n_results, | |
| include=["metadatas", "embeddings", "documents"] | |
| ) | |
| print(f"[DEBUG] Query results: {results['ids'][0] if 'ids' in results and len(results['ids']) > 0 else 'No results'}") | |
| formatted_results = format_object_results(results) | |
| print(f"[DEBUG] Formatted results count: {len(formatted_results)}") | |
| return jsonify({ | |
| "success": True, | |
| "searchType": "class", | |
| "results": formatted_results | |
| }) | |
| except Exception as e: | |
| print(f"[DEBUG] Error in class search: {e}") | |
| return jsonify({"error": f"Error in class search: {str(e)}"}), 500 | |
| else: | |
| print(f"[DEBUG] Invalid search parameters: {data}") | |
| return jsonify({"error": "Invalid search parameters"}) | |
| if query_embedding is None: | |
| print("[DEBUG] Error: Failed to generate query embedding") | |
| return jsonify({"error": "Failed to generate query embedding"}) | |
| try: | |
| # ์ ์ฌ๋ ๊ฒ์ ์คํ | |
| print(f"[DEBUG] Running similarity search with embedding of length {len(query_embedding)}") | |
| results = object_collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results, | |
| include=["metadatas", "distances"] | |
| ) | |
| print(f"[DEBUG] Query results: {results['ids'][0] if 'ids' in results and len(results['ids']) > 0 else 'No results'}") | |
| formatted_results = format_object_results(results) | |
| print(f"[DEBUG] Formatted results count: {len(formatted_results)}") | |
| return jsonify({ | |
| "success": True, | |
| "searchType": search_type, | |
| "results": formatted_results | |
| }) | |
| except Exception as e: | |
| print(f"[DEBUG] Error in similarity search: {e}") | |
| return jsonify({"error": f"Error in similarity search: {str(e)}"}), 500 | |
| except Exception as e: | |
| print(f"Error in search-similar-objects API: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| def format_object_results(results): | |
| """๊ฒ์ ๊ฒฐ๊ณผ ํฌ๋งทํ - ChromaDB query ๋ฐ get ๋ฉ์๋ ๊ฒฐ๊ณผ ๋ชจ๋ ์ฒ๋ฆฌ""" | |
| formatted_results = [] | |
| print(f"[DEBUG] Formatting results: {results.keys() if results else 'None'}") | |
| if not results: | |
| print("[DEBUG] No results to format") | |
| return formatted_results | |
| try: | |
| # Check if this is a result from 'get' method (class search) or 'query' method (similarity search) | |
| is_get_result = 'ids' in results and isinstance(results['ids'], list) and not isinstance(results['ids'][0], list) if 'ids' in results else False | |
| if is_get_result: | |
| # Handle results from 'get' method (flat structure) | |
| print("[DEBUG] Processing results from get method (class search)") | |
| if len(results['ids']) == 0: | |
| return formatted_results | |
| for i, obj_id in enumerate(results['ids']): | |
| try: | |
| # Extract object info | |
| metadata = results['metadatas'][i] if 'metadatas' in results else {} | |
| # Deserialize bbox if stored as JSON string | |
| if 'bbox' in metadata and isinstance(metadata['bbox'], str): | |
| try: | |
| metadata['bbox'] = json.loads(metadata['bbox']) | |
| except: | |
| pass # Keep as is if deserialization fails | |
| result_item = { | |
| "id": obj_id, | |
| "metadata": metadata | |
| } | |
| # No distance in get results | |
| # Check if image data is already in metadata | |
| if 'image_data' not in metadata: | |
| print(f"[DEBUG] Image data not found in metadata for object {obj_id}") | |
| else: | |
| print(f"[DEBUG] Image data found in metadata for object {obj_id}") | |
| formatted_results.append(result_item) | |
| except Exception as e: | |
| print(f"[DEBUG] Error formatting get result {i}: {e}") | |
| else: | |
| # Handle results from 'query' method (nested structure) | |
| print("[DEBUG] Processing results from query method (similarity search)") | |
| if 'ids' not in results or len(results['ids']) == 0 or len(results['ids'][0]) == 0: | |
| return formatted_results | |
| for i, obj_id in enumerate(results['ids'][0]): | |
| try: | |
| # Extract object info | |
| metadata = results['metadatas'][0][i] if 'metadatas' in results and len(results['metadatas']) > 0 else {} | |
| # Deserialize bbox if stored as JSON string | |
| if 'bbox' in metadata and isinstance(metadata['bbox'], str): | |
| try: | |
| metadata['bbox'] = json.loads(metadata['bbox']) | |
| except: | |
| pass # Keep as is if deserialization fails | |
| result_item = { | |
| "id": obj_id, | |
| "metadata": metadata | |
| } | |
| if 'distances' in results and len(results['distances']) > 0: | |
| result_item["distance"] = float(results['distances'][0][i]) | |
| # Check if image data is already in metadata | |
| if 'image_data' not in metadata: | |
| try: | |
| # Try to get original image via image ID | |
| image_id = metadata.get('image_id') | |
| if image_id: | |
| print(f"[DEBUG] Image data not found in metadata for object {obj_id} with image_id {image_id}") | |
| except Exception as e: | |
| print(f"[DEBUG] Error checking image data for result {i}: {e}") | |
| else: | |
| print(f"[DEBUG] Image data found in metadata for object {obj_id}") | |
| formatted_results.append(result_item) | |
| except Exception as e: | |
| print(f"[DEBUG] Error formatting query result {i}: {e}") | |
| except Exception as e: | |
| print(f"[DEBUG] Error in format_object_results: {e}") | |
| return formatted_results | |
| # ๋ก๊ทธ์ธ ํ์ด์ง HTML ํ ํ๋ฆฟ | |
| LOGIN_TEMPLATE = ''' | |
| <!DOCTYPE html> | |
| <html lang="ko"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Vision LLM Agent - ๋ก๊ทธ์ธ</title> | |
| <style> | |
| body { | |
| font-family: Arial, sans-serif; | |
| background-color: #f5f5f5; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| height: 100vh; | |
| margin: 0; | |
| } | |
| .login-container { | |
| background-color: white; | |
| padding: 2rem; | |
| border-radius: 8px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| width: 100%; | |
| max-width: 400px; | |
| } | |
| h1 { | |
| text-align: center; | |
| color: #4a6cf7; | |
| margin-bottom: 1.5rem; | |
| } | |
| .form-group { | |
| margin-bottom: 1rem; | |
| } | |
| label { | |
| display: block; | |
| margin-bottom: 0.5rem; | |
| font-weight: bold; | |
| } | |
| input { | |
| width: 100%; | |
| padding: 0.75rem; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| font-size: 1rem; | |
| } | |
| button { | |
| width: 100%; | |
| padding: 0.75rem; | |
| background-color: #4a6cf7; | |
| color: white; | |
| border: none; | |
| border-radius: 4px; | |
| font-size: 1rem; | |
| cursor: pointer; | |
| margin-top: 1rem; | |
| } | |
| button:hover { | |
| background-color: #3a5cd8; | |
| } | |
| .error-message { | |
| color: #e74c3c; | |
| margin-top: 1rem; | |
| text-align: center; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="login-container"> | |
| <h1>Vision LLM Agent</h1> | |
| <form action="/login" method="post" autocomplete="off"> | |
| <!-- hidden dummy fields to discourage Chrome autofill --> | |
| <input type="text" name="fakeusernameremembered" style="display:none" tabindex="-1" autocomplete="off"> | |
| <input type="password" name="fakepasswordremembered" style="display:none" tabindex="-1" autocomplete="off"> | |
| <div class="form-group"> | |
| <label for="username">Username</label> | |
| <input type="text" id="username" name="username" required autocomplete="username" autocapitalize="none" autocorrect="off" spellcheck="false"> | |
| </div> | |
| <div class="form-group"> | |
| <label for="password">Password</label> | |
| <input type="password" id="password" name="password" required autocomplete="current-password" autocapitalize="none" autocorrect="off" spellcheck="false"> | |
| </div> | |
| <button type="submit">Login</button> | |
| {% if error %} | |
| <p class="error-message">{{ error }}</p> | |
| {% endif %} | |
| </form> | |
| </div> | |
| </body> | |
| </html> | |
| ''' | |
| def login(): | |
| # ์ด๋ฏธ ๋ก๊ทธ์ธ๋ ์ฌ์ฉ์๋ ๋ฉ์ธ ํ์ด์ง๋ก ๋ฆฌ๋๋ ์ | |
| if current_user.is_authenticated and login_fresh(): | |
| print(f"User already authenticated and fresh as: {current_user.username}, redirecting to index") | |
| return redirect('/index.html') | |
| elif current_user.is_authenticated and not login_fresh(): | |
| # Remember-cookie ์ํ ๋ฑ ๋น-ํ๋ ์ ์ธ์ ์ด๋ฉด ๋ก๊ทธ์ธ ํ์ด์ง๋ฅผ ๋ณด์ฌ์ ์ฌ์ธ์ฆ ์ ๋ | |
| print("User authenticated but session not fresh; showing login page for reauthentication") | |
| error = None | |
| if request.method == 'POST': | |
| username = request.form.get('username') | |
| password = request.form.get('password') | |
| print(f"Login attempt: username={username}") | |
| if username in users and users[username].password == password: | |
| # ๋ก๊ทธ์ธ ์ฑ๊ณต ์ ์ธ์ ์ ์ฌ์ฉ์ ์ ๋ณด ์ ์ฅ | |
| user = users[username] | |
| login_user(user, remember=False) # 2๋ถ ์ธ์ ๋ง๋ฃ๋ฅผ ์ํด remember ๋นํ์ฑํ | |
| session['user_id'] = user.id | |
| session['username'] = username | |
| session.permanent = True | |
| session.modified = True # ์ธ์ ๋ณ๊ฒฝ ์ฌํญ ์ฆ์ ์ ์ฉ | |
| print(f"Login successful for user: {username}, ID: {user.id}") | |
| # ๋ฆฌ๋๋ ์ ์ฒ๋ฆฌ | |
| next_page = request.args.get('next') | |
| if next_page and next_page.startswith('/') and next_page != '/login': | |
| print(f"Redirecting to: {next_page}") | |
| return redirect(next_page) | |
| print("Redirecting to index.html") | |
| return redirect(url_for('serve_index_html')) | |
| else: | |
| error = 'Invalid username or password' | |
| print(f"Login failed: {error}") | |
| return render_template_string(LOGIN_TEMPLATE, error=error) | |
| def logout(): | |
| logout_user() | |
| # Clear server-side session fully | |
| try: | |
| session.clear() | |
| except Exception as e: | |
| print(f"[DEBUG] Error clearing session on logout: {e}") | |
| # Ensure remember cookie is removed by setting an expired cookie | |
| resp = redirect(url_for('login')) | |
| try: | |
| resp.delete_cookie( | |
| key='remember_token', | |
| path='/', | |
| samesite='None', | |
| secure=True, | |
| httponly=True, | |
| ) | |
| except Exception as e: | |
| print(f"[DEBUG] Error deleting remember_token cookie: {e}") | |
| return resp | |
| # ์ ์ ํ์ผ ์๋น์ ์ํ ๋ผ์ฐํธ (๋ก๊ทธ์ธ ๋ถํ์) | |
| def serve_static(filename): | |
| print(f"Serving static file: {filename}") | |
| resp = send_from_directory(app.static_folder, filename) | |
| # Prevent caching of static assets to reflect latest frontend changes | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| # ์ธ๋ฑ์ค HTML ์ง์ ์๋น (๋ก๊ทธ์ธ ํ์) | |
| def serve_index_html(): | |
| # ์ธ์ ๋ฐ ์ฟ ํค ๋๋ฒ๊ทธ ์ ๋ณด | |
| print(f"Request to /index.html - Session data: {dict(session)}") | |
| print(f"Request to /index.html - Cookies: {request.cookies}") | |
| print(f"Request to /index.html - User authenticated: {current_user.is_authenticated}") | |
| # ์ธ์ฆ ํ์ธ (fresh session only) | |
| if not current_user.is_authenticated or not login_fresh(): | |
| print("User not authenticated, redirecting to login") | |
| return redirect(url_for('login')) | |
| print(f"Serving index.html for authenticated user: {current_user.username} (ID: {current_user.id})") | |
| # ์ธ์ ์ํ ๋๋ฒ๊ทธ | |
| print(f"Session data: user_id={session.get('user_id')}, username={session.get('username')}, is_permanent={session.get('permanent', False)}") | |
| # ์ธ์ ๋ง๋ฃ๋ฅผ ์๋๋๋ก ์ ์งํ๊ธฐ ์ํด ์ฌ๊ธฐ์ ์ธ์ ์ ๊ฐฑ์ ํ์ง ์์ต๋๋ค. | |
| # ์ฃผ์: ์ธ์ ์ ์ฐ๊ธฐ(๋๋ session.modified=True)๋ Flask-Session์์ ๋ง๋ฃ์๊ฐ์ ์ฐ์ฅํ ์ ์์ต๋๋ค. | |
| # index.html์ ์ฝ์ด ํํธ๋นํธ ์คํฌ๋ฆฝํธ๋ฅผ ์ฃผ์ | |
| index_path = os.path.join(app.static_folder, 'index.html') | |
| try: | |
| with open(index_path, 'r', encoding='utf-8') as f: | |
| html = f.read() | |
| except Exception as e: | |
| print(f"[DEBUG] Failed to read index.html for injection: {e}") | |
| return send_from_directory(app.static_folder, 'index.html') | |
| heartbeat_script = """ | |
| <script> | |
| (function(){ | |
| // 1) ์ธ์ ์ํ ์ฃผ๊ธฐ ์ฒดํฌ (๋ง๋ฃ์ ๋ก๊ทธ์ธ์ผ๋ก) | |
| function checkSession(){ | |
| fetch('/api/status', {credentials: 'include', redirect: 'manual'}).then(function(res){ | |
| var redirected = res.redirected || (res.url && res.url.indexOf('/login') !== -1); | |
| if(res.status !== 200 || redirected){ | |
| window.location.href = '/login'; | |
| } | |
| }).catch(function(){ | |
| // ๋คํธ์ํฌ ์ค๋ฅ ๋ฑ๋ ๋ก๊ทธ์ธ์ผ๋ก ์ ๋ | |
| window.location.href = '/login'; | |
| }); | |
| } | |
| checkSession(); | |
| setInterval(checkSession, 30000); | |
| // 2) ์ฌ์ฉ์ ๋นํ์ฑ(๋ฌด๋์) 2๋ถ ํ ์๋ ๋ก๊ทธ์์ | |
| var idleMs = 120000; // 2๋ถ | |
| var idleTimer; | |
| function triggerLogout(){ | |
| // ์๋ฒ ์ธ์ ์ ๋ฆฌ ํ ๋ก๊ทธ์ธ ํ๋ฉด์ผ๋ก | |
| window.location.href = '/logout'; | |
| } | |
| function resetIdle(){ | |
| if (idleTimer) clearTimeout(idleTimer); | |
| idleTimer = setTimeout(triggerLogout, idleMs); | |
| } | |
| ['click','mousemove','keydown','scroll','touchstart','visibilitychange'].forEach(function(evt){ | |
| window.addEventListener(evt, resetIdle, {passive:true}); | |
| }); | |
| resetIdle(); | |
| })(); | |
| </script> | |
| """ | |
| try: | |
| if '</body>' in html: | |
| html = html.replace('</body>', heartbeat_script + '</body>') | |
| else: | |
| html = html + heartbeat_script | |
| except Exception as e: | |
| print(f"[DEBUG] Failed to inject heartbeat script: {e}") | |
| return send_from_directory(app.static_folder, 'index.html') | |
| resp = make_response(html) | |
| # Prevent sensitive pages from being cached | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| # Static files should be accessible without login requirements | |
| def static_files(filename): | |
| print(f"Serving static file: {filename}") | |
| # Two possible locations after CRA build copy: | |
| # 1) Top-level: static/<filename> | |
| # 2) Nested build: static/static/<filename> | |
| top_level_path = os.path.join(app.static_folder, filename) | |
| nested_dir = os.path.join(app.static_folder, 'static') | |
| nested_path = os.path.join(nested_dir, filename) | |
| try: | |
| if os.path.exists(top_level_path): | |
| return send_from_directory(app.static_folder, filename) | |
| elif os.path.exists(nested_path): | |
| # Serve from nested build directory | |
| return send_from_directory(nested_dir, filename) | |
| else: | |
| # Fallback: try as-is (may help in some edge cases) | |
| return send_from_directory(app.static_folder, filename) | |
| except Exception as e: | |
| print(f"[DEBUG] Error serving static file '{filename}': {e}") | |
| # Final fallback to avoid leaking stack traces | |
| return ('Not Found', 404) | |
| # Add explicit handlers for JS files that are failing | |
| def static_js_files(filename): | |
| print(f"Serving JS file: {filename}") | |
| # Try top-level static/js and nested static/static/js | |
| top_js_dir = os.path.join(app.static_folder, 'js') | |
| nested_js_dir = os.path.join(app.static_folder, 'static', 'js') | |
| top_js_path = os.path.join(top_js_dir, filename) | |
| nested_js_path = os.path.join(nested_js_dir, filename) | |
| try: | |
| if os.path.exists(top_js_path): | |
| return send_from_directory(top_js_dir, filename) | |
| elif os.path.exists(nested_js_path): | |
| return send_from_directory(nested_js_dir, filename) | |
| else: | |
| # As a fallback, let the generic static handler try | |
| return static_files(os.path.join('js', filename)) | |
| except Exception as e: | |
| print(f"[DEBUG] Error serving JS file '{filename}': {e}") | |
| return ('Not Found', 404) | |
| # ๊ธฐ๋ณธ ๊ฒฝ๋ก ๋ฐ ๊ธฐํ ๊ฒฝ๋ก ์ฒ๋ฆฌ (๋ก๊ทธ์ธ ํ์) | |
| def serve_react(path): | |
| """Serve React frontend""" | |
| print(f"Serving React frontend for path: {path}, user: {current_user.username if current_user.is_authenticated else 'not authenticated'}") | |
| # ์ ์ ํ์ผ ์ฒ๋ฆฌ๋ ์ด์ ๋ณ๋ ๋ผ์ฐํธ์์ ์ฒ๋ฆฌ | |
| if path != "" and os.path.exists(os.path.join(app.static_folder, path)): | |
| resp = send_from_directory(app.static_folder, path) | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| else: | |
| # React ์ฑ์ index.html ์๋น (ํํธ๋นํธ ์คํฌ๋ฆฝํธ ์ฃผ์ ) | |
| index_path = os.path.join(app.static_folder, 'index.html') | |
| try: | |
| with open(index_path, 'r', encoding='utf-8') as f: | |
| html = f.read() | |
| except Exception as e: | |
| print(f"[DEBUG] Failed to read index.html for injection (serve_react): {e}") | |
| resp = send_from_directory(app.static_folder, 'index.html') | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| heartbeat_script = """ | |
| <script> | |
| (function(){ | |
| // 1) ์ธ์ ์ํ ์ฃผ๊ธฐ ์ฒดํฌ (๋ง๋ฃ์ ๋ก๊ทธ์ธ์ผ๋ก) | |
| function checkSession(){ | |
| fetch('/api/status', {credentials: 'include', redirect: 'manual'}).then(function(res){ | |
| var redirected = res.redirected || (res.url && res.url.indexOf('/login') !== -1); | |
| if(res.status !== 200 || redirected){ | |
| window.location.href = '/login'; | |
| } | |
| }).catch(function(){ | |
| window.location.href = '/login'; | |
| }); | |
| } | |
| checkSession(); | |
| setInterval(checkSession, 30000); | |
| // 2) ์ฌ์ฉ์ ๋นํ์ฑ(๋ฌด๋์) 2๋ถ ํ ์๋ ๋ก๊ทธ์์ | |
| var idleMs = 120000; // 2๋ถ | |
| var idleTimer; | |
| function triggerLogout(){ | |
| window.location.href = '/logout'; | |
| } | |
| function resetIdle(){ | |
| if (idleTimer) clearTimeout(idleTimer); | |
| idleTimer = setTimeout(triggerLogout, idleMs); | |
| } | |
| ['click','mousemove','keydown','scroll','touchstart','visibilitychange'].forEach(function(evt){ | |
| window.addEventListener(evt, resetIdle, {passive:true}); | |
| }); | |
| resetIdle(); | |
| })(); | |
| </script> | |
| """ | |
| try: | |
| if '</body>' in html: | |
| html = html.replace('</body>', heartbeat_script + '</body>') | |
| else: | |
| html = html + heartbeat_script | |
| except Exception as e: | |
| print(f"[DEBUG] Failed to inject heartbeat script (serve_react): {e}") | |
| resp = send_from_directory(app.static_folder, 'index.html') | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| resp = make_response(html) | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| def similar_images_page(): | |
| """Serve similar images search page""" | |
| resp = send_from_directory(app.static_folder, 'similar-images.html') | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| def object_detection_search_page(): | |
| """Serve object detection search page""" | |
| resp = send_from_directory(app.static_folder, 'object-detection-search.html') | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| def model_vector_db_page(): | |
| """Serve model vector DB UI page""" | |
| resp = send_from_directory(app.static_folder, 'model-vector-db.html') | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| def openai_chat_page(): | |
| """Serve OpenAI chat UI page""" | |
| resp = send_from_directory(app.static_folder, 'openai-chat.html') | |
| resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
| resp.headers['Pragma'] = 'no-cache' | |
| resp.headers['Expires'] = '0' | |
| return resp | |
| def openai_chat_api(): | |
| """Forward chat request to OpenAI Chat Completions API. | |
| Expects JSON: { prompt: string, model?: string, api_key?: string, system?: string } | |
| Uses OPENAI_API_KEY from environment if api_key not provided. | |
| """ | |
| try: | |
| data = request.get_json(force=True) | |
| except Exception: | |
| return jsonify({"error": "Invalid JSON body"}), 400 | |
| prompt = (data or {}).get('prompt', '').strip() | |
| model = (data or {}).get('model') or os.environ.get('OPENAI_MODEL', 'gpt-4') | |
| system = (data or {}).get('system') or 'You are a helpful assistant.' | |
| api_key = (data or {}).get('api_key') or os.environ.get('OPENAI_API_KEY') | |
| if not prompt: | |
| return jsonify({"error": "Missing 'prompt'"}), 400 | |
| if not api_key: | |
| return jsonify({"error": "Missing OpenAI API key. Provide in request or set OPENAI_API_KEY env."}), 400 | |
| # Prefer official Python SDK if available | |
| if OpenAI is None: | |
| return jsonify({"error": "OpenAI Python package not installed on server"}), 500 | |
| try: | |
| start = time.time() | |
| client = OpenAI(api_key=api_key) | |
| print("Available models (first 5):", [m.id for m in client.models.list().data][:5]) | |
| chat = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| latency = round(time.time() - start, 3) | |
| except Exception as e: | |
| return jsonify({"error": f"OpenAI SDK call failed: {str(e)}"}), 502 | |
| try: | |
| content = chat.choices[0].message.content if chat and chat.choices else '' | |
| usage = getattr(chat, 'usage', None) | |
| usage = usage.model_dump() if hasattr(usage, 'model_dump') else (usage or {}) | |
| except Exception as e: | |
| return jsonify({"error": f"Failed to parse SDK response: {str(e)}"}), 500 | |
| return jsonify({ | |
| 'response': content, | |
| 'model': model, | |
| 'usage': usage, | |
| 'latency_sec': latency | |
| }) | |
| def status(): | |
| return jsonify({ | |
| "status": "online", | |
| "models": { | |
| "yolo": yolo_model is not None, | |
| "detr": detr_model is not None and detr_processor is not None, | |
| "vit": vit_model is not None and vit_processor is not None | |
| }, | |
| "device": "GPU" if torch.cuda.is_available() else "CPU", | |
| "user": current_user.username | |
| }) | |
| # Root route is now handled by serve_react function | |
| # This route is removed to prevent conflicts | |
| def index_page(): | |
| # /index ๊ฒฝ๋ก๋ index.html๋ก ๋ฆฌ๋๋ ์ | |
| print("Index route redirecting to index.html") | |
| return redirect('/index.html') | |
| if __name__ == "__main__": | |
| # ํ๊น ํ์ด์ค Space์์๋ PORT ํ๊ฒฝ ๋ณ์๋ฅผ ์ฌ์ฉํฉ๋๋ค | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(debug=False, host='0.0.0.0', port=port) | |