Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
from pathlib import Path | |
from typing import List, Optional | |
import io | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, File, UploadFile, Request, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks | |
from fastapi.responses import HTMLResponse, FileResponse, StreamingResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from PIL import Image | |
from image_indexer import ImageIndexer | |
from image_search import ImageSearch | |
from image_database import ImageDatabase | |
# Initialize image indexer, searcher, and database | |
indexer = ImageIndexer() | |
searcher = ImageSearch() | |
image_db = ImageDatabase() | |
image_extensions = [".jpg", ".jpeg", ".png", ".gif"] | |
async def lifespan(_: FastAPI): | |
"""Initialize the image indexer""" | |
yield | |
app = FastAPI(title="Visual Product Search", lifespan=lifespan) | |
# Setup templates and static files | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
async def home(request: Request): | |
"""Render the home page""" | |
folders = indexer.folder_manager.get_all_folders() | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"initial_status": { | |
"status": indexer.status.value, | |
"current_file": indexer.current_file, | |
"total_files": indexer.total_files, | |
"processed_files": indexer.processed_files, | |
"progress_percentage": round((indexer.processed_files / indexer.total_files * 100) if indexer.total_files > 0 else 0, 2) | |
}, | |
"folders": folders | |
} | |
) | |
async def upload_images( | |
files: List[UploadFile] = File(...), | |
background_tasks: BackgroundTasks = None | |
): | |
"""Upload multiple images and index them""" | |
try: | |
# Create uploads directory if it doesn't exist | |
upload_dir = Path("uploads") | |
upload_dir.mkdir(exist_ok=True) | |
# Save uploaded files | |
saved_files = [] | |
for file in files: | |
if file.content_type and file.content_type.startswith('image/'): | |
# Generate unique filename | |
file_extension = Path(file.filename).suffix | |
unique_filename = f"{uuid.uuid4()}{file_extension}" | |
file_path = upload_dir / unique_filename | |
# Save the file | |
contents = await file.read() | |
with open(file_path, "wb") as f: | |
f.write(contents) | |
saved_files.append(str(file_path)) | |
else: | |
raise HTTPException(status_code=400, detail=f"File {file.filename} is not a valid image") | |
if saved_files: | |
# Add the upload folder to be indexed | |
folder_info = indexer.folder_manager.add_folder(str(upload_dir)) | |
# Start indexing in the background | |
if background_tasks: | |
background_tasks.add_task(indexer.index_folder, str(upload_dir)) | |
return { | |
"status": "success", | |
"message": f"Uploaded and indexing {len(saved_files)} images", | |
"folder_info": folder_info, | |
"uploaded_files": saved_files | |
} | |
else: | |
raise HTTPException(status_code=400, detail="No valid images were uploaded") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) from e | |
async def add_folder(folder_path: str, background_tasks: BackgroundTasks): | |
"""Add a new folder to index""" | |
try: | |
# Add folder to manager first (this creates the collection) | |
folder_info = indexer.folder_manager.add_folder(folder_path) | |
# Start indexing in the background | |
background_tasks.add_task(indexer.index_folder, folder_path) | |
return folder_info | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) from e | |
async def remove_folder(folder_path: str): | |
"""Remove a folder from indexing""" | |
try: | |
await indexer.remove_folder(folder_path) | |
return {"status": "success"} | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) from e | |
async def list_folders(): | |
"""List all indexed folders""" | |
return indexer.folder_manager.get_all_folders() | |
async def search_by_text(query: str, folder: Optional[str] = None) -> List[dict]: | |
"""Search images by text query, optionally filtered by folder""" | |
results = await searcher.search_by_text(query, folder) | |
return results | |
async def search_by_image( | |
file: UploadFile = File(...), | |
folder: Optional[str] = None | |
) -> List[dict]: | |
"""Search images by uploading a similar image, optionally filtered by folder""" | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)) | |
results = await searcher.search_by_image(image, folder) | |
return results | |
async def search_by_url( | |
url: str, | |
folder: Optional[str] = None | |
) -> List[dict]: | |
"""Search images by providing a URL to a similar image, optionally filtered by folder""" | |
results = await searcher.search_by_url(url, folder) | |
return results | |
async def list_images(folder: Optional[str] = None) -> List[dict]: | |
"""List all indexed images, optionally filtered by folder""" | |
return await indexer.get_all_images(folder) | |
async def websocket_endpoint(websocket: WebSocket): | |
"""WebSocket endpoint for real-time indexing status updates""" | |
await indexer.add_websocket_connection(websocket) | |
try: | |
while True: | |
await websocket.receive_text() | |
except WebSocketDisconnect: | |
await indexer.remove_websocket_connection(websocket) | |
async def serve_image(image_id: str): | |
"""Serve an image from the database by ID""" | |
try: | |
image_data = image_db.get_image(image_id) | |
if not image_data: | |
raise HTTPException(status_code=404, detail="Image not found") | |
return StreamingResponse( | |
io.BytesIO(image_data["image_data"]), | |
media_type=f"image/{image_data['file_extension'].lstrip('.')}", | |
headers={ | |
"Cache-Control": "max-age=86400", # Cache for 24 hours | |
"Content-Disposition": f"inline; filename=\"{image_data['filename']}\"" | |
} | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def serve_thumbnail_by_id(image_id: str): | |
"""Serve a thumbnail from the database by ID""" | |
try: | |
thumbnail_data = image_db.get_thumbnail(image_id) | |
if not thumbnail_data: | |
raise HTTPException(status_code=404, detail="Thumbnail not found") | |
return StreamingResponse( | |
io.BytesIO(thumbnail_data), | |
media_type="image/jpeg", | |
headers={"Cache-Control": "max-age=86400"} # Cache for 24 hours | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_database_stats(): | |
"""Get database statistics""" | |
try: | |
return image_db.get_database_stats() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def debug_collections(): | |
"""Debug endpoint to check collections and folders""" | |
try: | |
# Get Qdrant client and collections | |
qdrant_client = indexer.qdrant | |
collections = qdrant_client.get_collections().collections | |
# Get folder manager status | |
folders = indexer.folder_manager.get_all_folders() | |
return { | |
"qdrant_collections": [col.name for col in collections], | |
"folder_manager_folders": folders, | |
"collections_count": len(collections), | |
"folders_count": len(folders) | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
# Keep the old endpoints for backward compatibility but mark as deprecated | |
async def serve_thumbnail(folder_path: str, file_path: str): | |
"""Serve resized image thumbnails (DEPRECATED - use /thumbnail/{image_id} instead)""" | |
try: | |
# Get folder info to verify it's an indexed folder | |
folder_info = indexer.folder_manager.get_folder_info(folder_path) | |
if not folder_info: | |
raise HTTPException(status_code=404, detail="Folder not found") | |
# Construct full file path | |
full_path = Path(folder_path) / file_path | |
if not full_path.exists(): | |
raise HTTPException(status_code=404, detail="File not found") | |
# Only serve image files | |
if full_path.suffix.lower() not in image_extensions: | |
raise HTTPException(status_code=400, detail="Invalid file type") | |
# Open image, resize, and convert to JPEG | |
img = Image.open(full_path) | |
img.thumbnail((200, 200)) # Resize maintaining aspect ratio | |
# Save to a byte stream | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format="JPEG") | |
img_byte_arr.seek(0) | |
return StreamingResponse(img_byte_arr, media_type="image/jpeg", headers={"Cache-Control": "max-age=3600"}) # Cache for 1 hour | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def serve_file(folder_path: str, file_path: str): | |
"""Serve files from indexed folders (DEPRECATED - use /image/{image_id} instead)""" | |
try: | |
# Get folder info to verify it's an indexed folder | |
folder_info = indexer.folder_manager.get_folder_info(folder_path) | |
if not folder_info: | |
raise HTTPException(status_code=404, detail="Folder not found") | |
# Construct full file path | |
full_path = Path(folder_path) / file_path | |
if not full_path.exists(): | |
raise HTTPException(status_code=404, detail="File not found") | |
# Only serve image files | |
if full_path.suffix.lower() not in image_extensions: | |
raise HTTPException(status_code=400, detail="Invalid file type") | |
return FileResponse(full_path) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) from e | |
def get_windows_drives(): | |
"""Get available drives on Windows""" | |
from ctypes import windll | |
drives = [] | |
bitmask = windll.kernel32.GetLogicalDrives() | |
for letter in range(65, 91): # A-Z | |
if bitmask & (1 << (letter - 65)): | |
drives.append(chr(letter) + ":\\") | |
return drives | |
def get_directory_item(item): | |
"""Get directory item info""" | |
try: | |
is_dir = item.is_dir() | |
if is_dir or item.suffix.lower() in image_extensions: | |
return { | |
"name": item.name, | |
"path": str(item.absolute()), | |
"type": "directory" if is_dir else "file", | |
"size": item.stat().st_size if not is_dir else None | |
} | |
except Exception: | |
pass | |
return None | |
def get_directory_contents(path: str): | |
"""Get contents of a directory""" | |
try: | |
path_obj = Path(path) | |
if not path_obj.exists(): | |
return {"error": "Path does not exist"} | |
parent = str(path_obj.parent) if path_obj.parent != path_obj else None | |
contents = [ | |
item for item in (get_directory_item(i) for i in path_obj.iterdir()) | |
if item is not None | |
] | |
return { | |
"current_path": str(path_obj.absolute()), | |
"parent_path": parent, | |
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower())) | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
async def browse_folders(): | |
"""Browse system folders""" | |
if os.name == "nt": # Windows | |
return {"drives": get_windows_drives()} | |
return get_directory_contents("/") # Unix-like | |
async def browse_path(path: str): | |
"""Browse a specific path""" | |
try: | |
path_obj = Path(path) | |
if not path_obj.exists(): | |
raise HTTPException(status_code=404, detail="Path not found") | |
# Get parent directory for navigation | |
parent = str(path_obj.parent) if path_obj.parent != path_obj else None | |
# List directories and files | |
contents = [] | |
for item in path_obj.iterdir(): | |
try: | |
is_dir = item.is_dir() | |
if is_dir or item.suffix.lower() in image_extensions: | |
contents.append({ | |
"name": item.name, | |
"path": str(item.absolute()), | |
"type": "directory" if is_dir else "file", | |
"size": item.stat().st_size if not is_dir else None | |
}) | |
except Exception: | |
continue # Skip items we can't access | |
return { | |
"current_path": str(path_obj.absolute()), | |
"parent_path": parent, | |
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower())) | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) from e | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False) |