Spaces:
Running
Running
sunheycho
feat(llama-lora): add display names and defaults; add peft; rebuild static\n\n- UI: add baseName/loraName fields and send to backend; show names in results\n- Defaults: base=openlm-research/open_llama_3b, adapter=HGKo/openllama3b-lora-adapter\n- Backend deps: add peft to requirements.txt\n- Frontend: build CRA and sync artifacts to top-level static/
56b6ee6
# -*- coding: utf-8 -*- | |
# Set matplotlib config directory to avoid permission issues | |
import os | |
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
from flask import Flask, request, jsonify, send_from_directory, redirect, url_for, session, render_template_string, make_response, Response, stream_with_context | |
from datetime import timedelta | |
import torch | |
from PIL import Image | |
import numpy as np | |
import io | |
from io import BytesIO | |
import base64 | |
import uuid | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle | |
import time | |
from flask_cors import CORS | |
import json | |
import sys | |
import requests | |
import asyncio | |
from threading import Thread | |
import tempfile | |
try: | |
from openai import OpenAI | |
except Exception as _e: | |
OpenAI = None | |
try: | |
# LangChain for RAG answering | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
except Exception as _e: | |
ChatOpenAI = None | |
ChatPromptTemplate = None | |
StrOutputParser = None | |
from flask_login import ( | |
LoginManager, | |
UserMixin, | |
login_user, | |
logout_user, | |
login_required, | |
current_user, | |
fresh_login_required, | |
login_fresh, | |
) | |
try: | |
# PEFT for LoRA adapters (optional) | |
from peft import PeftModel | |
except Exception as _e: | |
PeftModel = None | |
# Fix for SQLite3 version compatibility with ChromaDB | |
try: | |
import pysqlite3 | |
sys.modules['sqlite3'] = pysqlite3 | |
except ImportError: | |
print("Warning: pysqlite3 not found, using built-in sqlite3") | |
import chromadb | |
from chromadb.utils import embedding_functions | |
app = Flask(__name__, static_folder='static') | |
# Import product comparison coordinator with detailed debugging | |
print("=" * 80) | |
print("[STARTUP DEBUG] π Testing product_comparison import at startup...") | |
print("=" * 80) | |
try: | |
print("[DEBUG] Attempting to import product_comparison module...") | |
from product_comparison import get_product_comparison_coordinator, decode_base64_image | |
print("[DEBUG] β Product comparison module imported successfully!") | |
print(f"[DEBUG] β get_product_comparison_coordinator: {get_product_comparison_coordinator}") | |
print(f"[DEBUG] β decode_base64_image: {decode_base64_image}") | |
# Test coordinator creation | |
print("[DEBUG] Testing coordinator creation...") | |
test_coordinator = get_product_comparison_coordinator() | |
print(f"[DEBUG] β Test coordinator created: {type(test_coordinator).__name__}") | |
except ImportError as e: | |
print(f"[DEBUG] β Product comparison import failed: {e}") | |
print(f"[DEBUG] β Import error type: {type(e).__name__}") | |
print(f"[DEBUG] β Import error args: {e.args}") | |
import traceback | |
print("[DEBUG] β Full import traceback:") | |
traceback.print_exc() | |
print("Warning: Product comparison module not available") | |
get_product_comparison_coordinator = None | |
decode_base64_image = None | |
except Exception as e: | |
print(f"[DEBUG] β Unexpected error during import: {e}") | |
print(f"[DEBUG] β Error type: {type(e).__name__}") | |
import traceback | |
print("[DEBUG] β Full traceback:") | |
traceback.print_exc() | |
get_product_comparison_coordinator = None | |
decode_base64_image = None | |
print("=" * 80) | |
print(f"[STARTUP DEBUG] π Import test completed. Coordinator available: {get_product_comparison_coordinator is not None}") | |
print(f"[STARTUP DEBUG] π Current working directory: {os.getcwd()}") | |
print(f"[STARTUP DEBUG] π Files in current directory: {os.listdir('.')}") | |
print("=" * 80) | |
# νκ²½ λ³μμμ λΉλ° ν€λ₯Ό κ°μ Έμ€κ±°λ, μμΌλ©΄ μμ ν λλ€ ν€ μμ± | |
secret_key = os.environ.get('FLASK_SECRET_KEY') | |
if not secret_key: | |
import secrets | |
secret_key = secrets.token_hex(16) # 32μ κΈΈμ΄μ λλ€ 16μ§μ λ¬Έμμ΄ μμ± | |
print("WARNING: FLASK_SECRET_KEY νκ²½ λ³μκ° μ€μ λμ§ μμμ΅λλ€. λλ€ ν€λ₯Ό μμ±νμ΅λλ€.") | |
print("μλ² μ¬μμ μ μΈμ μ΄ λͺ¨λ λ§λ£λ©λλ€. νλ‘λμ νκ²½μμλ νκ²½ λ³μλ₯Ό μ€μ νμΈμ.") | |
app.secret_key = secret_key # μΈμ μνΈνλ₯Ό μν λΉλ° ν€ | |
app.config['CORS_HEADERS'] = 'Content-Type' | |
# Remember cookie (Flask-Login) β minimize duration to prevent auto re-login | |
app.config['REMEMBER_COOKIE_DURATION'] = timedelta(seconds=1) | |
# Cookie settings for Hugging Face Spaces compatibility | |
app.config['REMEMBER_COOKIE_SECURE'] = False | |
app.config['REMEMBER_COOKIE_HTTPONLY'] = False # Allow JS access for debugging | |
app.config['REMEMBER_COOKIE_SAMESITE'] = None # Most permissive for iframe | |
# Session cookie settings - most permissive for HF Spaces | |
app.config['SESSION_COOKIE_SECURE'] = False | |
app.config['SESSION_COOKIE_HTTPONLY'] = False # Allow JS access | |
app.config['SESSION_COOKIE_SAMESITE'] = None # Most permissive for iframe | |
app.config['SESSION_COOKIE_PATH'] = '/' | |
app.config['SESSION_COOKIE_DOMAIN'] = None | |
CORS(app, supports_credentials=True) # Enable CORS for all routes with credentials | |
# Dynamic HTTPS detection and cookie security settings | |
def configure_cookies_for_https(): | |
"""Dynamically configure cookie security based on HTTPS detection""" | |
is_https = ( | |
request.is_secure or | |
request.headers.get('X-Forwarded-Proto') == 'https' or | |
request.headers.get('X-Forwarded-Ssl') == 'on' or | |
os.environ.get('HTTPS', '').lower() == 'true' | |
) | |
if is_https: | |
app.config['SESSION_COOKIE_SECURE'] = True | |
app.config['REMEMBER_COOKIE_SECURE'] = True | |
# μν¬λ¦Ώ ν€ μ€μ (μΈμ μνΈνμ μ¬μ©) | |
app.config['SECRET_KEY'] = os.environ.get('SECRET_KEY', 'vision_llm_agent_secret_key') | |
app.config['SESSION_TYPE'] = 'filesystem' | |
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(seconds=120) # μΈμ μ ν¨ μκ° (2λΆ) | |
app.config['SESSION_REFRESH_EACH_REQUEST'] = False # μ λ λ§λ£(λ‘κ·ΈμΈ κΈ°μ€ 2λΆ ν λ§λ£) | |
# Fix session cookie issues in Hugging Face Spaces | |
app.config['SESSION_COOKIE_NAME'] = 'session' | |
app.config['SESSION_COOKIE_DOMAIN'] = None # Allow any domain | |
app.config['SESSION_KEY_PREFIX'] = 'session:' | |
# Flask-Login μ€μ | |
login_manager = LoginManager() | |
login_manager.init_app(app) | |
login_manager.login_view = 'login' | |
login_manager.session_protection = 'strong' | |
# When authentication is required or session is not fresh, redirect to login instead of 401 | |
login_manager.refresh_view = 'login' | |
def handle_unauthorized(): | |
# For non-authenticated access, send user to login | |
return redirect(url_for('login')) | |
def handle_needs_refresh(): | |
# For non-fresh sessions (e.g., after expiry or only remember-cookie), send to login | |
return redirect(url_for('login')) | |
# μΈμ μ€μ | |
import tempfile | |
from flask_session import Session | |
# μμ λλ ν 리λ₯Ό μ¬μ©νμ¬ κΆν λ¬Έμ ν΄κ²° | |
session_dir = tempfile.gettempdir() | |
app.config['SESSION_TYPE'] = 'filesystem' | |
app.config['SESSION_PERMANENT'] = True | |
app.config['SESSION_USE_SIGNER'] = True | |
app.config['SESSION_FILE_DIR'] = session_dir | |
print(f"Using session directory: {session_dir}") | |
Session(app) | |
# μ¬μ©μ ν΄λμ€ μ μ | |
class User(UserMixin): | |
def __init__(self, id, username, password): | |
self.id = id | |
self.username = username | |
self.password = password | |
def get_id(self): | |
return str(self.id) # Flask-Loginμ λ¬Έμμ΄ IDλ₯Ό μκ΅¬ν¨ | |
# ν μ€νΈμ© μ¬μ©μ (μ€μ νκ²½μμλ λ°μ΄ν°λ² μ΄μ€ μ¬μ© κΆμ₯) | |
# νκ²½ λ³μμμ μ¬μ©μ κ³μ μ 보λ₯Ό κ°μ Έμ€κΈ° (κΈ°λ³Έκ° μμ) | |
admin_username = os.environ.get('ADMIN_USERNAME', 'admin') | |
admin_password = os.environ.get('ADMIN_PASSWORD') | |
user_username = os.environ.get('USER_USERNAME', 'user') | |
user_password = os.environ.get('USER_PASSWORD') | |
# νκ²½ λ³μκ° μ€μ λμ§ μμμ κ²½μ° κ²½κ³ λ©μμ§ μΆλ ₯ | |
if not admin_password or not user_password: | |
print("ERROR: νκ²½ λ³μ ADMIN_PASSWORD λλ USER_PASSWORDκ° μ€μ λμ§ μμμ΅λλ€.") | |
print("Hugging Face Spacesμμ λ°λμ νκ²½ λ³μλ₯Ό μ€μ ν΄μΌ ν©λλ€.") | |
print("Settings > Repository secretsμμ νκ²½ λ³μλ₯Ό μΆκ°νμΈμ.") | |
# νκ²½ λ³μκ° μμ κ²½μ° μμ λΉλ°λ²νΈ μμ± (κ°λ°μ©) | |
import secrets | |
if not admin_password: | |
admin_password = secrets.token_hex(8) # μμ λΉλ°λ²νΈ μμ± | |
print(f"WARNING: μμ admin λΉλ°λ²νΈκ° μμ±λμμ΅λλ€: {admin_password}") | |
if not user_password: | |
user_password = secrets.token_hex(8) # μμ λΉλ°λ²νΈ μμ± | |
print(f"WARNING: μμ user λΉλ°λ²νΈκ° μμ±λμμ΅λλ€: {user_password}") | |
users = { | |
admin_username: User('1', admin_username, admin_password), | |
user_username: User('2', user_username, user_password) | |
} | |
# Manual authentication check function | |
def check_authentication(): | |
"""Check authentication via session, cookies, or headers. Returns (user_id, username) or (None, None)""" | |
# Check session first | |
user_id = session.get('user_id') | |
username = session.get('username') | |
# Fallback to cookies if session is empty | |
if not user_id or not username: | |
user_id = request.cookies.get('auth_user_id') | |
username = request.cookies.get('auth_username') | |
# Fallback to headers (for localStorage-based auth) | |
if not user_id or not username: | |
user_id = request.headers.get('X-Auth-User-ID') | |
username = request.headers.get('X-Auth-Username') | |
print(f"[DEBUG] Auth from headers: user_id={user_id}, username={username}") | |
# Fallback to query params (for SSE/EventSource where custom headers aren't supported) | |
if not user_id or not username: | |
user_id = request.args.get('uid') | |
username = request.args.get('uname') | |
if user_id or username: | |
print(f"[DEBUG] Auth from query params: user_id={user_id}, username={username}") | |
if not user_id or not username: | |
return None, None | |
# Verify user exists | |
for stored_username, user_obj in users.items(): | |
# users maps username -> User instance; compare using attributes | |
if str(getattr(user_obj, 'id', None)) == str(user_id) and stored_username == username: | |
return user_id, username | |
return None, None | |
def require_auth(): | |
"""Decorator replacement for @login_required that supports cookie fallback""" | |
def decorator(f): | |
def decorated_function(*args, **kwargs): | |
user_id, username = check_authentication() | |
if not user_id: | |
return jsonify({"error": "Authentication required"}), 401 | |
return f(*args, **kwargs) | |
decorated_function.__name__ = f.__name__ | |
return decorated_function | |
return decorator | |
# μ¬μ©μ λ‘λ ν¨μ | |
def load_user(user_id): | |
print(f"Loading user with ID: {user_id}") | |
# μΈμ λλ²κ·Έ μ 보 μΆλ ₯ | |
print(f"Session data in user_loader: {dict(session)}") | |
print(f"Current request cookies: {request.cookies}") | |
# Check session first, then fallback to cookies | |
session_user_id = session.get('user_id') | |
if not session_user_id: | |
# Try to get from cookies | |
cookie_user_id = request.cookies.get('auth_user_id') | |
cookie_username = request.cookies.get('auth_username') | |
print(f"Session empty, checking cookies: user_id={cookie_user_id}, username={cookie_username}") | |
if cookie_user_id and cookie_username: | |
# Verify cookie user exists | |
for username, user in users.items(): | |
if str(user.id) == str(cookie_user_id) and username == cookie_username: | |
print(f"User found via cookies: {username}, ID: {user.id}") | |
# Restore session from cookies | |
session['user_id'] = user.id | |
session['username'] = username | |
session.modified = True | |
return user | |
# Original session-based lookup | |
for username, user in users.items(): | |
if str(user.id) == str(user_id): # νμ€ν λ¬Έμμ΄ λΉκ΅ | |
print(f"User found: {username}, ID: {user.id}") | |
# μΈμ μ 보 μ λ°μ΄νΈ | |
session['user_id'] = user.id | |
session['username'] = username | |
session.modified = True | |
return user | |
print(f"User not found with ID: {user_id}") | |
return None | |
# Model initialization | |
print("Loading models... This may take a moment.") | |
# Image embedding model (CLIP) for vector search | |
clip_model = None | |
clip_processor = None | |
try: | |
from transformers import CLIPProcessor, CLIPModel | |
# μμ λλ ν 리 μ¬μ© | |
import tempfile | |
temp_dir = tempfile.gettempdir() | |
os.environ["TRANSFORMERS_CACHE"] = temp_dir | |
# CLIP λͺ¨λΈ λ‘λ (μ΄λ―Έμ§ μλ² λ©μ©) | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("CLIP model loaded successfully") | |
except Exception as e: | |
print("Error loading CLIP model:", e) | |
clip_model = None | |
clip_processor = None | |
# Vector DB μ΄κΈ°ν | |
vector_db = None | |
image_collection = None | |
object_collection = None | |
try: | |
# ChromaDB ν΄λΌμ΄μΈνΈ μ΄κΈ°ν (μΈλ©λͺ¨λ¦¬ DB) | |
vector_db = chromadb.Client() | |
# μλ² λ© ν¨μ μ€μ | |
ef = embedding_functions.DefaultEmbeddingFunction() | |
# μ΄λ―Έμ§ 컬λ μ μμ± | |
image_collection = vector_db.create_collection( | |
name="image_collection", | |
embedding_function=ef, | |
get_or_create=True | |
) | |
# κ°μ²΄ μΈμ κ²°κ³Ό 컬λ μ μμ± | |
object_collection = vector_db.create_collection( | |
name="object_collection", | |
embedding_function=ef, | |
get_or_create=True | |
) | |
print("Vector DB initialized successfully") | |
except Exception as e: | |
print("Error initializing Vector DB:", e) | |
vector_db = None | |
image_collection = None | |
object_collection = None | |
# YOLOv8 model | |
yolo_model = None | |
try: | |
import os | |
from ultralytics import YOLO | |
# λͺ¨λΈ νμΌ κ²½λ‘ - μμ λλ ν 리 μ¬μ© | |
import tempfile | |
temp_dir = tempfile.gettempdir() | |
model_path = os.path.join(temp_dir, "yolov8n.pt") | |
# λͺ¨λΈ νμΌμ΄ μμΌλ©΄ μ§μ λ€μ΄λ‘λ | |
if not os.path.exists(model_path): | |
print(f"Downloading YOLOv8 model to {model_path}...") | |
try: | |
os.system(f"wget -q https://ultralytics.com/assets/yolov8n.pt -O {model_path}") | |
print("YOLOv8 model downloaded successfully") | |
except Exception as e: | |
print(f"Error downloading YOLOv8 model: {e}") | |
# λ€μ΄λ‘λ μ€ν¨ μ λ체 URL μλ | |
try: | |
os.system(f"wget -q https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt -O {model_path}") | |
print("YOLOv8 model downloaded from alternative source") | |
except Exception as e2: | |
print(f"Error downloading from alternative source: {e2}") | |
# λ§μ§λ§ λμμΌλ‘ μ§μ λͺ¨λΈ URL μ¬μ© | |
try: | |
os.system(f"curl -L https://ultralytics.com/assets/yolov8n.pt --output {model_path}") | |
print("YOLOv8 model downloaded using curl") | |
except Exception as e3: | |
print(f"All download attempts failed: {e3}") | |
# νκ²½ λ³μ μ€μ - μ€μ νμΌ κ²½λ‘ μ§μ | |
os.environ["YOLO_CONFIG_DIR"] = temp_dir | |
os.environ["MPLCONFIGDIR"] = temp_dir | |
yolo_model = YOLO(model_path) # Using the nano model for faster inference | |
print("YOLOv8 model loaded successfully") | |
except Exception as e: | |
print("Error loading YOLOv8 model:", e) | |
yolo_model = None | |
# DETR model (DEtection TRansformer) | |
detr_processor = None | |
detr_model = None | |
try: | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
print("DETR model loaded successfully") | |
except Exception as e: | |
print("Error loading DETR model:", e) | |
detr_processor = None | |
detr_model = None | |
# ViT model | |
vit_processor = None | |
vit_model = None | |
try: | |
from transformers import ViTImageProcessor, ViTForImageClassification | |
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
print("ViT model loaded successfully") | |
except Exception as e: | |
print("Error loading ViT model:", e) | |
vit_processor = None | |
vit_model = None | |
# Get device information | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# llama.cpp (GGUF) support | |
llama_cpp = None | |
llama_cpp_model = None | |
gguf_model_path = None | |
try: | |
import llama_cpp as llama_cpp | |
def ensure_q4_gguf_model(): | |
"""Download a TinyLlama Q4 GGUF model if not present and return the local path.""" | |
global gguf_model_path | |
cache_dir = os.path.join(tempfile.gettempdir(), "gguf_models") | |
os.makedirs(cache_dir, exist_ok=True) | |
# Use a small, permissively accessible TinyLlama GGUF | |
filename = "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf" | |
gguf_model_path = os.path.join(cache_dir, filename) | |
if not os.path.exists(gguf_model_path): | |
try: | |
url = ( | |
"https://huggingface.co/TinyLlama/" | |
"TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/" | |
+ filename | |
) | |
print(f"[GGUF] Downloading model from {url} -> {gguf_model_path}") | |
with requests.get(url, stream=True, timeout=60) as r: | |
r.raise_for_status() | |
with open(gguf_model_path, 'wb') as f: | |
for chunk in r.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
print("[GGUF] Download complete") | |
except Exception as e: | |
print(f"[GGUF] Failed to download GGUF model: {e}") | |
return None | |
return gguf_model_path | |
def get_llama_cpp_model(): | |
"""Lazy-load llama.cpp model from local GGUF path.""" | |
global llama_cpp_model | |
if llama_cpp_model is not None: | |
return llama_cpp_model | |
model_path = ensure_q4_gguf_model() | |
if not model_path: | |
return None | |
try: | |
print(f"[GGUF] Loading llama.cpp model: {model_path}") | |
llama_cpp_model = llama_cpp.Llama( | |
model_path=model_path, | |
n_ctx=4096, | |
n_threads=max(1, os.cpu_count() or 1), | |
n_gpu_layers=0, # CPU-friendly default; adjust if GPU offload available | |
verbose=False, | |
) | |
print("[GGUF] llama.cpp model loaded") | |
except Exception as e: | |
print(f"[GGUF] Failed to load llama.cpp model: {e}") | |
llama_cpp_model = None | |
return llama_cpp_model | |
except Exception as _e: | |
llama_cpp = None | |
llama_cpp_model = None | |
gguf_model_path = None | |
# LLM model (using an open-access model instead of Llama 4 which requires authentication) | |
llm_model = None | |
llm_tokenizer = None | |
try: | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
print("Loading LLM model... This may take a moment.") | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Using TinyLlama as an open-access alternative | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
# Removing options that require accelerate package | |
# device_map="auto", | |
# load_in_8bit=True | |
).to(device) | |
print("LLM model loaded successfully") | |
except Exception as e: | |
print(f"Error loading LLM model: {e}") | |
llm_model = None | |
llm_tokenizer = None | |
def process_llm_query(vision_results, user_query): | |
"""Process a query with the LLM model using vision results and user text""" | |
if llm_model is None or llm_tokenizer is None: | |
return {"error": "LLM model not available"} | |
# κ²°κ³Ό λ°μ΄ν° μμ½ (ν ν° κΈΈμ΄ μ νμ μν΄) | |
summarized_results = [] | |
# κ°μ²΄ νμ§ κ²°κ³Ό μμ½ | |
if isinstance(vision_results, list): | |
# μ΅λ 10κ° κ°μ²΄λ§ ν¬ν¨ | |
for i, obj in enumerate(vision_results[:10]): | |
if isinstance(obj, dict): | |
# νμν μ λ³΄λ§ μΆμΆ | |
summary = { | |
"label": obj.get("label", "unknown"), | |
"confidence": obj.get("confidence", 0), | |
} | |
summarized_results.append(summary) | |
# Create a prompt combining vision results and user query | |
prompt = f"""You are an AI assistant analyzing image detection results. | |
Here are the objects detected in the image: {json.dumps(summarized_results, indent=2)} | |
User question: {user_query} | |
Please provide a detailed analysis based on the detected objects and the user's question. | |
""" | |
# Tokenize and generate response | |
try: | |
start_time = time.time() | |
# ν ν° κΈΈμ΄ νμΈ λ° μ ν | |
tokens = llm_tokenizer.encode(prompt) | |
if len(tokens) > 1500: # μμ λ§μ§ μ€μ | |
prompt = f"""You are an AI assistant analyzing image detection results. | |
The image contains {len(summarized_results)} detected objects. | |
User question: {user_query} | |
Please provide a general analysis based on the user's question. | |
""" | |
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = llm_model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
# Remove the prompt from the response | |
if response_text.startswith(prompt): | |
response_text = response_text[len(prompt):].strip() | |
inference_time = time.time() - start_time | |
return { | |
"response": response_text, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": "GPU" if torch.cuda.is_available() else "CPU" | |
} | |
} | |
except Exception as e: | |
return {"error": f"Error processing LLM query: {str(e)}"} | |
def image_to_base64(img): | |
"""Convert PIL Image to base64 string""" | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
return img_str | |
def process_yolo(image): | |
if yolo_model is None: | |
return {"error": "YOLOv8 model not loaded"} | |
# Measure inference time | |
start_time = time.time() | |
# Convert to numpy if it's a PIL image | |
if isinstance(image, Image.Image): | |
image_np = np.array(image) | |
else: | |
image_np = image | |
# Run inference | |
results = yolo_model(image_np) | |
# Process results | |
result_image = results[0].plot() | |
result_image = Image.fromarray(result_image) | |
# Get detection information | |
boxes = results[0].boxes | |
class_names = results[0].names | |
# Format detection results | |
detections = [] | |
for box in boxes: | |
class_id = int(box.cls[0].item()) | |
class_name = class_names[class_id] | |
confidence = round(box.conf[0].item(), 2) | |
bbox = box.xyxy[0].tolist() | |
bbox = [round(x) for x in bbox] | |
detections.append({ | |
"class": class_name, | |
"confidence": confidence, | |
"bbox": bbox | |
}) | |
# Calculate inference time | |
inference_time = time.time() - start_time | |
# Add inference time and device info | |
device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
return { | |
"image": image_to_base64(result_image), | |
"detections": detections, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": device_info | |
} | |
} | |
def process_detr(image): | |
if detr_model is None or detr_processor is None: | |
return {"error": "DETR model not loaded"} | |
# Measure inference time | |
start_time = time.time() | |
# Prepare image for the model | |
inputs = detr_processor(images=image, return_tensors="pt") | |
# Run inference | |
with torch.no_grad(): | |
outputs = detr_model(**inputs) | |
# Process results | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = detr_processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=0.9 | |
)[0] | |
# Create a copy of the image to draw on | |
result_image = image.copy() | |
fig, ax = plt.subplots(1) | |
ax.imshow(result_image) | |
# Format detection results | |
detections = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
box = [round(i) for i in box.tolist()] | |
class_name = detr_model.config.id2label[label.item()] | |
confidence = round(score.item(), 2) | |
# Draw rectangle | |
rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], | |
linewidth=2, edgecolor='r', facecolor='none') | |
ax.add_patch(rect) | |
# Add label | |
plt.text(box[0], box[1], "{}: {}".format(class_name, confidence), | |
bbox=dict(facecolor='white', alpha=0.8)) | |
detections.append({ | |
"class": class_name, | |
"confidence": confidence, | |
"bbox": box | |
}) | |
# Save figure to image | |
buf = io.BytesIO() | |
plt.tight_layout() | |
plt.axis('off') | |
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
buf.seek(0) | |
result_image = Image.open(buf) | |
plt.close(fig) | |
# Calculate inference time | |
inference_time = time.time() - start_time | |
# Add inference time and device info | |
device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
return { | |
"image": image_to_base64(result_image), | |
"detections": detections, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": device_info | |
} | |
} | |
def process_vit(image): | |
if vit_model is None or vit_processor is None: | |
return {"error": "ViT model not loaded"} | |
# Measure inference time | |
start_time = time.time() | |
# Prepare image for the model | |
inputs = vit_processor(images=image, return_tensors="pt") | |
# Run inference | |
with torch.no_grad(): | |
outputs = vit_model(**inputs) | |
logits = outputs.logits | |
# Get the predicted class | |
predicted_class_idx = logits.argmax(-1).item() | |
prediction = vit_model.config.id2label[predicted_class_idx] | |
# Get top 5 predictions | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
top5_prob, top5_indices = torch.topk(probs, 5) | |
results = [] | |
for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): | |
class_name = vit_model.config.id2label[idx.item()] | |
results.append({ | |
"rank": i+1, | |
"class": class_name, | |
"probability": round(prob.item(), 3) | |
}) | |
# Calculate inference time | |
inference_time = time.time() - start_time | |
# Add inference time and device info | |
device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
return { | |
"top_predictions": results, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": device_info | |
} | |
} | |
def yolo_detect(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
image = Image.open(file.stream) | |
result = process_yolo(image) | |
return jsonify(result) | |
# --- HF CausalLM + LoRA helper caches --- | |
hf_tokenizers = {} | |
hf_base_models = {} | |
hf_lora_models = {} | |
def _preferred_dtype(): | |
try: | |
return torch.float16 if torch.cuda.is_available() else torch.float32 | |
except Exception: | |
return torch.float32 | |
def load_hf_base_and_tokenizer(base_id: str, tok_id: str = None): | |
"""Load and cache base CausalLM and tokenizer. tok_id can point to adapter repo to use its tokenizer.""" | |
global hf_tokenizers, hf_base_models | |
tok_key = tok_id or base_id | |
if tok_key not in hf_tokenizers: | |
hf_tokenizers[tok_key] = AutoTokenizer.from_pretrained(tok_key, use_fast=True) | |
if base_id not in hf_base_models: | |
hf_base_models[base_id] = AutoModelForCausalLM.from_pretrained( | |
base_id, | |
torch_dtype=_preferred_dtype(), | |
use_safetensors=False, | |
).to(device) | |
return hf_tokenizers[tok_key], hf_base_models[base_id] | |
def load_hf_lora_model(base_id: str, adapter_id: str): | |
"""Load and cache a PEFT-wrapped LoRA model from base + adapter. Returns model or raises if PEFT unavailable.""" | |
if PeftModel is None: | |
raise RuntimeError("PEFT (peft) is not installed; cannot load LoRA adapter") | |
key = f"{base_id}::{adapter_id}" | |
if key in hf_lora_models: | |
return hf_lora_models[key] | |
# Create a fresh base to avoid mutating the cached base used for non-LoRA runs | |
base = AutoModelForCausalLM.from_pretrained( | |
base_id, | |
torch_dtype=_preferred_dtype(), | |
use_safetensors=False, | |
).to(device) | |
lora_model = PeftModel.from_pretrained(base, adapter_id).eval().to(device) | |
hf_lora_models[key] = lora_model | |
return lora_model | |
def detr_detect(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
image = Image.open(file.stream) | |
result = process_detr(image) | |
return jsonify(result) | |
def vit_classify(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
image = Image.open(file.stream) | |
result = process_vit(image) | |
return jsonify(result) | |
def analyze_with_llm(): | |
# Check if required data is in the request | |
if not request.json: | |
return jsonify({"error": "No JSON data provided"}), 400 | |
# Extract vision results and user query from request | |
data = request.json | |
if 'visionResults' not in data or 'userQuery' not in data: | |
return jsonify({"error": "Missing required fields: visionResults or userQuery"}), 400 | |
vision_results = data['visionResults'] | |
user_query = data['userQuery'] | |
# Process the query with LLM | |
result = process_llm_query(vision_results, user_query) | |
return jsonify(result) | |
def generate_image_embedding(image): | |
"""CLIP λͺ¨λΈμ μ¬μ©νμ¬ μ΄λ―Έμ§ μλ² λ© μμ±""" | |
if clip_model is None or clip_processor is None: | |
return None | |
try: | |
# μ΄λ―Έμ§ μ μ²λ¦¬ | |
inputs = clip_processor(images=image, return_tensors="pt") | |
# μ΄λ―Έμ§ μλ² λ© μμ± | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**inputs) | |
# μλ² λ© μ κ·ν λ° numpy λ°°μ΄λ‘ λ³ν | |
image_embedding = image_features.squeeze().cpu().numpy() | |
normalized_embedding = image_embedding / np.linalg.norm(image_embedding) | |
return normalized_embedding.tolist() | |
except Exception as e: | |
print(f"Error generating image embedding: {e}") | |
return None | |
def find_similar_images(): | |
"""μ μ¬ μ΄λ―Έμ§ κ²μ API""" | |
if clip_model is None or clip_processor is None or image_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# μμ²μμ μ΄λ―Έμ§ λ°μ΄ν° μΆμΆ | |
if 'image' not in request.files and 'image' not in request.form: | |
return jsonify({"error": "No image provided"}) | |
if 'image' in request.files: | |
# νμΌλ‘ μ λ‘λλ κ²½μ° | |
image_file = request.files['image'] | |
image = Image.open(image_file).convert('RGB') | |
else: | |
# base64λ‘ μΈμ½λ©λ κ²½μ° | |
image_data = request.form['image'] | |
if image_data.startswith('data:image'): | |
# Remove the data URL prefix if present | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
# μ΄λ―Έμ§ ID μμ± (μμ) | |
image_id = str(uuid.uuid4()) | |
# μ΄λ―Έμ§ μλ² λ© μμ± | |
embedding = generate_image_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate image embedding"}) | |
# νμ¬ μ΄λ―Έμ§λ₯Ό DBμ μΆκ° (μ νμ ) | |
# image_collection.add( | |
# ids=[image_id], | |
# embeddings=[embedding] | |
# ) | |
# μ μ¬ μ΄λ―Έμ§ κ²μ | |
results = image_collection.query( | |
query_embeddings=[embedding], | |
n_results=5 # μμ 5κ° κ²°κ³Ό λ°ν | |
) | |
# κ²°κ³Ό ν¬λ§·ν | |
similar_images = [] | |
if len(results['ids'][0]) > 0: | |
for i, img_id in enumerate(results['ids'][0]): | |
similar_images.append({ | |
"id": img_id, | |
"distance": float(results['distances'][0][i]) if 'distances' in results else 0.0, | |
"metadata": results['metadatas'][0][i] if 'metadatas' in results else {} | |
}) | |
return jsonify({ | |
"query_image_id": image_id, | |
"similar_images": similar_images | |
}) | |
except Exception as e: | |
print(f"Error in similar-images API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def add_to_collection(): | |
"""μ΄λ―Έμ§λ₯Ό λ²‘ν° DBμ μΆκ°νλ API""" | |
if clip_model is None or clip_processor is None or image_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# μμ²μμ μ΄λ―Έμ§ λ°μ΄ν° μΆμΆ | |
if 'image' not in request.files and 'image' not in request.form: | |
return jsonify({"error": "No image provided"}) | |
# λ©νλ°μ΄ν° μΆμΆ | |
metadata = {} | |
if 'metadata' in request.form: | |
metadata = json.loads(request.form['metadata']) | |
# μ΄λ―Έμ§ ID (μ 곡λμ§ μμ κ²½μ° μλ μμ±) | |
image_id = request.form.get('id', str(uuid.uuid4())) | |
if 'image' in request.files: | |
# νμΌλ‘ μ λ‘λλ κ²½μ° | |
image_file = request.files['image'] | |
image = Image.open(image_file).convert('RGB') | |
else: | |
# base64λ‘ μΈμ½λ©λ κ²½μ° | |
image_data = request.form['image'] | |
if image_data.startswith('data:image'): | |
# Remove the data URL prefix if present | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
# μ΄λ―Έμ§ μλ² λ© μμ± | |
embedding = generate_image_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate image embedding"}) | |
# μ΄λ―Έμ§ λ°μ΄ν°λ₯Ό base64λ‘ μΈμ½λ©νμ¬ λ©νλ°μ΄ν°μ μΆκ° | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
metadata['image_data'] = img_str | |
# μ΄λ―Έμ§λ₯Ό DBμ μΆκ° | |
image_collection.add( | |
ids=[image_id], | |
embeddings=[embedding], | |
metadatas=[metadata] | |
) | |
return jsonify({ | |
"success": True, | |
"image_id": image_id, | |
"message": "Image added to collection" | |
}) | |
except Exception as e: | |
print(f"Error in add-to-collection API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def add_detected_objects(): | |
"""κ°μ²΄ μΈμ κ²°κ³Όλ₯Ό λ²‘ν° DBμ μΆκ°νλ API""" | |
if clip_model is None or object_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# λλ²κΉ : μμ² λ°μ΄ν° λ‘κΉ | |
print("[DEBUG] Received request in add-detected-objects") | |
# μμ²μμ μ΄λ―Έμ§μ κ°μ²΄ κ²μΆ κ²°κ³Ό λ°μ΄ν° μΆμΆ | |
data = request.json | |
print(f"[DEBUG] Request data keys: {list(data.keys()) if data else 'None'}") | |
if not data: | |
print("[DEBUG] Error: No data received in request") | |
return jsonify({"error": "No data received"}) | |
if 'image' not in data: | |
print("[DEBUG] Error: 'image' key not found in request data") | |
return jsonify({"error": "Missing image data"}) | |
if 'objects' not in data: | |
print("[DEBUG] Error: 'objects' key not found in request data") | |
return jsonify({"error": "Missing objects data"}) | |
# μ΄λ―Έμ§ λ°μ΄ν° λλ²κΉ | |
print(f"[DEBUG] Image data type: {type(data['image'])}") | |
print(f"[DEBUG] Image data starts with: {data['image'][:50]}...") # μ²μ 50μλ§ μΆλ ₯ | |
# κ°μ²΄ λ°μ΄ν° λλ²κΉ | |
print(f"[DEBUG] Objects data type: {type(data['objects'])}") | |
print(f"[DEBUG] Objects count: {len(data['objects']) if isinstance(data['objects'], list) else 'Not a list'}") | |
if isinstance(data['objects'], list) and len(data['objects']) > 0: | |
print(f"[DEBUG] First object keys: {list(data['objects'][0].keys()) if isinstance(data['objects'][0], dict) else 'Not a dict'}") | |
# μ΄λ―Έμ§ λ°μ΄ν° μ²λ¦¬ | |
image_data = data['image'] | |
if image_data.startswith('data:image'): | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
image_width, image_height = image.size | |
# μ΄λ―Έμ§ ID | |
image_id = data.get('imageId', str(uuid.uuid4())) | |
# κ°μ²΄ λ°μ΄ν° μ²λ¦¬ | |
objects = data['objects'] | |
object_ids = [] | |
object_embeddings = [] | |
object_metadatas = [] | |
for obj in objects: | |
# κ°μ²΄ ID μμ± | |
object_id = f"{image_id}_{str(uuid.uuid4())[:8]}" | |
# λ°μ΄λ© λ°μ€ μ 보 μΆμΆ | |
bbox = obj.get('bbox', []) | |
# 리μ€νΈ ννμ bbox [x1, y1, x2, y2] μ²λ¦¬ | |
if isinstance(bbox, list) and len(bbox) >= 4: | |
x1 = bbox[0] / image_width # μ κ·νλ μ’νλ‘ λ³ν | |
y1 = bbox[1] / image_height | |
x2 = bbox[2] / image_width | |
y2 = bbox[3] / image_height | |
width = x2 - x1 | |
height = y2 - y1 | |
# λμ λ리 ννμ bbox {'x': x, 'y': y, 'width': width, 'height': height} μ²λ¦¬ | |
elif isinstance(bbox, dict): | |
x1 = bbox.get('x', 0) | |
y1 = bbox.get('y', 0) | |
width = bbox.get('width', 0) | |
height = bbox.get('height', 0) | |
else: | |
# κΈ°λ³Έκ° μ€μ | |
x1, y1, width, height = 0, 0, 0, 0 | |
# λ°μ΄λ© λ°μ€λ₯Ό μ΄λ―Έμ§ μ’νλ‘ λ³ν | |
x1_px = int(x1 * image_width) | |
y1_px = int(y1 * image_height) | |
width_px = int(width * image_width) | |
height_px = int(height * image_height) | |
# κ°μ²΄ μ΄λ―Έμ§ μλ₯΄κΈ° | |
try: | |
object_image = image.crop((x1_px, y1_px, x1_px + width_px, y1_px + height_px)) | |
# μλ² λ© μμ± | |
embedding = generate_image_embedding(object_image) | |
if embedding is None: | |
continue | |
# λ©νλ°μ΄ν° κ΅¬μ± | |
# bboxλ₯Ό JSON λ¬Έμμ΄λ‘ μ§λ ¬ννμ¬ ChromaDB λ©νλ°μ΄ν° μ ν μ°ν | |
bbox_json = json.dumps({ | |
"x": x1, | |
"y": y1, | |
"width": width, | |
"height": height | |
}) | |
# κ°μ²΄ μ΄λ―Έμ§λ₯Ό base64λ‘ μΈμ½λ© | |
buffered = BytesIO() | |
object_image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
metadata = { | |
"image_id": image_id, | |
"class": obj.get('class', ''), | |
"confidence": obj.get('confidence', 0), | |
"bbox": bbox_json, # JSON λ¬Έμμ΄λ‘ μ μ₯ | |
"image_data": img_str # μ΄λ―Έμ§ λ°μ΄ν° μΆκ° | |
} | |
object_ids.append(object_id) | |
object_embeddings.append(embedding) | |
object_metadatas.append(metadata) | |
except Exception as e: | |
print(f"Error processing object: {e}") | |
continue | |
# Add to vector DB | |
if object_ids and object_embeddings and object_metadatas: | |
object_collection.add( | |
ids=object_ids, | |
embeddings=object_embeddings, | |
metadatas=object_metadatas | |
) | |
return jsonify({ | |
"success": True, | |
"message": f"Added {len(object_ids)} objects to collection", | |
"object_ids": object_ids | |
}) | |
else: | |
return jsonify({ | |
"warning": "No valid objects to add" | |
}) | |
except Exception as e: | |
print(f"Error in add-detected-objects API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
# Product Comparison API Endpoints | |
def start_product_comparison(): | |
"""Start a new product comparison session""" | |
print(f"[DEBUG] π start_product_comparison endpoint called") | |
print(f"[DEBUG] π get_product_comparison_coordinator: {get_product_comparison_coordinator}") | |
if get_product_comparison_coordinator is None: | |
print(f"[DEBUG] β Product comparison coordinator is None - returning 500") | |
# Try to import again and show the error | |
print(f"[DEBUG] π Attempting emergency import test...") | |
try: | |
from product_comparison import get_product_comparison_coordinator as test_coordinator | |
print(f"[DEBUG] π― Emergency import succeeded: {test_coordinator}") | |
except Exception as e: | |
print(f"[DEBUG] β Emergency import failed: {e}") | |
import traceback | |
traceback.print_exc() | |
return jsonify({"error": "Product comparison module not available"}), 500 | |
try: | |
print(f"[DEBUG] π Processing request data...") | |
# Generate session ID if provided in form or query params, otherwise create new one | |
session_id = request.form.get('session_id') or request.args.get('session_id') or str(uuid.uuid4()) | |
print(f"[DEBUG] π Session ID: {session_id}") | |
# Get analysis type if provided (info, compare, value, recommend) | |
analysis_type = request.form.get('analysisType') or request.args.get('analysisType', 'info') | |
print(f"[DEBUG] π Analysis type: {analysis_type}") | |
# Process images from FormData or JSON | |
images = [] | |
print(f"[DEBUG] πΌοΈ Processing images...") | |
print(f"[DEBUG] π Request files: {list(request.files.keys())}") | |
print(f"[DEBUG] π Request form: {dict(request.form)}") | |
# Check if request is multipart form data | |
if request.files: | |
print(f"[DEBUG] π Processing multipart form data...") | |
# Handle FormData with file uploads (from frontend) | |
if 'image1' in request.files and request.files['image1']: | |
img1 = request.files['image1'] | |
print(f"[DEBUG] πΌοΈ Processing image1: {img1.filename}") | |
try: | |
# Load image and convert to RGB to ensure it's fully loaded in memory | |
image = Image.open(img1.stream).convert('RGB') | |
images.append(image) | |
print(f"[DEBUG] β Image1 processed successfully") | |
except Exception as e: | |
print(f"[DEBUG] β Error processing image1: {e}") | |
if 'image2' in request.files and request.files['image2']: | |
img2 = request.files['image2'] | |
print(f"[DEBUG] πΌοΈ Processing image2: {img2.filename}") | |
try: | |
# Load image and convert to RGB to ensure it's fully loaded in memory | |
image = Image.open(img2.stream).convert('RGB') | |
images.append(image) | |
print(f"[DEBUG] β Image2 processed successfully") | |
except Exception as e: | |
print(f"[DEBUG] β Error processing image2: {e}") | |
# Fallback to JSON with base64 images (for API testing) | |
elif request.json and 'images' in request.json: | |
print(f"[DEBUG] π Processing JSON with base64 images...") | |
image_data_list = request.json.get('images', []) | |
for i, image_data in enumerate(image_data_list): | |
print(f"[DEBUG] πΌοΈ Processing base64 image {i+1}") | |
img = decode_base64_image(image_data) | |
if img is not None: | |
images.append(img) | |
print(f"[DEBUG] β Base64 image {i+1} processed successfully") | |
else: | |
print(f"[DEBUG] β Failed to decode base64 image {i+1}") | |
print(f"[DEBUG] π Total images processed: {len(images)}") | |
if not images: | |
print(f"[DEBUG] β No valid images provided - returning 400") | |
return jsonify({"error": "No valid images provided"}), 400 | |
# Get coordinator instance | |
print(f"[DEBUG] π― Getting coordinator instance...") | |
coordinator = get_product_comparison_coordinator() | |
print(f"[DEBUG] β Coordinator obtained: {type(coordinator).__name__}") | |
# Pass the analysis type and session metadata to the coordinator | |
session_metadata = { | |
'analysis_type': analysis_type, | |
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S') | |
} | |
print(f"[DEBUG] π Session metadata: {session_metadata}") | |
# Start processing in a background thread | |
print(f"[DEBUG] π§΅ Starting background processing thread...") | |
def run_async_task(loop): | |
try: | |
print(f"[DEBUG] π Setting event loop and starting coordinator.process_images...") | |
asyncio.set_event_loop(loop) | |
loop.run_until_complete(coordinator.process_images(session_id, images, session_metadata)) | |
print(f"[DEBUG] β coordinator.process_images completed successfully") | |
except Exception as e: | |
print(f"[DEBUG] β Error in async task: {e}") | |
import traceback | |
traceback.print_exc() | |
loop = asyncio.new_event_loop() | |
thread = Thread(target=run_async_task, args=(loop,)) | |
thread.daemon = True | |
thread.start() | |
print(f"[DEBUG] π Background thread started") | |
# Return session ID for client to use with streaming endpoint | |
response_data = { | |
"session_id": session_id, | |
"message": "Product comparison started", | |
"status": "processing" | |
} | |
print(f"[DEBUG] β Returning success response: {response_data}") | |
return jsonify(response_data) | |
except Exception as e: | |
print(f"[DEBUG] β Exception in start_product_comparison: {e}") | |
print(f"[DEBUG] β Exception type: {type(e).__name__}") | |
import traceback | |
print(f"[DEBUG] β Full traceback:") | |
traceback.print_exc() | |
return jsonify({"error": str(e)}), 500 | |
def stream_product_comparison(session_id): | |
"""Stream updates from a product comparison session""" | |
if get_product_comparison_coordinator is None: | |
return jsonify({"error": "Product comparison module not available"}), 500 | |
def generate(): | |
"""Generate SSE events for streaming""" | |
print(f"[DEBUG] π Starting SSE stream for session: {session_id}") | |
coordinator = get_product_comparison_coordinator() | |
last_message_index = 0 | |
retry_count = 0 | |
max_retries = 300 # 5 minutes at 1 second intervals | |
while retry_count < max_retries: | |
# Get current status | |
status = coordinator.get_session_status(session_id) | |
print(f"[DEBUG] π Session {session_id} status: {status}") | |
if status is None: | |
# Session not found | |
print(f"[DEBUG] β Session {session_id} not found") | |
yield f"data: {json.dumps({'error': 'Session not found'})}\n\n" | |
break | |
# Get all messages | |
messages = coordinator.get_session_messages(session_id) | |
print(f"[DEBUG] π Session {session_id} has {len(messages) if messages else 0} messages") | |
# Send any new messages | |
if messages and len(messages) > last_message_index: | |
new_messages = messages[last_message_index:] | |
print(f"[DEBUG] π€ Sending {len(new_messages)} new messages") | |
for msg in new_messages: | |
print(f"[DEBUG] π¨ Message: {msg}") | |
# msg is already a dict with message, agent, timestamp | |
yield f"data: {json.dumps(msg)}\n\n" | |
last_message_index = len(messages) | |
# Send current status | |
yield f"data: {json.dumps({'status': status})}\n\n" | |
# If completed or error, send final result and end stream | |
if status in ['completed', 'error']: | |
result = coordinator.get_session_result(session_id) | |
print(f"[DEBUG] π Session {session_id} finished with status: {status}") | |
print(f"[DEBUG] π― Final result: {result}") | |
yield f"data: {json.dumps({'final_result': result})}\n\n" | |
break | |
# Wait before next update | |
time.sleep(1) | |
retry_count += 1 | |
# End the stream if we've reached max retries | |
if retry_count >= max_retries: | |
print(f"[DEBUG] β° Session {session_id} timed out after {max_retries} retries") | |
yield f"data: {json.dumps({'error': 'Timeout waiting for results'})}\n\n" | |
return Response( | |
stream_with_context(generate()), | |
mimetype="text/event-stream", | |
headers={ | |
'Cache-Control': 'no-cache', | |
'X-Accel-Buffering': 'no', | |
'Content-Type': 'text/event-stream', | |
} | |
) | |
# ============================ | |
# LLM LoRA Compare Endpoints | |
# ============================ | |
# Simple in-memory session store for LoRA compare | |
lora_sessions = {} | |
def lora_add_message(session_id, message, msg_type="info"): | |
sess = lora_sessions.get(session_id) | |
if not sess: | |
return | |
ts = time.strftime('%Y-%m-%d %H:%M:%S') | |
sess['messages'].append({ | |
'message': message, | |
'type': msg_type, | |
'timestamp': ts | |
}) | |
def start_llama_lora_compare(): | |
"""Start a LoRA-vs-Base comparison session (text or image+text prompt).""" | |
session_id = request.form.get('session_id') or str(uuid.uuid4()) | |
prompt = request.form.get('prompt', '') | |
# Default to local GGUF TinyLlama Q4 model via llama.cpp | |
base_model_id = request.form.get('baseModel', 'gguf:tinyllama-q4km') | |
lora_path = request.form.get('loraPath', '') | |
image_b64 = None | |
if 'image' in request.files: | |
try: | |
img = Image.open(request.files['image'].stream).convert('RGB') | |
buffer = BytesIO() | |
img.save(buffer, format='PNG') | |
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
except Exception as _e: | |
pass | |
# Initialize session | |
lora_sessions[session_id] = { | |
'status': 'processing', | |
'messages': [], | |
'result': None, | |
} | |
lora_add_message(session_id, 'LoRA comparison started', 'system') | |
def worker(): | |
try: | |
lora_add_message(session_id, f"Base model: {base_model_id}") | |
if lora_path: | |
lora_add_message(session_id, f"LoRA adapter: {lora_path}") | |
else: | |
lora_add_message(session_id, "No LoRA adapter provided; using mock output.") | |
# Prepare prompt | |
full_prompt = prompt or 'Describe the content.' | |
if image_b64: | |
lora_add_message(session_id, 'Image provided; running vision+language prompt.') | |
# Run base inference (best-effort) | |
start_base = time.time() | |
base_output = None | |
try: | |
# If base_model_id indicates GGUF, use llama.cpp | |
if base_model_id.startswith('gguf:') and llama_cpp is not None: | |
model = get_llama_cpp_model() | |
if model is None: | |
raise RuntimeError('GGUF model unavailable') | |
# Simple chat-style prompt | |
prompt_text = f"You are a helpful assistant.\nUser: {full_prompt}\nAssistant:" | |
res = model( | |
prompt=prompt_text, | |
max_tokens=128, | |
temperature=0.7, | |
top_p=0.9, | |
stop=["User:", "\n\n"], | |
) | |
text = res.get('choices', [{}])[0].get('text', '').strip() | |
# Ensure we don't return an empty base output | |
base_output = text or "[empty]" | |
elif '/' in base_model_id: | |
# HF Transformers path (e.g., openlm-research/open_llama_3b) | |
tok_id = lora_path or base_model_id # user expects tokenizer from adapter repo when provided | |
tokenizer_hf, base_model_hf = load_hf_base_and_tokenizer(base_model_id, tok_id) | |
inputs = tokenizer_hf(full_prompt, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
out = base_model_hf.generate(**inputs, max_new_tokens=128, temperature=0.7, top_p=0.9) | |
text = tokenizer_hf.decode(out[0], skip_special_tokens=True) | |
if text.startswith(full_prompt): | |
stripped = text[len(full_prompt):].strip() | |
text = stripped if stripped else text | |
base_output = text or "[empty]" | |
elif llm_model is not None and llm_tokenizer is not None: | |
inputs = llm_tokenizer(full_prompt, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
out = llm_model.generate(**inputs, max_new_tokens=128, temperature=0.7, top_p=0.9) | |
text = llm_tokenizer.decode(out[0], skip_special_tokens=True) | |
if text.startswith(full_prompt): | |
stripped = text[len(full_prompt):].strip() | |
text = stripped if stripped else text | |
base_output = text or "[empty]" | |
else: | |
base_output = f"[mock] Base response for: {full_prompt[:80]}..." | |
except Exception as e: | |
base_output = f"[error] Base inference failed: {e}" | |
base_latency = int((time.time() - start_base) * 1000) | |
lora_add_message(session_id, f"Base inference done in {base_latency} ms") | |
# Run LoRA inference (prefer HF+PEFT when adapter provided) | |
start_lora = time.time() | |
try: | |
if lora_path and '/' in base_model_id: | |
tokenizer_hf, _ = load_hf_base_and_tokenizer(base_model_id, lora_path) | |
lora_model = load_hf_lora_model(base_model_id, lora_path) | |
inputs = tokenizer_hf(full_prompt, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
out = lora_model.generate(**inputs, max_new_tokens=128, temperature=0.7, top_p=0.9) | |
text = tokenizer_hf.decode(out[0], skip_special_tokens=True) | |
if text.startswith(full_prompt): | |
stripped = text[len(full_prompt):].strip() | |
text = stripped if stripped else text | |
lora_output = text or "[empty]" | |
else: | |
lora_output = f"[mock] LoRA response (no adapter) for: {full_prompt[:80]}..." | |
except Exception as e: | |
lora_output = f"[error] LoRA inference failed: {e}" | |
lora_latency = int((time.time() - start_lora) * 1000) | |
lora_add_message(session_id, f"LoRA inference done in {lora_latency} ms") | |
lora_sessions[session_id]['result'] = { | |
'prompt': full_prompt, | |
'image': image_b64, | |
'base': { 'output': base_output, 'latency_ms': base_latency }, | |
'lora': { 'output': lora_output, 'latency_ms': lora_latency }, | |
} | |
lora_sessions[session_id]['status'] = 'completed' | |
lora_add_message(session_id, 'Comparison completed', 'system') | |
except Exception as e: | |
lora_sessions[session_id]['status'] = 'error' | |
lora_sessions[session_id]['result'] = { | |
'error': str(e) | |
} | |
lora_add_message(session_id, f"Error: {e}", 'error') | |
Thread(target=worker, daemon=True).start() | |
return jsonify({ 'session_id': session_id, 'status': 'processing' }) | |
def stream_llama_lora_compare(session_id): | |
"""SSE stream for LoRA comparison progress and final result.""" | |
def generate(): | |
last_idx = 0 | |
retries = 0 | |
max_retries = 300 | |
while retries < max_retries: | |
sess = lora_sessions.get(session_id) | |
if not sess: | |
yield f"data: {json.dumps({'error': 'Session not found'})}\n\n" | |
break | |
msgs = sess['messages'] | |
if len(msgs) > last_idx: | |
for m in msgs[last_idx:]: | |
yield f"data: {json.dumps(m)}\n\n" | |
last_idx = len(msgs) | |
yield f"data: {json.dumps({'status': sess['status']})}\n\n" | |
if sess['status'] in ('completed', 'error'): | |
yield f"data: {json.dumps({'final_result': sess['result']})}\n\n" | |
break | |
time.sleep(1) | |
retries += 1 | |
if retries >= max_retries: | |
yield f"data: {json.dumps({'error': 'Timeout waiting for results'})}\n\n" | |
return Response( | |
stream_with_context(generate()), | |
mimetype='text/event-stream', | |
headers={ | |
'Cache-Control': 'no-cache', | |
'X-Accel-Buffering': 'no', | |
'Content-Type': 'text/event-stream', | |
} | |
) | |
def search_similar_objects(): | |
"""μ μ¬ν κ°μ²΄ κ²μ API""" | |
print("[DEBUG] Received request in search-similar-objects") | |
if clip_model is None or object_collection is None: | |
print("[DEBUG] Error: Image embedding model or vector DB not available") | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# μμ² λ°μ΄ν° μΆμΆ | |
data = request.json | |
print(f"[DEBUG] Request data keys: {list(data.keys()) if data else 'None'}") | |
if not data: | |
print("[DEBUG] Error: Missing request data") | |
return jsonify({"error": "Missing request data"}) | |
# κ²μ μ ν κ²°μ | |
search_type = data.get('searchType', 'image') | |
n_results = int(data.get('n_results', 5)) # κ²°κ³Ό κ°μ | |
print(f"[DEBUG] Search type: {search_type}, n_results: {n_results}") | |
query_embedding = None | |
if search_type == 'image' and 'image' in data: | |
# μ΄λ―Έμ§λ‘ κ²μνλ κ²½μ° | |
print("[DEBUG] Searching by image") | |
image_data = data['image'] | |
if image_data.startswith('data:image'): | |
image_data = image_data.split(',')[1] | |
try: | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
query_embedding = generate_image_embedding(image) | |
print(f"[DEBUG] Generated image embedding: {type(query_embedding)}, shape: {len(query_embedding) if query_embedding is not None else 'None'}") | |
except Exception as e: | |
print(f"[DEBUG] Error generating image embedding: {e}") | |
return jsonify({"error": f"Error processing image: {str(e)}"}), 500 | |
elif search_type == 'object' and 'objectId' in data: | |
# κ°μ²΄ IDλ‘ κ²μνλ κ²½μ° | |
object_id = data['objectId'] | |
result = object_collection.get(ids=[object_id], include=["embeddings"]) | |
if result and "embeddings" in result and len(result["embeddings"]) > 0: | |
query_embedding = result["embeddings"][0] | |
elif search_type == 'class' and 'class_name' in data: | |
# ν΄λμ€ μ΄λ¦μΌλ‘ κ²μνλ κ²½μ° | |
print("[DEBUG] Searching by class name") | |
class_name = data['class_name'] | |
print(f"[DEBUG] Class name: {class_name}") | |
filter_query = {"class": {"$eq": class_name}} | |
try: | |
# ν΄λμ€λ‘ νν°λ§νμ¬ κ²μ | |
print(f"[DEBUG] Querying with filter: {filter_query}") | |
# Use get method instead of query for class-based filtering | |
results = object_collection.get( | |
where=filter_query, | |
limit=n_results, | |
include=["metadatas", "embeddings", "documents"] | |
) | |
print(f"[DEBUG] Query results: {results['ids'][0] if 'ids' in results and len(results['ids']) > 0 else 'No results'}") | |
formatted_results = format_object_results(results) | |
print(f"[DEBUG] Formatted results count: {len(formatted_results)}") | |
return jsonify({ | |
"success": True, | |
"searchType": "class", | |
"results": formatted_results | |
}) | |
except Exception as e: | |
print(f"[DEBUG] Error in class search: {e}") | |
return jsonify({"error": f"Error in class search: {str(e)}"}), 500 | |
else: | |
print(f"[DEBUG] Invalid search parameters: {data}") | |
return jsonify({"error": "Invalid search parameters"}) | |
if query_embedding is None: | |
print("[DEBUG] Error: Failed to generate query embedding") | |
return jsonify({"error": "Failed to generate query embedding"}) | |
try: | |
# μ μ¬λ κ²μ μ€ν | |
print(f"[DEBUG] Running similarity search with embedding of length {len(query_embedding)}") | |
results = object_collection.query( | |
query_embeddings=[query_embedding], | |
n_results=n_results, | |
include=["metadatas", "distances"] | |
) | |
print(f"[DEBUG] Query results: {results['ids'][0] if 'ids' in results and len(results['ids']) > 0 else 'No results'}") | |
formatted_results = format_object_results(results) | |
print(f"[DEBUG] Formatted results count: {len(formatted_results)}") | |
return jsonify({ | |
"success": True, | |
"searchType": search_type, | |
"results": formatted_results | |
}) | |
except Exception as e: | |
print(f"[DEBUG] Error in similarity search: {e}") | |
return jsonify({"error": f"Error in similarity search: {str(e)}"}), 500 | |
except Exception as e: | |
print(f"Error in search-similar-objects API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def format_object_results(results): | |
"""κ²μ κ²°κ³Ό ν¬λ§·ν - ChromaDB query λ° get λ©μλ κ²°κ³Ό λͺ¨λ μ²λ¦¬""" | |
formatted_results = [] | |
print(f"[DEBUG] Formatting results: {results.keys() if results else 'None'}") | |
if not results: | |
print("[DEBUG] No results to format") | |
return formatted_results | |
try: | |
# Check if this is a result from 'get' method (class search) or 'query' method (similarity search) | |
is_get_result = 'ids' in results and isinstance(results['ids'], list) and not isinstance(results['ids'][0], list) if 'ids' in results else False | |
if is_get_result: | |
# Handle results from 'get' method (flat structure) | |
print("[DEBUG] Processing results from get method (class search)") | |
if len(results['ids']) == 0: | |
return formatted_results | |
for i, obj_id in enumerate(results['ids']): | |
try: | |
# Extract object info | |
metadata = results['metadatas'][i] if 'metadatas' in results else {} | |
# Deserialize bbox if stored as JSON string | |
if 'bbox' in metadata and isinstance(metadata['bbox'], str): | |
try: | |
metadata['bbox'] = json.loads(metadata['bbox']) | |
except: | |
pass # Keep as is if deserialization fails | |
result_item = { | |
"id": obj_id, | |
"metadata": metadata | |
} | |
# No distance in get results | |
# Check if image data is already in metadata | |
if 'image_data' not in metadata: | |
print(f"[DEBUG] Image data not found in metadata for object {obj_id}") | |
else: | |
print(f"[DEBUG] Image data found in metadata for object {obj_id}") | |
formatted_results.append(result_item) | |
except Exception as e: | |
print(f"[DEBUG] Error formatting get result {i}: {e}") | |
else: | |
# Handle results from 'query' method (nested structure) | |
print("[DEBUG] Processing results from query method (similarity search)") | |
if 'ids' not in results or len(results['ids']) == 0 or len(results['ids'][0]) == 0: | |
return formatted_results | |
for i, obj_id in enumerate(results['ids'][0]): | |
try: | |
# Extract object info | |
metadata = results['metadatas'][0][i] if 'metadatas' in results and len(results['metadatas']) > 0 else {} | |
# Deserialize bbox if stored as JSON string | |
if 'bbox' in metadata and isinstance(metadata['bbox'], str): | |
try: | |
metadata['bbox'] = json.loads(metadata['bbox']) | |
except: | |
pass # Keep as is if deserialization fails | |
result_item = { | |
"id": obj_id, | |
"metadata": metadata | |
} | |
if 'distances' in results and len(results['distances']) > 0: | |
result_item["distance"] = float(results['distances'][0][i]) | |
# Check if image data is already in metadata | |
if 'image_data' not in metadata: | |
try: | |
# Try to get original image via image ID | |
image_id = metadata.get('image_id') | |
if image_id: | |
print(f"[DEBUG] Image data not found in metadata for object {obj_id} with image_id {image_id}") | |
except Exception as e: | |
print(f"[DEBUG] Error checking image data for result {i}: {e}") | |
else: | |
print(f"[DEBUG] Image data found in metadata for object {obj_id}") | |
formatted_results.append(result_item) | |
except Exception as e: | |
print(f"[DEBUG] Error formatting query result {i}: {e}") | |
except Exception as e: | |
print(f"[DEBUG] Error in format_object_results: {e}") | |
return formatted_results | |
# λ‘κ·ΈμΈ νμ΄μ§ HTML ν νλ¦Ώ | |
LOGIN_TEMPLATE = ''' | |
<!DOCTYPE html> | |
<html lang="ko"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Vision LLM Agent - λ‘κ·ΈμΈ</title> | |
<style> | |
body { | |
font-family: Arial, sans-serif; | |
background-color: #f5f5f5; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
height: 100vh; | |
margin: 0; | |
} | |
.login-container { | |
background-color: white; | |
padding: 2rem; | |
border-radius: 8px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
width: 100%; | |
max-width: 400px; | |
} | |
h1 { | |
text-align: center; | |
color: #4a6cf7; | |
margin-bottom: 1.5rem; | |
} | |
.form-group { | |
margin-bottom: 1rem; | |
} | |
label { | |
display: block; | |
margin-bottom: 0.5rem; | |
font-weight: bold; | |
} | |
input { | |
width: 100%; | |
padding: 0.75rem; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
font-size: 1rem; | |
} | |
button { | |
width: 100%; | |
padding: 0.75rem; | |
background-color: #4a6cf7; | |
color: white; | |
border: none; | |
border-radius: 4px; | |
font-size: 1rem; | |
cursor: pointer; | |
margin-top: 1rem; | |
} | |
button:hover { | |
background-color: #3a5cd8; | |
} | |
.error-message { | |
color: #e74c3c; | |
margin-top: 1rem; | |
text-align: center; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="login-container"> | |
<h1>Vision LLM Agent</h1> | |
<form action="/login" method="post" autocomplete="off"> | |
<!-- hidden dummy fields to discourage Chrome autofill --> | |
<input type="text" name="fakeusernameremembered" style="display:none" tabindex="-1" autocomplete="off"> | |
<input type="password" name="fakepasswordremembered" style="display:none" tabindex="-1" autocomplete="off"> | |
<div class="form-group"> | |
<label for="username">Username</label> | |
<input type="text" id="username" name="username" value="" placeholder="Enter username" required autocomplete="username" autocapitalize="none" autocorrect="off" spellcheck="false"> | |
</div> | |
<div class="form-group"> | |
<label for="password">Password</label> | |
<input type="password" id="password" name="password" value="" placeholder="Enter password" required autocomplete="current-password" autocapitalize="none" autocorrect="off" spellcheck="false"> | |
</div> | |
<button type="submit">Login</button> | |
{% if error %} | |
<p class="error-message">{{ error }}</p> | |
{% endif %} | |
</form> | |
</div> | |
</body> | |
</html> | |
''' | |
def login(): | |
# μ΄λ―Έ λ‘κ·ΈμΈλ μ¬μ©μλ λ©μΈ νμ΄μ§λ‘ 리λλ μ (remove fresh requirement for HF Spaces) | |
if current_user.is_authenticated: | |
print(f"User already authenticated as: {current_user.username}, redirecting to index") | |
return redirect('/index.html') | |
error = None | |
if request.method == 'POST': | |
username = request.form.get('username') | |
password = request.form.get('password') | |
print(f"Login attempt: username={username}") | |
if username in users and users[username].password == password: | |
# λ‘κ·ΈμΈ μ±κ³΅ μ μΈμ μ μ¬μ©μ μ 보 μ μ₯ | |
user = users[username] | |
login_user(user, remember=False) # 2λΆ μΈμ λ§λ£λ₯Ό μν΄ remember λΉνμ±ν | |
session['user_id'] = user.id | |
session['username'] = username | |
session.permanent = True | |
session.modified = True # μΈμ λ³κ²½ μ¬ν μ¦μ μ μ© | |
# Debug session data after setting | |
print(f"Login successful for user: {username}, ID: {user.id}") | |
print(f"Session data after login: {dict(session)}") | |
# Force session save by accessing it | |
_ = session.get('user_id') | |
# 리λλ μ μ²λ¦¬ | |
next_page = request.args.get('next') | |
if next_page and next_page.startswith('/') and next_page != '/login': | |
print(f"Redirecting to: {next_page}") | |
return redirect(next_page) | |
print("Redirecting to index.html") | |
# Serve index.html directly with cookies to avoid redirect issues | |
print("Serving index.html directly with auth cookies") | |
# Read index.html file | |
index_path = os.path.join(app.static_folder, 'index.html') | |
try: | |
with open(index_path, 'r', encoding='utf-8') as f: | |
html = f.read() | |
except Exception as e: | |
print(f"[DEBUG] Failed to read index.html: {e}") | |
return "Error loading page", 500 | |
# Add session debug script | |
# Alternative: Use localStorage instead of cookies for HF Spaces | |
auth_script = f""" | |
<script> | |
console.log('[DEBUG] HF Spaces detected - using localStorage for auth persistence'); | |
// Store auth data in localStorage (not affected by iframe restrictions) | |
localStorage.setItem('auth_user_id', '{user.id}'); | |
localStorage.setItem('auth_username', '{username}'); | |
localStorage.setItem('auth_timestamp', Date.now().toString()); | |
console.log('[DEBUG] Auth data stored in localStorage'); | |
console.log('[DEBUG] user_id:', localStorage.getItem('auth_user_id')); | |
console.log('[DEBUG] username:', localStorage.getItem('auth_username')); | |
// Override fetch to include auth headers | |
const originalFetch = window.fetch; | |
window.fetch = function(url, options = {{}}) {{ | |
// Add auth headers for API calls | |
if (url.startsWith('/api/')) {{ | |
options.headers = options.headers || {{}}; | |
options.headers['X-Auth-User-ID'] = localStorage.getItem('auth_user_id'); | |
options.headers['X-Auth-Username'] = localStorage.getItem('auth_username'); | |
console.log('[DEBUG] Added auth headers to API call:', url); | |
}} | |
return originalFetch(url, options); | |
}}; | |
// Heartbeat to keep session alive | |
setInterval(() => {{ | |
const userId = localStorage.getItem('auth_user_id'); | |
const username = localStorage.getItem('auth_username'); | |
if (userId && username) {{ | |
fetch('/api/heartbeat', {{ | |
method: 'POST', | |
headers: {{ | |
'X-Auth-User-ID': userId, | |
'X-Auth-Username': username, | |
'Content-Type': 'application/json' | |
}} | |
}}).then(response => {{ | |
console.log('Heartbeat sent, status:', response.status); | |
}}).catch(error => {{ | |
console.log('Heartbeat failed:', error); | |
}}); | |
}} | |
}}, 30000); | |
</script> | |
""" | |
# Insert script before </body> | |
if '</body>' in html: | |
html = html.replace('</body>', auth_script + '\n</body>') | |
else: | |
html += auth_script | |
response = make_response(html) | |
response.headers['Content-Type'] = 'text/html; charset=utf-8' | |
# Set cookies in response headers with maximum compatibility | |
response.set_cookie('auth_user_id', str(user.id), | |
max_age=7200, # 2 hours | |
secure=False, # Force insecure for HF Spaces | |
httponly=False, # Allow JavaScript access | |
samesite=None, # Most permissive | |
path='/') # Ensure path is set | |
response.set_cookie('auth_username', username, | |
max_age=7200, # 2 hours | |
secure=False, # Force insecure for HF Spaces | |
httponly=False, # Allow JavaScript access | |
samesite=None, # Most permissive | |
path='/') # Ensure path is set | |
print(f"[DEBUG] Set cookies: auth_user_id={user.id}, auth_username={username}") | |
print(f"[DEBUG] Cookie settings: secure=False, httponly=False, samesite=None, path=/") | |
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' | |
response.headers['Pragma'] = 'no-cache' | |
response.headers['Expires'] = '0' | |
return response | |
else: | |
error = 'Invalid username or password' | |
print(f"Login failed: {error}") | |
return render_template_string(LOGIN_TEMPLATE, error=error) | |
def logout(): | |
# Clear server-side session | |
logout_user() | |
session.clear() | |
# Return a small HTML that clears client-side localStorage and expires cookies, then redirects | |
html = """ | |
<html><head><meta charset="utf-8"><title>Logging out...</title></head> | |
<body> | |
<script> | |
try { | |
localStorage.removeItem('auth_user_id'); | |
localStorage.removeItem('auth_username'); | |
localStorage.removeItem('auth_timestamp'); | |
} catch (e) {} | |
window.location.href = '/login'; | |
</script> | |
</body></html> | |
""" | |
resp = make_response(html) | |
# Expire auth cookies explicitly | |
resp.set_cookie('auth_user_id', '', expires=0, path='/') | |
resp.set_cookie('auth_username', '', expires=0, path='/') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
def heartbeat(): | |
"""Keep session alive; support header-based auth for HF Spaces""" | |
uid, uname = check_authentication() | |
if uid and uname: | |
return jsonify({"status": "alive", "user_id": uid, "username": uname}) | |
return jsonify({"status": "no_session"}), 401 | |
def product_comparison_lite_page(): | |
"""Serve a lightweight two-image Product Comparison page (no React build required).""" | |
html = '''<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Product Comparison (Lite)</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 16px; } | |
.row { display: flex; gap: 16px; flex-wrap: wrap; } | |
.col { flex: 1 1 320px; min-width: 280px; } | |
.card { border: 1px solid #ddd; border-radius: 8px; padding: 12px; } | |
.preview-box { width: 100%; height: 45vh; max-height: 520px; border: 1px dashed #ccc; display:flex; align-items:center; justify-content:center; overflow:hidden; border-radius:6px; background:#fafafa; } | |
.preview-box img { max-width: 100%; max-height: 100%; width: auto; height: auto; object-fit: contain; } | |
input[type="file"] { display: none; } | |
.file-button { padding: 8px 12px; background: #3f51b5; color: white; border: none; border-radius: 4px; cursor: pointer; font-size: 14px; } | |
.file-button:hover { background: #303f9f; } | |
.controls { margin-top: 12px; display:flex; gap: 8px; align-items:center; } | |
button { padding: 10px 14px; background: #3f51b5; color: #fff; border: none; border-radius: 4px; cursor: pointer; } | |
button:disabled { background: #9aa0c3; cursor: not-allowed; } | |
.log { white-space: pre-wrap; background:#fff; color:#333; border: 1px solid #ddd; padding:10px; border-radius:6px; height:180px; overflow:auto; font-size: 12px; } | |
.result { margin-top: 12px; } | |
</style> | |
</head> | |
<body> | |
<h2>Product Comparison (Lite)</h2> | |
<div class="row"> | |
<div class="col"> | |
<div class="card"> | |
<h3>Product Image 1</h3> | |
<div class="preview-box"><img id="img1" alt="Product 1 Preview" style="display:none;" /></div> | |
<div class="controls"> | |
<input type="file" id="file1" accept="image/*" /> | |
<button type="button" class="file-button" onclick="document.getElementById('file1').click()">Choose Image File</button> | |
</div> | |
</div> | |
</div> | |
<div class="col"> | |
<div class="card"> | |
<h3>Product Image 2</h3> | |
<div class="preview-box"><img id="img2" alt="Product 2 Preview" style="display:none;" /></div> | |
<div class="controls"> | |
<input type="file" id="file2" accept="image/*" /> | |
<button type="button" class="file-button" onclick="document.getElementById('file2').click()">Choose Image File</button> | |
</div> | |
</div> | |
</div> | |
</div> | |
<div class="controls" style="margin-top:16px;"> | |
<button id="compareBtn" disabled onclick="startCompare()">Compare Products</button> | |
<a href="/index.html" style="margin-left:8px;">Back to App</a> | |
</div> | |
<h3>Analysis Progress</h3> | |
<div id="log" class="log"></div> | |
<h3>Comparison Results</h3> | |
<pre id="result" class="result"></pre> | |
<script> | |
// Proactively unregister any active service workers | |
(function(){ | |
if ('serviceWorker' in navigator) { | |
try { | |
navigator.serviceWorker.getRegistrations().then(function(regs){ | |
regs.forEach(function(r){ r.unregister(); }); | |
}); | |
} catch (e) { /* ignore */ } | |
} | |
})(); | |
let file1 = null, file2 = null; | |
function handleFile(i, input) { | |
console.log('handleFile called for image', i, 'with input:', input); | |
const f = input.files && input.files[0]; | |
if (!f) { | |
console.log('No file selected'); | |
return; | |
} | |
console.log('File selected:', f.name, 'size:', f.size, 'type:', f.type); | |
const url = URL.createObjectURL(f); | |
console.log('Object URL created:', url); | |
const img = document.getElementById('img'+i); | |
console.log('Image element found:', img); | |
if (img) { | |
img.src = url; | |
img.style.display = 'block'; | |
img.style.visibility = 'visible'; | |
console.log('Image src set and display changed to block'); | |
img.onload = function() { | |
console.log('Image ' + i + ' loaded successfully, dimensions:', img.naturalWidth, 'x', img.naturalHeight); | |
}; | |
img.onerror = function() { | |
console.error('Failed to load image ' + i); | |
}; | |
} | |
if (i === 1) { | |
file1 = f; | |
console.log('File1 set:', f.name); | |
} else { | |
file2 = f; | |
console.log('File2 set:', f.name); | |
} | |
updateButton(); | |
} | |
function updateButton(){ | |
const btn = document.getElementById('compareBtn'); | |
const enabled = file1 && file2; | |
btn.disabled = !enabled; | |
console.log('Button state updated. File1:', !!file1, 'File2:', !!file2, 'Enabled:', enabled); | |
} | |
async function startCompare() { | |
if (!file1 || !file2) return; | |
const formData = new FormData(); | |
formData.append('image1', file1); | |
formData.append('image2', file2); | |
try { | |
const response = await fetch('/api/product/compare/start', { | |
method: 'POST', | |
body: formData | |
}); | |
const data = await response.json(); | |
if (data.session_id) { | |
streamResults(data.session_id); | |
} | |
} catch (error) { | |
console.error('Error starting comparison:', error); | |
} | |
} | |
function streamResults(sessionId) { | |
const eventSource = new EventSource('/api/product/compare/stream/' + sessionId); | |
const logDiv = document.getElementById('log'); | |
const resultDiv = document.getElementById('result'); | |
eventSource.onmessage = function(event) { | |
const data = JSON.parse(event.data); | |
console.log('Received SSE data:', data); | |
if (data.message) { | |
logDiv.textContent += data.message + '\\n'; | |
logDiv.scrollTop = logDiv.scrollHeight; | |
} else if (data.status) { | |
logDiv.textContent += 'Status: ' + data.status + '\\n'; | |
logDiv.scrollTop = logDiv.scrollHeight; | |
} else if (data.final_result) { | |
resultDiv.textContent = JSON.stringify(data.final_result, null, 2); | |
eventSource.close(); | |
} else if (data.error) { | |
logDiv.textContent += 'Error: ' + data.error + '\\n'; | |
logDiv.scrollTop = logDiv.scrollHeight; | |
eventSource.close(); | |
} | |
}; | |
eventSource.onerror = function() { | |
eventSource.close(); | |
}; | |
} | |
// Add event listeners after DOM loads | |
document.addEventListener('DOMContentLoaded', function() { | |
document.getElementById('file1').addEventListener('change', function() { | |
handleFile(1, this); | |
}); | |
document.getElementById('file2').addEventListener('change', function() { | |
handleFile(2, this); | |
}); | |
}); | |
</script> | |
</body> | |
</html>''' | |
resp = make_response(html) | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
# μ μ νμΌ μλΉμ μν λΌμ°νΈ (λ‘κ·ΈμΈ λΆνμ) | |
def serve_static(filename): | |
print(f"Serving static file: {filename}") | |
resp = send_from_directory(app.static_folder, filename) | |
# Prevent caching of static assets to reflect latest frontend changes | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
# μΈλ±μ€ HTML μ§μ μλΉ (λ‘κ·ΈμΈ νμ) | |
def serve_index_html(): | |
# μΈμ λ° μΏ ν€ λλ²κ·Έ μ 보 | |
print(f"Request to /index.html - Session data: {dict(session)}") | |
print(f"Request to /index.html - Cookies: {request.cookies}") | |
print(f"Request to /index.html - User authenticated: {current_user.is_authenticated}") | |
# Check session data first, then fallback to auth cookies | |
user_id = session.get('user_id') | |
username = session.get('username') | |
# Fallback to auth cookies if session is empty | |
if not user_id or not username: | |
user_id = request.cookies.get('auth_user_id') | |
username = request.cookies.get('auth_username') | |
print(f"Session empty, checking cookies: user_id={user_id}, username={username}") | |
if not user_id or not username: | |
print("No session or cookie data found, redirecting to login") | |
return redirect(url_for('login')) | |
# Verify user exists (users maps username -> User instance) | |
user_found = False | |
for stored_username, user_obj in users.items(): | |
if str(getattr(user_obj, 'id', None)) == str(user_id) and stored_username == username: | |
user_found = True | |
break | |
if not user_found: | |
print(f"User not found: {username} (ID: {user_id}), redirecting to login") | |
return redirect(url_for('login')) | |
print(f"Serving index.html for authenticated user: {username} (ID: {user_id})") | |
# μΈμ μν λλ²κ·Έ | |
print(f"Session data: user_id={user_id}, username={username}, is_permanent={session.get('permanent', False)}") | |
# μΈμ λ§λ£λ₯Ό μλλλ‘ μ μ§νκΈ° μν΄ μ¬κΈ°μ μΈμ μ κ°±μ νμ§ μμ΅λλ€. | |
# μ£Όμ: μΈμ μ μ°κΈ°(λλ session.modified=True)λ Flask-Sessionμμ λ§λ£μκ°μ μ°μ₯ν μ μμ΅λλ€. | |
# index.htmlμ μ½μ΄ ννΈλΉνΈ μ€ν¬λ¦½νΈλ₯Ό μ£Όμ | |
index_path = os.path.join(app.static_folder, 'index.html') | |
try: | |
with open(index_path, 'r', encoding='utf-8') as f: | |
html = f.read() | |
except Exception as e: | |
print(f"[DEBUG] Failed to read index.html for injection: {e}") | |
return send_from_directory(app.static_folder, 'index.html') | |
heartbeat_script = """ | |
<style> | |
/* Override preview image sizing from old builds without rebuild */ | |
.preview-image { max-width: 100% !important; max-height: 100% !important; width: auto !important; height: auto !important; object-fit: contain !important; } | |
/* Ensure container is not too tall on desktop */ | |
.preview-image-container, .image-container { height: 45vh !important; } | |
@media (max-width: 600px) { .preview-image-container, .image-container { height: 35vh !important; } } | |
/* Product Comparison navigation button */ | |
.product-comparison-nav { position: fixed; top: 20px; right: 20px; z-index: 9999; background: #3f51b5; color: white; padding: 12px 16px; border-radius: 8px; text-decoration: none; font-weight: bold; box-shadow: 0 4px 8px rgba(0,0,0,0.2); transition: background 0.3s; } | |
.product-comparison-nav:hover { background: #303f9f; color: white; text-decoration: none; } | |
</style> | |
<script> | |
(function(){ | |
// Ensure auth headers are attached on all API calls using localStorage (HF Spaces blocks cookies) | |
try { | |
const originalFetch = window.fetch; | |
window.fetch = function(url, options = {}) { | |
if (typeof url === 'string' && url.indexOf('/api/') === 0) { | |
options.headers = options.headers || {}; | |
const uid = localStorage.getItem('auth_user_id'); | |
const uname = localStorage.getItem('auth_username'); | |
if (uid && uname) { | |
options.headers['X-Auth-User-ID'] = uid; | |
options.headers['X-Auth-Username'] = uname; | |
// console.debug('[DEBUG] (index) Added auth headers to API call:', url); | |
} | |
} | |
return originalFetch.apply(this, [url, options]); | |
} | |
} catch (e) { | |
console.warn('Failed to install fetch auth header override:', e); | |
} | |
// 1) μΈμ μν μ£ΌκΈ° μ²΄ν¬ (λ§λ£μ λ‘κ·ΈμΈμΌλ‘) | |
function checkSession(){ | |
fetch('/api/status', {credentials: 'include', redirect: 'manual'}).then(function(res){ | |
var redirected = res.redirected || (res.url && res.url.indexOf('/login') !== -1); | |
if(res.status !== 200 || redirected){ | |
window.location.href = '/login'; | |
} | |
}).catch(function(){ | |
// λ€νΈμν¬ μ€λ₯ λ±λ λ‘κ·ΈμΈμΌλ‘ μ λ | |
window.location.href = '/login'; | |
}); | |
} | |
checkSession(); | |
setInterval(checkSession, 30000); | |
// 2) μ¬μ©μ λΉνμ±(무λμ) 2λΆ ν μλ λ‘κ·Έμμ | |
var idleMs = 120000; // 2λΆ | |
var idleTimer; | |
function triggerLogout(){ | |
// μλ² μΈμ μ 리 ν λ‘κ·ΈμΈ νλ©΄μΌλ‘ | |
window.location.href = '/logout'; | |
} | |
function resetIdle(){ | |
if (idleTimer) clearTimeout(idleTimer); | |
idleTimer = setTimeout(triggerLogout, idleMs); | |
} | |
['click','mousemove','keydown','scroll','touchstart','visibilitychange'].forEach(function(evt){ | |
window.addEventListener(evt, resetIdle, {passive:true}); | |
}); | |
resetIdle(); | |
// 3) Add Product Comparison navigation button - DISABLED | |
function addProductComparisonButton() { | |
// Disabled - no longer adding Product Comparison button to main UI | |
return; | |
} | |
// Add button after DOM is ready | |
if (document.readyState === 'loading') { | |
document.addEventListener('DOMContentLoaded', addProductComparisonButton); | |
} else { | |
addProductComparisonButton(); | |
} | |
})(); | |
</script> | |
""" | |
try: | |
if '</body>' in html: | |
html = html.replace('</body>', heartbeat_script + '</body>') | |
else: | |
html = html + heartbeat_script | |
except Exception as e: | |
print(f"[DEBUG] Failed to inject heartbeat script: {e}") | |
return send_from_directory(app.static_folder, 'index.html') | |
resp = make_response(html) | |
# Prevent sensitive pages from being cached | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
# Static files should be accessible without login requirements | |
def static_files(filename): | |
print(f"Serving static file: {filename}") | |
# Two possible locations after CRA build copy: | |
# 1) Top-level: static/<filename> | |
# 2) Nested build: static/static/<filename> | |
top_level_path = os.path.join(app.static_folder, filename) | |
nested_dir = os.path.join(app.static_folder, 'static') | |
nested_path = os.path.join(nested_dir, filename) | |
try: | |
if os.path.exists(top_level_path): | |
return send_from_directory(app.static_folder, filename) | |
elif os.path.exists(nested_path): | |
# Serve from nested build directory | |
return send_from_directory(nested_dir, filename) | |
else: | |
# Fallback: try as-is (may help in some edge cases) | |
return send_from_directory(app.static_folder, filename) | |
except Exception as e: | |
print(f"[DEBUG] Error serving static file '{filename}': {e}") | |
# Final fallback to avoid leaking stack traces | |
return ('Not Found', 404) | |
# Add explicit handlers for JS files that are failing | |
def static_js_files(filename): | |
print(f"Serving JS file: {filename}") | |
# Try top-level static/js and nested static/static/js | |
top_js_dir = os.path.join(app.static_folder, 'js') | |
nested_js_dir = os.path.join(app.static_folder, 'static', 'js') | |
top_js_path = os.path.join(top_js_dir, filename) | |
nested_js_path = os.path.join(nested_js_dir, filename) | |
try: | |
if os.path.exists(top_js_path): | |
return send_from_directory(top_js_dir, filename) | |
elif os.path.exists(nested_js_path): | |
return send_from_directory(nested_js_dir, filename) | |
else: | |
# As a fallback, let the generic static handler try | |
return static_files(os.path.join('js', filename)) | |
except Exception as e: | |
print(f"[DEBUG] Error serving JS file '{filename}': {e}") | |
return ('Not Found', 404) | |
# κΈ°λ³Έ κ²½λ‘ λ° κΈ°ν κ²½λ‘ μ²λ¦¬ (λ‘κ·ΈμΈ νμ) | |
def serve_react(path): | |
"""Serve React frontend""" | |
print(f"Serving React frontend for path: {path}, user: {current_user.username if current_user.is_authenticated else 'not authenticated'}") | |
# Skip specific routes that have their own handlers (but not index.html) | |
if path in ['product-comparison-lite', 'login', 'logout', 'similar-images', 'object-detection-search', 'model-vector-db', 'openai-chat']: | |
# Let Flask find the specific route handler | |
from flask import abort | |
abort(404) # This will cause Flask to try other routes | |
# For root path, redirect to index.html to ensure consistent behavior | |
if path == "": | |
return redirect('/index.html') | |
# μ μ νμΌ μ²λ¦¬λ μ΄μ λ³λ λΌμ°νΈμμ μ²λ¦¬ | |
if path != "" and os.path.exists(os.path.join(app.static_folder, path)): | |
resp = send_from_directory(app.static_folder, path) | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
else: | |
# React μ±μ index.html μλΉ (ννΈλΉνΈ μ€ν¬λ¦½νΈ μ£Όμ ) | |
index_path = os.path.join(app.static_folder, 'index.html') | |
try: | |
with open(index_path, 'r', encoding='utf-8') as f: | |
html = f.read() | |
except Exception as e: | |
print(f"[DEBUG] Failed to read index.html for injection (serve_react): {e}") | |
resp = send_from_directory(app.static_folder, 'index.html') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
heartbeat_script = """ | |
<style> | |
/* Override preview image sizing from old builds without rebuild */ | |
.preview-image { max-width: 100% !important; max-height: 100% !important; width: auto !important; height: auto !important; object-fit: contain !important; } | |
.preview-image-container, .image-container { height: 45vh !important; } | |
@media (max-width: 600px) { .preview-image-container, .image-container { height: 35vh !important; } } | |
/* Product Comparison navigation button */ | |
.product-comparison-nav { position: fixed; top: 20px; right: 20px; z-index: 9999; background: #3f51b5; color: white; padding: 12px 16px; border-radius: 8px; text-decoration: none; font-weight: bold; box-shadow: 0 4px 8px rgba(0,0,0,0.2); transition: background 0.3s; } | |
.product-comparison-nav:hover { background: #303f9f; color: white; text-decoration: none; } | |
</style> | |
<script> | |
(function(){ | |
// 1) μΈμ μν μ£ΌκΈ° μ²΄ν¬ (λ§λ£μ λ‘κ·ΈμΈμΌλ‘) | |
function checkSession(){ | |
fetch('/api/status', {credentials: 'include', redirect: 'manual'}).then(function(res){ | |
var redirected = res.redirected || (res.url && res.url.indexOf('/login') !== -1); | |
if(res.status !== 200 || redirected){ | |
window.location.href = '/login'; | |
} | |
}).catch(function(){ | |
window.location.href = '/login'; | |
}); | |
} | |
checkSession(); | |
setInterval(checkSession, 30000); | |
// 2) μ¬μ©μ λΉνμ±(무λμ) 2λΆ ν μλ λ‘κ·Έμμ | |
var idleMs = 120000; // 2λΆ | |
var idleTimer; | |
function triggerLogout(){ | |
window.location.href = '/logout'; | |
} | |
function resetIdle(){ | |
if (idleTimer) clearTimeout(idleTimer); | |
idleTimer = setTimeout(triggerLogout, idleMs); | |
} | |
['click','mousemove','keydown','scroll','touchstart','visibilitychange'].forEach(function(evt){ | |
window.addEventListener(evt, resetIdle, {passive:true}); | |
}); | |
resetIdle(); | |
// 3) Add Product Comparison navigation button - DISABLED | |
function addProductComparisonButton() { | |
// Disabled - no longer adding Product Comparison button to main UI | |
return; | |
} | |
// Add button after DOM is ready | |
if (document.readyState === 'loading') { | |
document.addEventListener('DOMContentLoaded', addProductComparisonButton); | |
} else { | |
addProductComparisonButton(); | |
} | |
})(); | |
</script> | |
""" | |
try: | |
if '</body>' in html: | |
html = html.replace('</body>', heartbeat_script + '</body>') | |
else: | |
html = html + heartbeat_script | |
except Exception as e: | |
print(f"[DEBUG] Failed to inject heartbeat script (serve_react): {e}") | |
resp = send_from_directory(app.static_folder, 'index.html') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
resp = make_response(html) | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
def similar_images_page(): | |
"""Serve similar images search page""" | |
resp = send_from_directory(app.static_folder, 'similar-images.html') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
def object_detection_search_page(): | |
"""Serve object detection search page""" | |
resp = send_from_directory(app.static_folder, 'object-detection-search.html') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
def model_vector_db_page(): | |
"""Serve model vector DB UI page""" | |
resp = send_from_directory(app.static_folder, 'model-vector-db.html') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
def openai_chat_page(): | |
"""Serve OpenAI chat UI page""" | |
resp = send_from_directory(app.static_folder, 'openai-chat.html') | |
resp.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' | |
resp.headers['Pragma'] = 'no-cache' | |
resp.headers['Expires'] = '0' | |
return resp | |
def openai_chat_api(): | |
"""Forward chat request to OpenAI Chat Completions API. | |
Expects JSON: { prompt: string, model?: string, api_key?: string, system?: string } | |
Uses OPENAI_API_KEY from environment if api_key not provided. | |
""" | |
try: | |
data = request.get_json(force=True) | |
except Exception: | |
return jsonify({"error": "Invalid JSON body"}), 400 | |
prompt = (data or {}).get('prompt', '').strip() | |
model = (data or {}).get('model') or os.environ.get('OPENAI_MODEL', 'gpt-5-mini') | |
system = (data or {}).get('system') or 'You are a helpful assistant.' | |
api_key = (data or {}).get('api_key') or os.environ.get('OPENAI_API_KEY') | |
if not prompt: | |
return jsonify({"error": "Missing 'prompt'"}), 400 | |
if not api_key: | |
return jsonify({"error": "Missing OpenAI API key. Provide in request or set OPENAI_API_KEY env."}), 400 | |
# Prefer official Python SDK if available | |
if OpenAI is None: | |
return jsonify({"error": "OpenAI Python package not installed on server"}), 500 | |
try: | |
start = time.time() | |
client = OpenAI(api_key=api_key) | |
# Perform chat completion | |
chat = client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "system", "content": system}, | |
{"role": "user", "content": prompt}, | |
], | |
) | |
latency = round(time.time() - start, 3) | |
except Exception as e: | |
# Log detailed error for diagnostics | |
try: | |
import traceback | |
traceback.print_exc() | |
except Exception: | |
pass | |
err_msg = str(e) | |
print(f"[OpenAI Chat Error] model={model} err={err_msg}") | |
return jsonify({ | |
"error": "OpenAI SDK call failed", | |
"detail": err_msg, | |
"model": model | |
}), 502 | |
try: | |
content = chat.choices[0].message.content if chat and chat.choices else '' | |
usage = getattr(chat, 'usage', None) | |
usage = usage.model_dump() if hasattr(usage, 'model_dump') else (usage or {}) | |
except Exception as e: | |
return jsonify({"error": f"Failed to parse SDK response: {str(e)}"}), 500 | |
return jsonify({ | |
'response': content, | |
'model': model, | |
'usage': usage, | |
'latency_sec': latency | |
}) | |
def vision_rag_query(): | |
"""Vision RAG endpoint. | |
Expects JSON with one of the following query modes and a user question: | |
- { userQuery, searchType: 'image', image, n_results? } | |
- { userQuery, searchType: 'object', objectId, n_results? } | |
- { userQuery, searchType: 'class', class_name, n_results? } | |
Returns: { answer, retrieved: [...], model, latency_sec } | |
""" | |
if ChatOpenAI is None: | |
return jsonify({"error": "LangChain not installed on server"}), 500 | |
data = request.get_json(silent=True) or {} | |
# Debug: log incoming payload keys and basic info (without sensitive data) | |
try: | |
print("[VRAG][REQ] keys=", list(data.keys())) | |
print("[VRAG][REQ] has_api_key=", bool(data.get('api_key') or os.environ.get('OPENAI_API_KEY'))) | |
_img = data.get('image') | |
if isinstance(_img, str): | |
print("[VRAG][REQ] image_str_len=", len(_img), "prefix=", _img[:30] if len(_img) > 30 else _img) | |
print("[VRAG][REQ] searchType=", data.get('searchType'), "objectId=", data.get('objectId'), "class_name=", data.get('class_name')) | |
except Exception as _e: | |
print("[VRAG][WARN] failed to log request payload:", _e) | |
user_query = (data.get('userQuery') or '').strip() | |
if not user_query: | |
return jsonify({"error": "Missing 'userQuery'"}), 400 | |
api_key = data.get('api_key') or os.environ.get('OPENAI_API_KEY') | |
if not api_key: | |
return jsonify({"error": "Missing OpenAI API key. Provide in request or set OPENAI_API_KEY env."}), 400 | |
search_type = data.get('searchType', 'image') | |
n_results = int(data.get('n_results', 5)) | |
print(f"[VRAG] user_query='{user_query}' | search_type={search_type} | n_results={n_results}") | |
# Build query embedding or filtered fetch similar to /api/search-similar-objects | |
results = None | |
try: | |
if search_type == 'image' and 'image' in data: | |
image_data = data['image'] | |
if isinstance(image_data, str) and image_data.startswith('data:image'): | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
query_embedding = generate_image_embedding(image) | |
if query_embedding is None: | |
return jsonify({"error": "Failed to generate image embedding"}), 500 | |
results = object_collection.query( | |
query_embeddings=[query_embedding], | |
n_results=n_results, | |
include=["metadatas", "distances"] | |
) if object_collection is not None else None | |
elif search_type == 'object' and 'objectId' in data: | |
obj_id = data['objectId'] | |
base = object_collection.get(ids=[obj_id], include=["embeddings"]) if object_collection is not None else None | |
emb = base["embeddings"][0] if base and "embeddings" in base and base["embeddings"] else None | |
if emb is None: | |
return jsonify({"error": "objectId not found or has no embedding"}), 400 | |
results = object_collection.query( | |
query_embeddings=[emb], | |
n_results=n_results, | |
include=["metadatas", "distances"] | |
) | |
elif search_type == 'class' and 'class_name' in data: | |
filter_query = {"class": {"$eq": data['class_name']}} | |
results = object_collection.get( | |
where=filter_query, | |
limit=n_results, | |
include=["metadatas", "embeddings", "documents"] | |
) if object_collection is not None else None | |
else: | |
return jsonify({"error": "Invalid search parameters"}), 400 | |
except Exception as e: | |
return jsonify({"error": f"Retrieval failed: {str(e)}"}), 500 | |
# Format results using existing helper | |
formatted = format_object_results(results) if results else [] | |
# Debug: log retrieval summary | |
try: | |
cnt = len(formatted) | |
print(f"[VRAG][RETRIEVE] items={cnt}") | |
if cnt: | |
print("[VRAG][RETRIEVE] first_item=", { | |
'id': formatted[0].get('id'), | |
'distance': formatted[0].get('distance'), | |
'meta_keys': list((formatted[0].get('metadata') or {}).keys()) | |
}) | |
except Exception as _e: | |
print("[VRAG][WARN] failed to log retrieval summary:", _e) | |
# Build concise context for LLM | |
def _shorten(md): | |
try: | |
bbox = md.get('bbox') if isinstance(md, dict) else None | |
if isinstance(bbox, dict): | |
bbox = {k: round(float(v), 3) for k, v in bbox.items() if isinstance(v, (int, float))} | |
return { | |
'image_id': md.get('image_id'), | |
'class': md.get('class'), | |
'confidence': md.get('confidence'), | |
'bbox': bbox, | |
} | |
except Exception: | |
return {k: md.get(k) for k in ('image_id', 'class', 'confidence') if k in md} | |
context_items = [] | |
for r in formatted[:n_results]: | |
md = r.get('metadata', {}) | |
item = { | |
'id': r.get('id'), | |
'distance': r.get('distance'), | |
'meta': _shorten(md) | |
} | |
context_items.append(item) | |
# Compose prompt | |
system_text = ( | |
"You are a vision assistant. Use ONLY the provided detected object context to answer. " | |
"Be concise and state uncertainty if context is insufficient." | |
) | |
# Provide the minimal JSON-like context to the model | |
context_text = json.dumps(context_items, ensure_ascii=False, indent=2) | |
user_text = f"User question: {user_query}\n\nDetected context (top {len(context_items)}):\n{context_text}" | |
# Debug: show compact context preview | |
try: | |
print("[VRAG][CTX] context_items_count=", len(context_items)) | |
if context_items: | |
print("[VRAG][CTX] sample=", json.dumps(context_items[0], ensure_ascii=False)) | |
except Exception as _e: | |
print("[VRAG][WARN] failed to log context:", _e) | |
# Attempt multimodal call (text + top-1 image) if available; otherwise fallback to text-only LangChain. | |
answer = None | |
model_used = None | |
try: | |
start = time.time() | |
top_data_url = None | |
try: | |
if formatted: | |
md0 = (formatted[0] or {}).get('metadata') or {} | |
img_b64 = md0.get('image_data') | |
if isinstance(img_b64, str) and len(img_b64) > 50: | |
# Construct data URL without logging raw base64 | |
top_data_url = 'data:image/jpeg;base64,' + img_b64 | |
except Exception: | |
top_data_url = None | |
# Prefer OpenAI SDK for multimodal if available and we have an image | |
if OpenAI is not None and top_data_url is not None: | |
client = OpenAI(api_key=api_key) | |
model_used = os.environ.get('OPENAI_MODEL', 'gpt-5-mini') | |
chat = client.chat.completions.create( | |
model=model_used, | |
messages=[ | |
{"role": "system", "content": system_text}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": user_text}, | |
{"type": "image_url", "image_url": {"url": top_data_url}}, | |
], | |
}, | |
], | |
) | |
answer = chat.choices[0].message.content if chat and chat.choices else '' | |
else: | |
# Fallback to existing LangChain text-only flow | |
llm = ChatOpenAI(api_key=api_key, model=os.environ.get('OPENAI_MODEL', 'gpt-5-mini')) | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_text), | |
("human", "{input}") | |
]) | |
chain = prompt | llm | StrOutputParser() | |
answer = chain.invoke({"input": user_text}) | |
model_used = getattr(llm, 'model', None) | |
latency = round(time.time() - start, 3) | |
except Exception as e: | |
return jsonify({"error": f"LLM call failed: {str(e)}"}), 502 | |
return jsonify({ | |
"answer": answer, | |
"retrieved": context_items, | |
"model": model_used, | |
"latency_sec": latency | |
}) | |
def status(): | |
# Public endpoint - no authentication required | |
return jsonify({ | |
"status": "online", | |
"models": { | |
"yolo": yolo_model is not None, | |
"detr": detr_model is not None and detr_processor is not None, | |
"vit": vit_model is not None and vit_processor is not None | |
}, | |
"device": "GPU" if torch.cuda.is_available() else "CPU" | |
}) | |
# Root route is now handled by serve_react function | |
# This route is removed to prevent conflicts | |
def index_page(): | |
# /index κ²½λ‘λ index.htmlλ‘ λ¦¬λλ μ | |
print("Index route redirecting to index.html") | |
return redirect('/index.html') | |
if __name__ == "__main__": | |
# νκΉ νμ΄μ€ Spaceμμλ PORT νκ²½ λ³μλ₯Ό μ¬μ©ν©λλ€ | |
port = int(os.environ.get("FLASK_PORT", os.environ.get("PORT", 7860))) | |
app.run(debug=False, host='0.0.0.0', port=port) | |