Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Depends, Request, status, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from fastapi.security import OAuth2PasswordBearer | |
from fastapi.middleware.gzip import GZipMiddleware | |
from typing import Dict, Any, Optional, List | |
import time | |
import logging | |
from datetime import datetime | |
from pydantic import BaseModel, Field | |
import os | |
import asyncio | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from config.config import settings | |
from core.rag_engine import RAGEngine | |
from core.user_profile import UserProfile, UserPreferences | |
# Define missing types | |
class ChatRequest(BaseModel): | |
message: str | |
chat_history: Optional[List[Dict[str, str]]] = None | |
class ChatResponse(BaseModel): | |
answer: str | |
sources: Optional[List[str]] = None | |
suggested_questions: Optional[List[str]] = None | |
class ErrorResponse(BaseModel): | |
error: str | |
detail: Optional[str] = None | |
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) | |
request_id: Optional[str] = None | |
class UserProfileResponse(BaseModel): | |
profile: Dict[str, Any] | |
class UserPreferencesUpdate(BaseModel): | |
preferences: Dict[str, Any] | |
# Setup logging with rotation | |
from logging.handlers import RotatingFileHandler | |
logging.basicConfig( | |
level=getattr(logging, settings.LOG_LEVEL), | |
format=settings.LOG_FORMAT, | |
handlers=[ | |
logging.StreamHandler(), | |
RotatingFileHandler( | |
"api.log", | |
maxBytes=10 * 1024 * 1024, # 10MB | |
backupCount=5, | |
), | |
], | |
) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title=settings.PROJECT_NAME, | |
description="AI-powered travel assistant API", | |
version=settings.VERSION, | |
docs_url="/docs", # Always show docs on HF Spaces | |
redoc_url="/redoc", | |
) | |
# Add security headers middleware | |
async def add_security_headers(request: Request, call_next): | |
response = await call_next(request) | |
response.headers["X-Content-Type-Options"] = "nosniff" | |
response.headers["X-Frame-Options"] = "DENY" | |
response.headers["X-XSS-Protection"] = "1; mode=block" | |
response.headers["Strict-Transport-Security"] = ( | |
"max-age=31536000; includeSubDomains" | |
) | |
return response | |
# Add CORS middleware with validation | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allow all origins for Hugging Face Spaces | |
allow_credentials=True, | |
allow_methods=["GET", "POST", "PUT", "DELETE"], | |
allow_headers=["*"], | |
max_age=3600, | |
) | |
# Add Gzip compression | |
app.add_middleware(GZipMiddleware, minimum_size=1000) | |
# Initialize core components with retry | |
async def initialize_components(): | |
try: | |
global rag_engine, user_profile | |
rag_engine = RAGEngine() | |
user_profile = UserProfile() | |
logger.info("Core components initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize core components: {str(e)}", exc_info=True) | |
raise | |
# Initialize components asynchronously | |
asyncio.create_task(initialize_components()) | |
# OAuth2 scheme for token authentication | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
from api.dependencies import ( | |
get_current_user, | |
rate_limit, | |
cleanup, | |
) | |
async def global_exception_handler(request: Request, exc: Exception): | |
"""Global exception handler with request ID""" | |
request_id = request.headers.get("X-Request-ID", "unknown") | |
logger.error( | |
f"Unhandled exception: {str(exc)}", | |
exc_info=True, | |
extra={"request_id": request_id}, | |
) | |
return JSONResponse( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
content=ErrorResponse( | |
error="Internal Server Error", detail=str(exc), request_id=request_id | |
).dict(), | |
) | |
async def add_process_time_header(request: Request, call_next): | |
"""Add processing time header to response""" | |
start_time = time.time() | |
try: | |
response = await call_next(request) | |
process_time = time.time() - start_time | |
response.headers["X-Process-Time"] = str(process_time) | |
return response | |
except Exception as e: | |
logger.error(f"Error in middleware: {str(e)}", exc_info=True) | |
raise | |
async def root(): | |
"""Root endpoint with version info""" | |
return { | |
"message": "Welcome to TravelMate AI Assistant API", | |
"version": settings.VERSION, | |
"environment": settings.DEBUG, # Use DEBUG setting for environment | |
} | |
async def chat( | |
request: ChatRequest, | |
background_tasks: BackgroundTasks, | |
current_user: Dict[str, Any] = Depends(get_current_user), | |
): | |
"""Process chat request with enhanced validation""" | |
try: | |
# Validate request size | |
if len(request.message) > settings.MAX_MESSAGE_LENGTH: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Message too long. Maximum length is {settings.MAX_MESSAGE_LENGTH} characters", | |
) | |
# Validate chat history | |
if request.chat_history: | |
if len(request.chat_history) > settings.MAX_CHAT_HISTORY: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Chat history too long. Maximum length is {settings.MAX_CHAT_HISTORY} messages", | |
) | |
for msg in request.chat_history: | |
if not isinstance(msg, dict) or not all( | |
k in msg for k in ["user", "assistant"] | |
): | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Invalid chat history format", | |
) | |
# Process query with RAG engine | |
result = await asyncio.wait_for( | |
rag_engine.process_query( | |
query=request.message, | |
chat_history=request.chat_history, | |
user_id=current_user["user_id"], | |
), | |
timeout=settings.QUERY_TIMEOUT, | |
) | |
# Add cleanup task | |
background_tasks.add_task(cleanup) | |
return ChatResponse( | |
answer=result["answer"], | |
sources=result.get("metadata", {}).get("sources", []), | |
suggested_questions=result.get("suggested_questions", []), | |
) | |
except asyncio.TimeoutError: | |
raise HTTPException( | |
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out" | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error processing chat request: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Error processing chat request", | |
) | |
async def get_profile(current_user: Dict[str, Any] = Depends(get_current_user)): | |
"""Get user profile with enhanced error handling""" | |
try: | |
profile = await asyncio.wait_for( | |
user_profile.get_profile(current_user["user_id"]), | |
timeout=settings.PROFILE_TIMEOUT, | |
) | |
if not profile: | |
raise HTTPException( | |
status_code=status.HTTP_404_NOT_FOUND, detail="Profile not found" | |
) | |
return UserProfileResponse(**profile) | |
except asyncio.TimeoutError: | |
raise HTTPException( | |
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out" | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error getting user profile: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Error retrieving profile", | |
) | |
async def update_preferences( | |
preferences: UserPreferencesUpdate, | |
current_user: Dict[str, Any] = Depends(get_current_user), | |
): | |
"""Update user preferences with validation""" | |
try: | |
# Validate preferences | |
try: | |
UserPreferences(**preferences.preferences) | |
except Exception as e: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Invalid preferences: {str(e)}", | |
) | |
success = await asyncio.wait_for( | |
user_profile.update_profile( | |
current_user["user_id"], {"preferences": preferences.preferences} | |
), | |
timeout=settings.PROFILE_TIMEOUT, | |
) | |
if not success: | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Failed to update preferences", | |
) | |
return {"message": "Preferences updated successfully"} | |
except asyncio.TimeoutError: | |
raise HTTPException( | |
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out" | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error updating preferences: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Error updating preferences", | |
) | |
async def health_check(): | |
"""Health check endpoint with detailed status""" | |
try: | |
# Check core components | |
if not rag_engine or not user_profile: | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Core components not initialized", | |
) | |
return { | |
"status": "healthy", | |
"timestamp": datetime.utcnow().isoformat(), | |
"version": settings.VERSION, | |
"environment": settings.DEBUG, # Use DEBUG setting for environment | |
"components": { | |
"rag_engine": "ok", | |
"user_profile": "ok", | |
}, | |
} | |
except Exception as e: | |
logger.error(f"Health check failed: {str(e)}", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unhealthy" | |
) | |
async def shutdown_event(): | |
"""Cleanup on shutdown""" | |
await cleanup() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
"api.main:app", | |
host="0.0.0.0", | |
port=int(os.getenv("PORT", 7860)), | |
reload=False, # Set reload to False for production | |
) | |