Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, send_from_directory | |
import torch | |
from PIL import Image | |
import numpy as np | |
import os | |
import io | |
import base64 | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle | |
import time | |
from flask_cors import CORS | |
import json | |
import chromadb | |
from chromadb.utils import embedding_functions | |
app = Flask(__name__, static_folder='static') | |
CORS(app) # Enable CORS for all routes | |
# Model initialization | |
print("Loading models... This may take a moment.") | |
# Image embedding model (CLIP) for vector search | |
clip_model = None | |
clip_processor = None | |
try: | |
from transformers import CLIPProcessor, CLIPModel | |
# ์์ ๋๋ ํ ๋ฆฌ ์ฌ์ฉ | |
import tempfile | |
temp_dir = tempfile.gettempdir() | |
os.environ["TRANSFORMERS_CACHE"] = temp_dir | |
# CLIP ๋ชจ๋ธ ๋ก๋ (์ด๋ฏธ์ง ์๋ฒ ๋ฉ์ฉ) | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("CLIP model loaded successfully") | |
except Exception as e: | |
print("Error loading CLIP model:", e) | |
clip_model = None | |
clip_processor = None | |
# Vector DB ์ด๊ธฐํ | |
vector_db = None | |
image_collection = None | |
object_collection = None | |
try: | |
# ChromaDB ํด๋ผ์ด์ธํธ ์ด๊ธฐํ (์ธ๋ฉ๋ชจ๋ฆฌ DB) | |
vector_db = chromadb.Client() | |
# ์๋ฒ ๋ฉ ํจ์ ์ค์ | |
ef = embedding_functions.DefaultEmbeddingFunction() | |
# ์ด๋ฏธ์ง ์ปฌ๋ ์ ์์ฑ | |
image_collection = vector_db.create_collection( | |
name="image_collection", | |
embedding_function=ef, | |
get_or_create=True | |
) | |
# ๊ฐ์ฒด ์ธ์ ๊ฒฐ๊ณผ ์ปฌ๋ ์ ์์ฑ | |
object_collection = vector_db.create_collection( | |
name="object_collection", | |
embedding_function=ef, | |
get_or_create=True | |
) | |
print("Vector DB initialized successfully") | |
except Exception as e: | |
print("Error initializing Vector DB:", e) | |
vector_db = None | |
image_collection = None | |
object_collection = None | |
# YOLOv8 model | |
yolo_model = None | |
try: | |
import os | |
from ultralytics import YOLO | |
# ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก - ์์ ๋๋ ํ ๋ฆฌ ์ฌ์ฉ | |
import tempfile | |
temp_dir = tempfile.gettempdir() | |
model_path = os.path.join(temp_dir, "yolov8n.pt") | |
# ๋ชจ๋ธ ํ์ผ์ด ์์ผ๋ฉด ์ง์ ๋ค์ด๋ก๋ | |
if not os.path.exists(model_path): | |
print(f"Downloading YOLOv8 model to {model_path}...") | |
try: | |
os.system(f"wget -q https://ultralytics.com/assets/yolov8n.pt -O {model_path}") | |
print("YOLOv8 model downloaded successfully") | |
except Exception as e: | |
print(f"Error downloading YOLOv8 model: {e}") | |
# ๋ค์ด๋ก๋ ์คํจ ์ ๋์ฒด URL ์๋ | |
try: | |
os.system(f"wget -q https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt -O {model_path}") | |
print("YOLOv8 model downloaded from alternative source") | |
except Exception as e2: | |
print(f"Error downloading from alternative source: {e2}") | |
# ๋ง์ง๋ง ๋์์ผ๋ก ์ง์ ๋ชจ๋ธ URL ์ฌ์ฉ | |
try: | |
os.system(f"curl -L https://ultralytics.com/assets/yolov8n.pt --output {model_path}") | |
print("YOLOv8 model downloaded using curl") | |
except Exception as e3: | |
print(f"All download attempts failed: {e3}") | |
# ํ๊ฒฝ ๋ณ์ ์ค์ - ์ค์ ํ์ผ ๊ฒฝ๋ก ์ง์ | |
os.environ["YOLO_CONFIG_DIR"] = temp_dir | |
os.environ["MPLCONFIGDIR"] = temp_dir | |
yolo_model = YOLO(model_path) # Using the nano model for faster inference | |
print("YOLOv8 model loaded successfully") | |
except Exception as e: | |
print("Error loading YOLOv8 model:", e) | |
yolo_model = None | |
# DETR model (DEtection TRansformer) | |
detr_processor = None | |
detr_model = None | |
try: | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
print("DETR model loaded successfully") | |
except Exception as e: | |
print("Error loading DETR model:", e) | |
detr_processor = None | |
detr_model = None | |
# ViT model | |
vit_processor = None | |
vit_model = None | |
try: | |
from transformers import ViTImageProcessor, ViTForImageClassification | |
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
vit_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
print("ViT model loaded successfully") | |
except Exception as e: | |
print("Error loading ViT model:", e) | |
vit_processor = None | |
vit_model = None | |
# Get device information | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# LLM model (using an open-access model instead of Llama 4 which requires authentication) | |
llm_model = None | |
llm_tokenizer = None | |
try: | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
print("Loading LLM model... This may take a moment.") | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Using TinyLlama as an open-access alternative | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
# Removing options that require accelerate package | |
# device_map="auto", | |
# load_in_8bit=True | |
).to(device) | |
print("LLM model loaded successfully") | |
except Exception as e: | |
print(f"Error loading LLM model: {e}") | |
llm_model = None | |
llm_tokenizer = None | |
def process_llm_query(vision_results, user_query): | |
"""Process a query with the LLM model using vision results and user text""" | |
if llm_model is None or llm_tokenizer is None: | |
return {"error": "LLM model not available"} | |
# ๊ฒฐ๊ณผ ๋ฐ์ดํฐ ์์ฝ (ํ ํฐ ๊ธธ์ด ์ ํ์ ์ํด) | |
summarized_results = [] | |
# ๊ฐ์ฒด ํ์ง ๊ฒฐ๊ณผ ์์ฝ | |
if isinstance(vision_results, list): | |
# ์ต๋ 10๊ฐ ๊ฐ์ฒด๋ง ํฌํจ | |
for i, obj in enumerate(vision_results[:10]): | |
if isinstance(obj, dict): | |
# ํ์ํ ์ ๋ณด๋ง ์ถ์ถ | |
summary = { | |
"label": obj.get("label", "unknown"), | |
"confidence": obj.get("confidence", 0), | |
} | |
summarized_results.append(summary) | |
# Create a prompt combining vision results and user query | |
prompt = f"""You are an AI assistant analyzing image detection results. | |
Here are the objects detected in the image: {json.dumps(summarized_results, indent=2)} | |
User question: {user_query} | |
Please provide a detailed analysis based on the detected objects and the user's question. | |
""" | |
# Tokenize and generate response | |
try: | |
start_time = time.time() | |
# ํ ํฐ ๊ธธ์ด ํ์ธ ๋ฐ ์ ํ | |
tokens = llm_tokenizer.encode(prompt) | |
if len(tokens) > 1500: # ์์ ๋ง์ง ์ค์ | |
prompt = f"""You are an AI assistant analyzing image detection results. | |
The image contains {len(summarized_results)} detected objects. | |
User question: {user_query} | |
Please provide a general analysis based on the user's question. | |
""" | |
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
output = llm_model.generate( | |
**inputs, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
response_text = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
# Remove the prompt from the response | |
if response_text.startswith(prompt): | |
response_text = response_text[len(prompt):].strip() | |
inference_time = time.time() - start_time | |
return { | |
"response": response_text, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": "GPU" if torch.cuda.is_available() else "CPU" | |
} | |
} | |
except Exception as e: | |
return {"error": f"Error processing LLM query: {str(e)}"} | |
def image_to_base64(img): | |
"""Convert PIL Image to base64 string""" | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
return img_str | |
def process_yolo(image): | |
if yolo_model is None: | |
return {"error": "YOLOv8 model not loaded"} | |
# Measure inference time | |
start_time = time.time() | |
# Convert to numpy if it's a PIL image | |
if isinstance(image, Image.Image): | |
image_np = np.array(image) | |
else: | |
image_np = image | |
# Run inference | |
results = yolo_model(image_np) | |
# Process results | |
result_image = results[0].plot() | |
result_image = Image.fromarray(result_image) | |
# Get detection information | |
boxes = results[0].boxes | |
class_names = results[0].names | |
# Format detection results | |
detections = [] | |
for box in boxes: | |
class_id = int(box.cls[0].item()) | |
class_name = class_names[class_id] | |
confidence = round(box.conf[0].item(), 2) | |
bbox = box.xyxy[0].tolist() | |
bbox = [round(x) for x in bbox] | |
detections.append({ | |
"class": class_name, | |
"confidence": confidence, | |
"bbox": bbox | |
}) | |
# Calculate inference time | |
inference_time = time.time() - start_time | |
# Add inference time and device info | |
device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
return { | |
"image": image_to_base64(result_image), | |
"detections": detections, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": device_info | |
} | |
} | |
def process_detr(image): | |
if detr_model is None or detr_processor is None: | |
return {"error": "DETR model not loaded"} | |
# Measure inference time | |
start_time = time.time() | |
# Prepare image for the model | |
inputs = detr_processor(images=image, return_tensors="pt") | |
# Run inference | |
with torch.no_grad(): | |
outputs = detr_model(**inputs) | |
# Process results | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = detr_processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=0.9 | |
)[0] | |
# Create a copy of the image to draw on | |
result_image = image.copy() | |
fig, ax = plt.subplots(1) | |
ax.imshow(result_image) | |
# Format detection results | |
detections = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
box = [round(i) for i in box.tolist()] | |
class_name = detr_model.config.id2label[label.item()] | |
confidence = round(score.item(), 2) | |
# Draw rectangle | |
rect = Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], | |
linewidth=2, edgecolor='r', facecolor='none') | |
ax.add_patch(rect) | |
# Add label | |
plt.text(box[0], box[1], "{}: {}".format(class_name, confidence), | |
bbox=dict(facecolor='white', alpha=0.8)) | |
detections.append({ | |
"class": class_name, | |
"confidence": confidence, | |
"bbox": box | |
}) | |
# Save figure to image | |
buf = io.BytesIO() | |
plt.tight_layout() | |
plt.axis('off') | |
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
buf.seek(0) | |
result_image = Image.open(buf) | |
plt.close(fig) | |
# Calculate inference time | |
inference_time = time.time() - start_time | |
# Add inference time and device info | |
device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
return { | |
"image": image_to_base64(result_image), | |
"detections": detections, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": device_info | |
} | |
} | |
def process_vit(image): | |
if vit_model is None or vit_processor is None: | |
return {"error": "ViT model not loaded"} | |
# Measure inference time | |
start_time = time.time() | |
# Prepare image for the model | |
inputs = vit_processor(images=image, return_tensors="pt") | |
# Run inference | |
with torch.no_grad(): | |
outputs = vit_model(**inputs) | |
logits = outputs.logits | |
# Get the predicted class | |
predicted_class_idx = logits.argmax(-1).item() | |
prediction = vit_model.config.id2label[predicted_class_idx] | |
# Get top 5 predictions | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
top5_prob, top5_indices = torch.topk(probs, 5) | |
results = [] | |
for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)): | |
class_name = vit_model.config.id2label[idx.item()] | |
results.append({ | |
"rank": i+1, | |
"class": class_name, | |
"probability": round(prob.item(), 3) | |
}) | |
# Calculate inference time | |
inference_time = time.time() - start_time | |
# Add inference time and device info | |
device_info = "GPU" if torch.cuda.is_available() else "CPU" | |
return { | |
"top_predictions": results, | |
"performance": { | |
"inference_time": round(inference_time, 3), | |
"device": device_info | |
} | |
} | |
def yolo_detect(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
image = Image.open(file.stream) | |
result = process_yolo(image) | |
return jsonify(result) | |
def detr_detect(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
image = Image.open(file.stream) | |
result = process_detr(image) | |
return jsonify(result) | |
def vit_classify(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image provided"}), 400 | |
file = request.files['image'] | |
image = Image.open(file.stream) | |
result = process_vit(image) | |
return jsonify(result) | |
def analyze_with_llm(): | |
# Check if required data is in the request | |
if not request.json: | |
return jsonify({"error": "No JSON data provided"}), 400 | |
# Extract vision results and user query from request | |
data = request.json | |
if 'visionResults' not in data or 'userQuery' not in data: | |
return jsonify({"error": "Missing required fields: visionResults or userQuery"}), 400 | |
vision_results = data['visionResults'] | |
user_query = data['userQuery'] | |
# Process the query with LLM | |
result = process_llm_query(vision_results, user_query) | |
return jsonify(result) | |
def generate_image_embedding(image): | |
"""CLIP ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ""" | |
if clip_model is None or clip_processor is None: | |
return None | |
try: | |
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
inputs = clip_processor(images=image, return_tensors="pt") | |
# ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**inputs) | |
# ์๋ฒ ๋ฉ ์ ๊ทํ ๋ฐ numpy ๋ฐฐ์ด๋ก ๋ณํ | |
image_embedding = image_features.squeeze().cpu().numpy() | |
normalized_embedding = image_embedding / np.linalg.norm(image_embedding) | |
return normalized_embedding.tolist() | |
except Exception as e: | |
print(f"Error generating image embedding: {e}") | |
return None | |
def find_similar_images(): | |
"""์ ์ฌ ์ด๋ฏธ์ง ๊ฒ์ API""" | |
if clip_model is None or clip_processor is None or image_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# ์์ฒญ์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ถ์ถ | |
if 'image' not in request.files and 'image' not in request.form: | |
return jsonify({"error": "No image provided"}) | |
if 'image' in request.files: | |
# ํ์ผ๋ก ์ ๋ก๋๋ ๊ฒฝ์ฐ | |
image_file = request.files['image'] | |
image = Image.open(image_file).convert('RGB') | |
else: | |
# base64๋ก ์ธ์ฝ๋ฉ๋ ๊ฒฝ์ฐ | |
image_data = request.form['image'] | |
if image_data.startswith('data:image'): | |
# Remove the data URL prefix if present | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
# ์ด๋ฏธ์ง ID ์์ฑ (์์) | |
image_id = str(uuid.uuid4()) | |
# ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ | |
embedding = generate_image_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate image embedding"}) | |
# ํ์ฌ ์ด๋ฏธ์ง๋ฅผ DB์ ์ถ๊ฐ (์ ํ์ ) | |
# image_collection.add( | |
# ids=[image_id], | |
# embeddings=[embedding] | |
# ) | |
# ์ ์ฌ ์ด๋ฏธ์ง ๊ฒ์ | |
results = image_collection.query( | |
query_embeddings=[embedding], | |
n_results=5 # ์์ 5๊ฐ ๊ฒฐ๊ณผ ๋ฐํ | |
) | |
# ๊ฒฐ๊ณผ ํฌ๋งทํ | |
similar_images = [] | |
if len(results['ids'][0]) > 0: | |
for i, img_id in enumerate(results['ids'][0]): | |
similar_images.append({ | |
"id": img_id, | |
"distance": float(results['distances'][0][i]) if 'distances' in results else 0.0, | |
"metadata": results['metadatas'][0][i] if 'metadatas' in results else {} | |
}) | |
return jsonify({ | |
"query_image_id": image_id, | |
"similar_images": similar_images | |
}) | |
except Exception as e: | |
print(f"Error in similar-images API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def add_to_collection(): | |
"""์ด๋ฏธ์ง๋ฅผ ๋ฒกํฐ DB์ ์ถ๊ฐํ๋ API""" | |
if clip_model is None or clip_processor is None or image_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# ์์ฒญ์์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ถ์ถ | |
if 'image' not in request.files and 'image' not in request.form: | |
return jsonify({"error": "No image provided"}) | |
# ๋ฉํ๋ฐ์ดํฐ ์ถ์ถ | |
metadata = {} | |
if 'metadata' in request.form: | |
metadata = json.loads(request.form['metadata']) | |
# ์ด๋ฏธ์ง ID (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ ์๋ ์์ฑ) | |
image_id = request.form.get('id', str(uuid.uuid4())) | |
if 'image' in request.files: | |
# ํ์ผ๋ก ์ ๋ก๋๋ ๊ฒฝ์ฐ | |
image_file = request.files['image'] | |
image = Image.open(image_file).convert('RGB') | |
else: | |
# base64๋ก ์ธ์ฝ๋ฉ๋ ๊ฒฝ์ฐ | |
image_data = request.form['image'] | |
if image_data.startswith('data:image'): | |
# Remove the data URL prefix if present | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
# ์ด๋ฏธ์ง ์๋ฒ ๋ฉ ์์ฑ | |
embedding = generate_image_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate image embedding"}) | |
# ์ด๋ฏธ์ง๋ฅผ DB์ ์ถ๊ฐ | |
image_collection.add( | |
ids=[image_id], | |
embeddings=[embedding], | |
metadatas=[metadata] | |
) | |
return jsonify({ | |
"success": True, | |
"image_id": image_id, | |
"message": "Image added to collection" | |
}) | |
except Exception as e: | |
print(f"Error in add-to-collection API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def add_detected_objects(): | |
"""๊ฐ์ฒด ์ธ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฒกํฐ DB์ ์ถ๊ฐํ๋ API""" | |
if clip_model is None or object_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# ์์ฒญ์์ ์ด๋ฏธ์ง์ ๊ฐ์ฒด ๊ฒ์ถ ๊ฒฐ๊ณผ ๋ฐ์ดํฐ ์ถ์ถ | |
data = request.json | |
if not data or 'image' not in data or 'objects' not in data: | |
return jsonify({"error": "Missing image or objects data"}) | |
# ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ฒ๋ฆฌ | |
image_data = data['image'] | |
if image_data.startswith('data:image'): | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
image_width, image_height = image.size | |
# ์ด๋ฏธ์ง ID | |
image_id = data.get('imageId', str(uuid.uuid4())) | |
# ๊ฐ์ฒด ๋ฐ์ดํฐ ์ฒ๋ฆฌ | |
objects = data['objects'] | |
object_ids = [] | |
object_embeddings = [] | |
object_metadatas = [] | |
for obj in objects: | |
# ๊ฐ์ฒด ID ์์ฑ | |
object_id = f"{image_id}_{str(uuid.uuid4())[:8]}" | |
# ๋ฐ์ด๋ฉ ๋ฐ์ค ์ ๋ณด ์ถ์ถ | |
bbox = obj.get('bbox', {}) | |
x1 = bbox.get('x', 0) | |
y1 = bbox.get('y', 0) | |
width = bbox.get('width', 0) | |
height = bbox.get('height', 0) | |
# ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์ด๋ฏธ์ง ์ขํ๋ก ๋ณํ | |
x1_px = int(x1 * image_width) | |
y1_px = int(y1 * image_height) | |
width_px = int(width * image_width) | |
height_px = int(height * image_height) | |
# ๊ฐ์ฒด ์ด๋ฏธ์ง ์๋ฅด๊ธฐ | |
try: | |
object_image = image.crop((x1_px, y1_px, x1_px + width_px, y1_px + height_px)) | |
# ์๋ฒ ๋ฉ ์์ฑ | |
embedding = generate_image_embedding(object_image) | |
if embedding is None: | |
continue | |
# ๋ฉํ๋ฐ์ดํฐ ๊ตฌ์ฑ | |
metadata = { | |
"image_id": image_id, | |
"class": obj.get('class', ''), | |
"confidence": obj.get('confidence', 0), | |
"bbox": { | |
"x": x1, | |
"y": y1, | |
"width": width, | |
"height": height | |
} | |
} | |
object_ids.append(object_id) | |
object_embeddings.append(embedding) | |
object_metadatas.append(metadata) | |
except Exception as e: | |
print(f"Error processing object: {e}") | |
continue | |
# ๊ฐ์ฒด๊ฐ ์๋ ๊ฒฝ์ฐ | |
if not object_ids: | |
return jsonify({"error": "No valid objects to add"}) | |
# ๊ฐ์ฒด๋ค์ DB์ ์ถ๊ฐ | |
object_collection.add( | |
ids=object_ids, | |
embeddings=object_embeddings, | |
metadatas=object_metadatas | |
) | |
return jsonify({ | |
"success": True, | |
"image_id": image_id, | |
"object_count": len(object_ids), | |
"object_ids": object_ids | |
}) | |
except Exception as e: | |
print(f"Error in add-detected-objects API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def search_similar_objects(): | |
"""์ ์ฌํ ๊ฐ์ฒด ๊ฒ์ API""" | |
if clip_model is None or object_collection is None: | |
return jsonify({"error": "Image embedding model or vector DB not available"}) | |
try: | |
# ์์ฒญ ๋ฐ์ดํฐ ์ถ์ถ | |
data = request.json | |
if not data: | |
return jsonify({"error": "Missing request data"}) | |
# ๊ฒ์ ์ ํ ๊ฒฐ์ | |
search_type = data.get('searchType', 'image') | |
n_results = int(data.get('nResults', 5)) # ๊ฒฐ๊ณผ ๊ฐ์ | |
query_embedding = None | |
if search_type == 'image' and 'image' in data: | |
# ์ด๋ฏธ์ง๋ก ๊ฒ์ํ๋ ๊ฒฝ์ฐ | |
image_data = data['image'] | |
if image_data.startswith('data:image'): | |
image_data = image_data.split(',')[1] | |
image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB') | |
query_embedding = generate_image_embedding(image) | |
elif search_type == 'object' and 'objectId' in data: | |
# ๊ฐ์ฒด ID๋ก ๊ฒ์ํ๋ ๊ฒฝ์ฐ | |
object_id = data['objectId'] | |
result = object_collection.get(ids=[object_id], include=["embeddings"]) | |
if result and "embeddings" in result and len(result["embeddings"]) > 0: | |
query_embedding = result["embeddings"][0] | |
elif search_type == 'class' and 'className' in data: | |
# ํด๋์ค ์ด๋ฆ์ผ๋ก ๊ฒ์ํ๋ ๊ฒฝ์ฐ | |
class_name = data['className'] | |
filter_query = {"class": {"$eq": class_name}} | |
# ํด๋์ค๋ก ํํฐ๋งํ์ฌ ๊ฒ์ | |
results = object_collection.query( | |
query_embeddings=None, | |
where=filter_query, | |
n_results=n_results, | |
include=["metadatas", "distances"] | |
) | |
return jsonify({ | |
"success": True, | |
"searchType": "class", | |
"results": format_object_results(results) | |
}) | |
else: | |
return jsonify({"error": "Invalid search parameters"}) | |
if query_embedding is None: | |
return jsonify({"error": "Failed to generate query embedding"}) | |
# ์ ์ฌ๋ ๊ฒ์ ์คํ | |
results = object_collection.query( | |
query_embeddings=[query_embedding], | |
n_results=n_results, | |
include=["metadatas", "distances"] | |
) | |
return jsonify({ | |
"success": True, | |
"searchType": search_type, | |
"results": format_object_results(results) | |
}) | |
except Exception as e: | |
print(f"Error in search-similar-objects API: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def format_object_results(results): | |
"""๊ฒ์ ๊ฒฐ๊ณผ ํฌ๋งทํ """ | |
formatted_results = [] | |
if len(results['ids']) > 0 and len(results['ids'][0]) > 0: | |
for i, obj_id in enumerate(results['ids'][0]): | |
result_item = { | |
"id": obj_id, | |
"metadata": results['metadatas'][0][i] if 'metadatas' in results else {} | |
} | |
if 'distances' in results: | |
result_item["distance"] = float(results['distances'][0][i]) | |
formatted_results.append(result_item) | |
return formatted_results | |
def serve_react(path): | |
"""Serve React frontend""" | |
if path != "" and os.path.exists(os.path.join(app.static_folder, path)): | |
return send_from_directory(app.static_folder, path) | |
else: | |
return send_from_directory(app.static_folder, 'index.html') | |
def similar_images_page(): | |
"""Serve similar images search page""" | |
return send_from_directory(app.static_folder, 'similar-images.html') | |
def object_detection_search_page(): | |
"""Serve object detection search page""" | |
return send_from_directory(app.static_folder, 'object-detection-search.html') | |
def model_vector_db_page(): | |
"""Serve model vector DB UI page""" | |
return send_from_directory(app.static_folder, 'model-vector-db.html') | |
def status(): | |
return jsonify({ | |
"status": "online", | |
"models": { | |
"yolo": yolo_model is not None, | |
"detr": detr_model is not None and detr_processor is not None, | |
"vit": vit_model is not None and vit_processor is not None | |
}, | |
"device": "GPU" if torch.cuda.is_available() else "CPU" | |
}) | |
def index(): | |
return send_from_directory('static', 'index.html') | |
if __name__ == "__main__": | |
# ํ๊น ํ์ด์ค Space์์๋ PORT ํ๊ฒฝ ๋ณ์๋ฅผ ์ฌ์ฉํฉ๋๋ค | |
port = int(os.environ.get("PORT", 7860)) | |
app.run(debug=False, host='0.0.0.0', port=port) | |