Spaces:
Sleeping
Sleeping
Commit
·
09aa2b8
0
Parent(s):
First Commit
Browse files- .gitattributes +35 -0
- .gitignore +93 -0
- README.md +13 -0
- api/dependencies.py +126 -0
- api/main.py +344 -0
- api/schemas.py +61 -0
- app.py +210 -0
- config/config.py +112 -0
- core/data_loader.py +185 -0
- core/rag_engine.py +151 -0
- core/user_profile.py +464 -0
- data/.gitkeep +1 -0
- docs/API.md +294 -0
- huggingface.yaml +113 -0
- requirements.txt +44 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
build/
|
| 9 |
+
develop-eggs/
|
| 10 |
+
dist/
|
| 11 |
+
downloads/
|
| 12 |
+
eggs/
|
| 13 |
+
.eggs/
|
| 14 |
+
lib/
|
| 15 |
+
lib64/
|
| 16 |
+
parts/
|
| 17 |
+
sdist/
|
| 18 |
+
var/
|
| 19 |
+
wheels/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.installed.cfg
|
| 22 |
+
*.egg
|
| 23 |
+
.pytest_cache/
|
| 24 |
+
.coverage
|
| 25 |
+
htmlcov/
|
| 26 |
+
|
| 27 |
+
# Virtual Environment
|
| 28 |
+
venv/
|
| 29 |
+
ENV/
|
| 30 |
+
.env/
|
| 31 |
+
.venv/
|
| 32 |
+
|
| 33 |
+
# IDE
|
| 34 |
+
.idea/
|
| 35 |
+
.vscode/
|
| 36 |
+
*.swp
|
| 37 |
+
*.swo
|
| 38 |
+
.DS_Store
|
| 39 |
+
*.sublime-workspace
|
| 40 |
+
*.sublime-project
|
| 41 |
+
|
| 42 |
+
# Project specific
|
| 43 |
+
data/vector_store/
|
| 44 |
+
data/travel_guides.json
|
| 45 |
+
data/user_profiles/
|
| 46 |
+
data/cache/
|
| 47 |
+
.cache/
|
| 48 |
+
*.log
|
| 49 |
+
logs/
|
| 50 |
+
.env
|
| 51 |
+
.env.*
|
| 52 |
+
!.env.example
|
| 53 |
+
secrets.json
|
| 54 |
+
secret_key.py
|
| 55 |
+
|
| 56 |
+
# Model files
|
| 57 |
+
models/
|
| 58 |
+
*.bin
|
| 59 |
+
*.pt
|
| 60 |
+
*.pth
|
| 61 |
+
*.onnx
|
| 62 |
+
*.h5
|
| 63 |
+
*.hdf5
|
| 64 |
+
*.ckpt
|
| 65 |
+
*.safetensors
|
| 66 |
+
|
| 67 |
+
# Hugging Face
|
| 68 |
+
.huggingface/
|
| 69 |
+
transformers/
|
| 70 |
+
datasets/
|
| 71 |
+
hub/
|
| 72 |
+
|
| 73 |
+
# Temporary files
|
| 74 |
+
tmp/
|
| 75 |
+
temp/
|
| 76 |
+
*.tmp
|
| 77 |
+
*.temp
|
| 78 |
+
*.bak
|
| 79 |
+
*.swp
|
| 80 |
+
*~
|
| 81 |
+
|
| 82 |
+
# System files
|
| 83 |
+
.DS_Store
|
| 84 |
+
Thumbs.db
|
| 85 |
+
desktop.ini
|
| 86 |
+
|
| 87 |
+
# Docker
|
| 88 |
+
.docker/
|
| 89 |
+
docker-compose.override.yml
|
| 90 |
+
|
| 91 |
+
# Documentation
|
| 92 |
+
docs/_build/
|
| 93 |
+
site/
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TravelMate AI
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.33.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: AI-Powered Customer Support Chatbot using RAG
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
api/dependencies.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends, HTTPException, status
|
| 2 |
+
from fastapi.security import OAuth2PasswordBearer
|
| 3 |
+
from jose import JWTError, jwt
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from typing import Dict, Any, Optional
|
| 6 |
+
import time
|
| 7 |
+
from functools import wraps
|
| 8 |
+
import logging
|
| 9 |
+
from cachetools import TTLCache
|
| 10 |
+
|
| 11 |
+
# BaseSettings import removed – unused
|
| 12 |
+
|
| 13 |
+
from config.config import settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Rate limiting cache with adjusted size
|
| 18 |
+
rate_limit_cache = TTLCache(
|
| 19 |
+
maxsize=settings.MAX_CACHE_SIZE, ttl=settings.RATE_LIMIT_WINDOW
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# JWT settings from environment
|
| 23 |
+
SECRET_KEY = settings.JWT_SECRET_KEY
|
| 24 |
+
ALGORITHM = "HS256"
|
| 25 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
| 26 |
+
|
| 27 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
| 31 |
+
"""Create JWT access token with validation"""
|
| 32 |
+
if not isinstance(data, dict):
|
| 33 |
+
raise ValueError("Token data must be a dictionary")
|
| 34 |
+
|
| 35 |
+
to_encode = data.copy()
|
| 36 |
+
if expires_delta:
|
| 37 |
+
expire = datetime.utcnow() + expires_delta
|
| 38 |
+
else:
|
| 39 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 40 |
+
|
| 41 |
+
to_encode.update({"exp": expire})
|
| 42 |
+
try:
|
| 43 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
| 44 |
+
return encoded_jwt
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"Error creating access token: {str(e)}", exc_info=True)
|
| 47 |
+
raise HTTPException(
|
| 48 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 49 |
+
detail="Error creating access token",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
|
| 54 |
+
"""Get current user with enhanced validation"""
|
| 55 |
+
credentials_exception = HTTPException(
|
| 56 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 57 |
+
detail="Could not validate credentials",
|
| 58 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 63 |
+
user_id: str = payload.get("sub")
|
| 64 |
+
if not user_id or not isinstance(user_id, str):
|
| 65 |
+
raise credentials_exception
|
| 66 |
+
|
| 67 |
+
# Validate token expiration
|
| 68 |
+
exp = payload.get("exp")
|
| 69 |
+
if not exp or datetime.fromtimestamp(exp) < datetime.utcnow():
|
| 70 |
+
raise HTTPException(
|
| 71 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 72 |
+
detail="Token has expired",
|
| 73 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return {"user_id": user_id}
|
| 77 |
+
except JWTError as e:
|
| 78 |
+
logger.error(f"JWT validation error: {str(e)}", exc_info=True)
|
| 79 |
+
raise credentials_exception
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def rate_limit(func):
|
| 83 |
+
"""Rate limit decorator with enhanced validation"""
|
| 84 |
+
|
| 85 |
+
@wraps(func)
|
| 86 |
+
async def wrapper(*args, **kwargs):
|
| 87 |
+
current_user = kwargs.get("current_user")
|
| 88 |
+
if not current_user or "user_id" not in current_user:
|
| 89 |
+
raise HTTPException(
|
| 90 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 91 |
+
detail="User ID is required for rate limiting",
|
| 92 |
+
)
|
| 93 |
+
user_id = current_user["user_id"]
|
| 94 |
+
|
| 95 |
+
# Check rate limit with enhanced validation
|
| 96 |
+
current_time = time.time()
|
| 97 |
+
key = f"{user_id}:{current_time // settings.RATE_LIMIT_WINDOW}"
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
with rate_limit_cache._lock:
|
| 101 |
+
if key in rate_limit_cache:
|
| 102 |
+
count = rate_limit_cache[key]
|
| 103 |
+
if count >= settings.RATE_LIMIT_REQUESTS:
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 106 |
+
detail=f"Rate limit exceeded. Try again in {settings.RATE_LIMIT_WINDOW} seconds",
|
| 107 |
+
)
|
| 108 |
+
rate_limit_cache[key] = count + 1
|
| 109 |
+
else:
|
| 110 |
+
rate_limit_cache[key] = 1
|
| 111 |
+
|
| 112 |
+
return await func(*args, **kwargs)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Rate limit error: {str(e)}", exc_info=True)
|
| 115 |
+
raise HTTPException(
|
| 116 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 117 |
+
detail="Error processing rate limit",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return wrapper
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def cleanup():
|
| 124 |
+
"""Cleanup resources"""
|
| 125 |
+
# Add any necessary cleanup here, e.g., closing database connections
|
| 126 |
+
pass
|
api/main.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Depends, Request, status, BackgroundTasks
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
from fastapi.security import OAuth2PasswordBearer
|
| 5 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
| 6 |
+
from typing import Dict, Any, Optional, List
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
import os
|
| 12 |
+
import asyncio
|
| 13 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 14 |
+
|
| 15 |
+
from config.config import settings
|
| 16 |
+
from core.rag_engine import RAGEngine
|
| 17 |
+
from core.user_profile import UserProfile, UserPreferences
|
| 18 |
+
|
| 19 |
+
# Define missing types
|
| 20 |
+
class ChatRequest(BaseModel):
|
| 21 |
+
message: str
|
| 22 |
+
chat_history: Optional[List[Dict[str, str]]] = None
|
| 23 |
+
|
| 24 |
+
class ChatResponse(BaseModel):
|
| 25 |
+
answer: str
|
| 26 |
+
sources: Optional[List[str]] = None
|
| 27 |
+
suggested_questions: Optional[List[str]] = None
|
| 28 |
+
|
| 29 |
+
class ErrorResponse(BaseModel):
|
| 30 |
+
error: str
|
| 31 |
+
detail: Optional[str] = None
|
| 32 |
+
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
|
| 33 |
+
request_id: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
class UserProfileResponse(BaseModel):
|
| 36 |
+
profile: Dict[str, Any]
|
| 37 |
+
|
| 38 |
+
class UserPreferencesUpdate(BaseModel):
|
| 39 |
+
preferences: Dict[str, Any]
|
| 40 |
+
|
| 41 |
+
# Setup logging with rotation
|
| 42 |
+
from logging.handlers import RotatingFileHandler
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=getattr(logging, settings.LOG_LEVEL),
|
| 46 |
+
format=settings.LOG_FORMAT,
|
| 47 |
+
handlers=[
|
| 48 |
+
logging.StreamHandler(),
|
| 49 |
+
RotatingFileHandler(
|
| 50 |
+
"api.log",
|
| 51 |
+
maxBytes=10 * 1024 * 1024, # 10MB
|
| 52 |
+
backupCount=5,
|
| 53 |
+
),
|
| 54 |
+
],
|
| 55 |
+
)
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
app = FastAPI(
|
| 59 |
+
title=settings.PROJECT_NAME,
|
| 60 |
+
description="AI-powered travel assistant API",
|
| 61 |
+
version=settings.VERSION,
|
| 62 |
+
docs_url="/docs", # Always show docs on HF Spaces
|
| 63 |
+
redoc_url="/redoc",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Add security headers middleware
|
| 67 |
+
@app.middleware("http")
|
| 68 |
+
async def add_security_headers(request: Request, call_next):
|
| 69 |
+
response = await call_next(request)
|
| 70 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 71 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 72 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 73 |
+
response.headers["Strict-Transport-Security"] = (
|
| 74 |
+
"max-age=31536000; includeSubDomains"
|
| 75 |
+
)
|
| 76 |
+
return response
|
| 77 |
+
|
| 78 |
+
# Add CORS middleware with validation
|
| 79 |
+
app.add_middleware(
|
| 80 |
+
CORSMiddleware,
|
| 81 |
+
allow_origins=["*"], # Allow all origins for Hugging Face Spaces
|
| 82 |
+
allow_credentials=True,
|
| 83 |
+
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
| 84 |
+
allow_headers=["*"],
|
| 85 |
+
max_age=3600,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Add Gzip compression
|
| 89 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 90 |
+
|
| 91 |
+
# Initialize core components with retry
|
| 92 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
| 93 |
+
async def initialize_components():
|
| 94 |
+
try:
|
| 95 |
+
global rag_engine, user_profile
|
| 96 |
+
rag_engine = RAGEngine()
|
| 97 |
+
user_profile = UserProfile()
|
| 98 |
+
logger.info("Core components initialized successfully")
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Failed to initialize core components: {str(e)}", exc_info=True)
|
| 101 |
+
raise
|
| 102 |
+
|
| 103 |
+
# Initialize components asynchronously
|
| 104 |
+
asyncio.create_task(initialize_components())
|
| 105 |
+
|
| 106 |
+
# OAuth2 scheme for token authentication
|
| 107 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
| 108 |
+
|
| 109 |
+
from api.dependencies import (
|
| 110 |
+
get_current_user,
|
| 111 |
+
rate_limit,
|
| 112 |
+
cleanup,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
@app.exception_handler(Exception)
|
| 116 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 117 |
+
"""Global exception handler with request ID"""
|
| 118 |
+
request_id = request.headers.get("X-Request-ID", "unknown")
|
| 119 |
+
logger.error(
|
| 120 |
+
f"Unhandled exception: {str(exc)}",
|
| 121 |
+
exc_info=True,
|
| 122 |
+
extra={"request_id": request_id},
|
| 123 |
+
)
|
| 124 |
+
return JSONResponse(
|
| 125 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 126 |
+
content=ErrorResponse(
|
| 127 |
+
error="Internal Server Error", detail=str(exc), request_id=request_id
|
| 128 |
+
).dict(),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
@app.middleware("http")
|
| 132 |
+
async def add_process_time_header(request: Request, call_next):
|
| 133 |
+
"""Add processing time header to response"""
|
| 134 |
+
start_time = time.time()
|
| 135 |
+
try:
|
| 136 |
+
response = await call_next(request)
|
| 137 |
+
process_time = time.time() - start_time
|
| 138 |
+
response.headers["X-Process-Time"] = str(process_time)
|
| 139 |
+
return response
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error in middleware: {str(e)}", exc_info=True)
|
| 142 |
+
raise
|
| 143 |
+
|
| 144 |
+
@app.get("/")
|
| 145 |
+
async def root():
|
| 146 |
+
"""Root endpoint with version info"""
|
| 147 |
+
return {
|
| 148 |
+
"message": "Welcome to TravelMate AI Assistant API",
|
| 149 |
+
"version": settings.VERSION,
|
| 150 |
+
"environment": settings.DEBUG, # Use DEBUG setting for environment
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
@app.post(
|
| 154 |
+
"/chat",
|
| 155 |
+
response_model=ChatResponse,
|
| 156 |
+
responses={
|
| 157 |
+
400: {"model": ErrorResponse},
|
| 158 |
+
401: {"model": ErrorResponse},
|
| 159 |
+
429: {"model": ErrorResponse},
|
| 160 |
+
500: {"model": ErrorResponse},
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
@rate_limit
|
| 164 |
+
async def chat(
|
| 165 |
+
request: ChatRequest,
|
| 166 |
+
background_tasks: BackgroundTasks,
|
| 167 |
+
current_user: Dict[str, Any] = Depends(get_current_user),
|
| 168 |
+
):
|
| 169 |
+
"""Process chat request with enhanced validation"""
|
| 170 |
+
try:
|
| 171 |
+
# Validate request size
|
| 172 |
+
if len(request.message) > settings.MAX_MESSAGE_LENGTH:
|
| 173 |
+
raise HTTPException(
|
| 174 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 175 |
+
detail=f"Message too long. Maximum length is {settings.MAX_MESSAGE_LENGTH} characters",
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Validate chat history
|
| 179 |
+
if request.chat_history:
|
| 180 |
+
if len(request.chat_history) > settings.MAX_CHAT_HISTORY:
|
| 181 |
+
raise HTTPException(
|
| 182 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 183 |
+
detail=f"Chat history too long. Maximum length is {settings.MAX_CHAT_HISTORY} messages",
|
| 184 |
+
)
|
| 185 |
+
for msg in request.chat_history:
|
| 186 |
+
if not isinstance(msg, dict) or not all(
|
| 187 |
+
k in msg for k in ["user", "assistant"]
|
| 188 |
+
):
|
| 189 |
+
raise HTTPException(
|
| 190 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 191 |
+
detail="Invalid chat history format",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Process query with RAG engine
|
| 195 |
+
result = await asyncio.wait_for(
|
| 196 |
+
rag_engine.process_query(
|
| 197 |
+
query=request.message,
|
| 198 |
+
chat_history=request.chat_history,
|
| 199 |
+
user_id=current_user["user_id"],
|
| 200 |
+
),
|
| 201 |
+
timeout=settings.QUERY_TIMEOUT,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Add cleanup task
|
| 205 |
+
background_tasks.add_task(cleanup)
|
| 206 |
+
|
| 207 |
+
return ChatResponse(
|
| 208 |
+
answer=result["answer"],
|
| 209 |
+
sources=result.get("metadata", {}).get("sources", []),
|
| 210 |
+
suggested_questions=result.get("suggested_questions", []),
|
| 211 |
+
)
|
| 212 |
+
except asyncio.TimeoutError:
|
| 213 |
+
raise HTTPException(
|
| 214 |
+
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out"
|
| 215 |
+
)
|
| 216 |
+
except HTTPException:
|
| 217 |
+
raise
|
| 218 |
+
except Exception as e:
|
| 219 |
+
logger.error(f"Error processing chat request: {str(e)}", exc_info=True)
|
| 220 |
+
raise HTTPException(
|
| 221 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 222 |
+
detail="Error processing chat request",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
@app.get(
|
| 226 |
+
"/profile",
|
| 227 |
+
response_model=UserProfileResponse,
|
| 228 |
+
responses={401: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
|
| 229 |
+
)
|
| 230 |
+
async def get_profile(current_user: Dict[str, Any] = Depends(get_current_user)):
|
| 231 |
+
"""Get user profile with enhanced error handling"""
|
| 232 |
+
try:
|
| 233 |
+
profile = await asyncio.wait_for(
|
| 234 |
+
user_profile.get_profile(current_user["user_id"]),
|
| 235 |
+
timeout=settings.PROFILE_TIMEOUT,
|
| 236 |
+
)
|
| 237 |
+
if not profile:
|
| 238 |
+
raise HTTPException(
|
| 239 |
+
status_code=status.HTTP_404_NOT_FOUND, detail="Profile not found"
|
| 240 |
+
)
|
| 241 |
+
return UserProfileResponse(**profile)
|
| 242 |
+
except asyncio.TimeoutError:
|
| 243 |
+
raise HTTPException(
|
| 244 |
+
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out"
|
| 245 |
+
)
|
| 246 |
+
except HTTPException:
|
| 247 |
+
raise
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.error(f"Error getting user profile: {str(e)}", exc_info=True)
|
| 250 |
+
raise HTTPException(
|
| 251 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 252 |
+
detail="Error retrieving profile",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
@app.put(
|
| 256 |
+
"/profile/preferences",
|
| 257 |
+
responses={
|
| 258 |
+
400: {"model": ErrorResponse},
|
| 259 |
+
401: {"model": ErrorResponse},
|
| 260 |
+
500: {"model": ErrorResponse},
|
| 261 |
+
},
|
| 262 |
+
)
|
| 263 |
+
async def update_preferences(
|
| 264 |
+
preferences: UserPreferencesUpdate,
|
| 265 |
+
current_user: Dict[str, Any] = Depends(get_current_user),
|
| 266 |
+
):
|
| 267 |
+
"""Update user preferences with validation"""
|
| 268 |
+
try:
|
| 269 |
+
# Validate preferences
|
| 270 |
+
try:
|
| 271 |
+
UserPreferences(**preferences.preferences)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
raise HTTPException(
|
| 274 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 275 |
+
detail=f"Invalid preferences: {str(e)}",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
success = await asyncio.wait_for(
|
| 279 |
+
user_profile.update_profile(
|
| 280 |
+
current_user["user_id"], {"preferences": preferences.preferences}
|
| 281 |
+
),
|
| 282 |
+
timeout=settings.PROFILE_TIMEOUT,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if not success:
|
| 286 |
+
raise HTTPException(
|
| 287 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 288 |
+
detail="Failed to update preferences",
|
| 289 |
+
)
|
| 290 |
+
return {"message": "Preferences updated successfully"}
|
| 291 |
+
except asyncio.TimeoutError:
|
| 292 |
+
raise HTTPException(
|
| 293 |
+
status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out"
|
| 294 |
+
)
|
| 295 |
+
except HTTPException:
|
| 296 |
+
raise
|
| 297 |
+
except Exception as e:
|
| 298 |
+
logger.error(f"Error updating preferences: {str(e)}", exc_info=True)
|
| 299 |
+
raise HTTPException(
|
| 300 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 301 |
+
detail="Error updating preferences",
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
@app.get("/health", responses={500: {"model": ErrorResponse}})
|
| 305 |
+
async def health_check():
|
| 306 |
+
"""Health check endpoint with detailed status"""
|
| 307 |
+
try:
|
| 308 |
+
# Check core components
|
| 309 |
+
if not rag_engine or not user_profile:
|
| 310 |
+
raise HTTPException(
|
| 311 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 312 |
+
detail="Core components not initialized",
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return {
|
| 316 |
+
"status": "healthy",
|
| 317 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 318 |
+
"version": settings.VERSION,
|
| 319 |
+
"environment": settings.DEBUG, # Use DEBUG setting for environment
|
| 320 |
+
"components": {
|
| 321 |
+
"rag_engine": "ok",
|
| 322 |
+
"user_profile": "ok",
|
| 323 |
+
},
|
| 324 |
+
}
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logger.error(f"Health check failed: {str(e)}", exc_info=True)
|
| 327 |
+
raise HTTPException(
|
| 328 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unhealthy"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
@app.on_event("shutdown")
|
| 332 |
+
async def shutdown_event():
|
| 333 |
+
"""Cleanup on shutdown"""
|
| 334 |
+
await cleanup()
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
import uvicorn
|
| 338 |
+
|
| 339 |
+
uvicorn.run(
|
| 340 |
+
"api.main:app",
|
| 341 |
+
host="0.0.0.0",
|
| 342 |
+
port=int(os.getenv("PORT", 7860)),
|
| 343 |
+
reload=False, # Set reload to False for production
|
| 344 |
+
)
|
api/schemas.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Dict, Any, Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ChatRequest(BaseModel):
|
| 7 |
+
message: str = Field(..., min_length=1, max_length=1000)
|
| 8 |
+
chat_history: List[Dict[str, str]] = Field(default_factory=list)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Source(BaseModel):
|
| 12 |
+
title: str
|
| 13 |
+
url: str
|
| 14 |
+
relevance_score: float
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChatResponse(BaseModel):
|
| 18 |
+
answer: str
|
| 19 |
+
sources: List[Source] = Field(default_factory=list)
|
| 20 |
+
suggested_questions: List[str] = Field(default_factory=list)
|
| 21 |
+
processing_time: Optional[float] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class UserPreferences(BaseModel):
|
| 25 |
+
favorite_destinations: List[str] = Field(default_factory=list)
|
| 26 |
+
travel_style: str = Field(default="balanced")
|
| 27 |
+
preferred_seasons: List[str] = Field(default_factory=list)
|
| 28 |
+
interests: List[str] = Field(default_factory=list)
|
| 29 |
+
dietary_restrictions: List[str] = Field(default_factory=list)
|
| 30 |
+
accessibility_needs: List[str] = Field(default_factory=list)
|
| 31 |
+
language: str = Field(default="en")
|
| 32 |
+
currency: str = Field(default="USD")
|
| 33 |
+
temperature_unit: str = Field(default="C")
|
| 34 |
+
timezone: str = Field(default="UTC")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class UserProfileResponse(BaseModel):
|
| 38 |
+
user_id: str
|
| 39 |
+
preferences: UserPreferences
|
| 40 |
+
created_at: datetime
|
| 41 |
+
updated_at: datetime
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class UserPreferencesUpdate(BaseModel):
|
| 45 |
+
favorite_destinations: Optional[List[str]] = None
|
| 46 |
+
travel_style: Optional[str] = None
|
| 47 |
+
preferred_seasons: Optional[List[str]] = None
|
| 48 |
+
interests: Optional[List[str]] = None
|
| 49 |
+
dietary_restrictions: Optional[List[str]] = None
|
| 50 |
+
accessibility_needs: Optional[List[str]] = None
|
| 51 |
+
language: Optional[str] = None
|
| 52 |
+
currency: Optional[str] = None
|
| 53 |
+
temperature_unit: Optional[str] = None
|
| 54 |
+
timezone: Optional[str] = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ErrorResponse(BaseModel):
|
| 58 |
+
error: str
|
| 59 |
+
detail: Optional[str] = None
|
| 60 |
+
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
|
| 61 |
+
request_id: Optional[str] = None
|
app.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
from logging.handlers import RotatingFileHandler
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 10 |
+
|
| 11 |
+
from core.rag_engine import RAGEngine
|
| 12 |
+
from core.user_profile import UserProfile
|
| 13 |
+
from config.config import settings
|
| 14 |
+
|
| 15 |
+
# ======================================================================================
|
| 16 |
+
# Logging Setup
|
| 17 |
+
# ======================================================================================
|
| 18 |
+
|
| 19 |
+
os.makedirs("logs", exist_ok=True)
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(
|
| 22 |
+
level=getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO),
|
| 23 |
+
format=settings.LOG_FORMAT,
|
| 24 |
+
handlers=[
|
| 25 |
+
logging.StreamHandler(sys.stdout),
|
| 26 |
+
RotatingFileHandler(
|
| 27 |
+
settings.LOG_FILE_PATH,
|
| 28 |
+
maxBytes=settings.LOG_FILE_MAX_BYTES,
|
| 29 |
+
backupCount=settings.LOG_FILE_BACKUP_COUNT
|
| 30 |
+
),
|
| 31 |
+
],
|
| 32 |
+
)
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
# ======================================================================================
|
| 36 |
+
# Core Module Initialization
|
| 37 |
+
# ======================================================================================
|
| 38 |
+
|
| 39 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
| 40 |
+
def initialize_with_retry(func):
|
| 41 |
+
"""Initializes a component with retry logic."""
|
| 42 |
+
try:
|
| 43 |
+
return func()
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Initialization failed: {e}", exc_info=True)
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
user_profile = initialize_with_retry(UserProfile)
|
| 50 |
+
rag_engine = initialize_with_retry(lambda: RAGEngine(user_profile=user_profile))
|
| 51 |
+
logger.info("Core modules initialized successfully.")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.critical(f"Fatal: Could not initialize core modules: {e}. Exiting.", exc_info=True)
|
| 54 |
+
sys.exit(1)
|
| 55 |
+
|
| 56 |
+
# ======================================================================================
|
| 57 |
+
# Business Logic
|
| 58 |
+
# ======================================================================================
|
| 59 |
+
|
| 60 |
+
async def handle_chat_interaction(
|
| 61 |
+
message: str, chat_history: List[List[str]], user_id: str, categories: List[str]
|
| 62 |
+
) -> List[List[str]]:
|
| 63 |
+
"""Handles the user's chat message, processes it, and updates the history."""
|
| 64 |
+
if not message.strip():
|
| 65 |
+
gr.Warning("Message cannot be empty. Please type a question.")
|
| 66 |
+
return chat_history
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
profile = user_profile.get_profile(user_id)
|
| 70 |
+
profile["preferences"]["favorite_categories"] = categories
|
| 71 |
+
user_profile.update_profile(user_id, profile)
|
| 72 |
+
logger.info(f"Updated preferences for user {user_id}: {categories}")
|
| 73 |
+
|
| 74 |
+
result = await rag_engine.process_query(query=message, user_id=user_id)
|
| 75 |
+
response = result.get("answer", "Sorry, I could not find an answer.")
|
| 76 |
+
|
| 77 |
+
sources = result.get("sources")
|
| 78 |
+
if sources:
|
| 79 |
+
response += "\n\n**Sources:**\n" + format_sources(sources)
|
| 80 |
+
|
| 81 |
+
chat_history.append((message, response))
|
| 82 |
+
logger.info(f"User {user_id} received response.")
|
| 83 |
+
return chat_history
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
error_message = f"An unexpected error occurred: {str(e)}"
|
| 87 |
+
logger.error(f"Error for user {user_id}: {error_message}", exc_info=True)
|
| 88 |
+
gr.Warning("Sorry, I encountered a problem. Please try again or rephrase your question.")
|
| 89 |
+
return chat_history
|
| 90 |
+
|
| 91 |
+
def format_sources(sources: List[Dict[str, Any]]) -> str:
|
| 92 |
+
"""Formats the source documents into a readable string."""
|
| 93 |
+
if not sources:
|
| 94 |
+
return ""
|
| 95 |
+
formatted_list = [f"- **{source.get('title', 'Unknown Source')}** (Category: {source.get('category', 'N/A')})" for source in sources]
|
| 96 |
+
return "\n".join(formatted_list)
|
| 97 |
+
|
| 98 |
+
# ======================================================================================
|
| 99 |
+
# Gradio UI Definition
|
| 100 |
+
# ======================================================================================
|
| 101 |
+
|
| 102 |
+
def handle_slider_change(value: int) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Handles the change event for the document loader slider.
|
| 105 |
+
Note: This currently only shows a notification. A restart is required.
|
| 106 |
+
"""
|
| 107 |
+
gr.Info(f"Document limit set to {int(value)}. Please restart the app for changes to take effect.")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def create_interface() -> gr.Blocks:
|
| 111 |
+
"""Creates and configures the Gradio web interface."""
|
| 112 |
+
|
| 113 |
+
with gr.Blocks(
|
| 114 |
+
title="TravelMate - Your AI Travel Assistant",
|
| 115 |
+
theme=gr.themes.Base(),
|
| 116 |
+
) as demo:
|
| 117 |
+
|
| 118 |
+
user_id = gr.State(lambda: str(uuid.uuid4()))
|
| 119 |
+
|
| 120 |
+
gr.Markdown("""
|
| 121 |
+
<div style="text-align: center;">
|
| 122 |
+
<h1 style="font-size: 2.5em;">✈️ TravelMate</h1>
|
| 123 |
+
<p style="font-size: 1.1em; color: #333;">Your AI-powered travel assistant. Ask me anything to plan your next trip!</p>
|
| 124 |
+
</div>
|
| 125 |
+
""")
|
| 126 |
+
|
| 127 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 128 |
+
doc_load_slider = gr.Slider(
|
| 129 |
+
minimum=100,
|
| 130 |
+
maximum=5000,
|
| 131 |
+
value=settings.MAX_DOCUMENTS_TO_LOAD,
|
| 132 |
+
step=100,
|
| 133 |
+
label="Documents to Load",
|
| 134 |
+
info="Controls how many documents are loaded for the RAG engine. Higher values may increase startup time.",
|
| 135 |
+
)
|
| 136 |
+
doc_load_slider.change(
|
| 137 |
+
fn=handle_slider_change, inputs=[doc_load_slider], outputs=None
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
with gr.Row():
|
| 141 |
+
with gr.Column(scale=2):
|
| 142 |
+
chatbot = gr.Chatbot(
|
| 143 |
+
elem_id="chatbot",
|
| 144 |
+
label="TravelMate Chat",
|
| 145 |
+
height=600,
|
| 146 |
+
show_label=False,
|
| 147 |
+
show_copy_button=True,
|
| 148 |
+
bubble_full_width=False,
|
| 149 |
+
avatar_images=("assets/user_avatar.png", "assets/bot_avatar.png"),
|
| 150 |
+
)
|
| 151 |
+
with gr.Row():
|
| 152 |
+
msg = gr.Textbox(
|
| 153 |
+
placeholder="Ask me about destinations, flights, hotels...",
|
| 154 |
+
show_label=False,
|
| 155 |
+
container=False,
|
| 156 |
+
scale=8,
|
| 157 |
+
)
|
| 158 |
+
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
| 159 |
+
|
| 160 |
+
with gr.Column(scale=1):
|
| 161 |
+
gr.Markdown("### Select Your Interests")
|
| 162 |
+
categories = gr.CheckboxGroup(
|
| 163 |
+
choices=[
|
| 164 |
+
"Flights", "Hotels", "Destinations", "Activities",
|
| 165 |
+
"Transportation", "Food & Dining", "Shopping",
|
| 166 |
+
"Health & Safety", "Budget Planning"
|
| 167 |
+
],
|
| 168 |
+
value=["Flights", "Hotels"],
|
| 169 |
+
label="Travel Categories",
|
| 170 |
+
)
|
| 171 |
+
gr.Markdown("### Example Questions")
|
| 172 |
+
gr.Examples(
|
| 173 |
+
examples=[
|
| 174 |
+
"What are the best places to visit in Japan?",
|
| 175 |
+
"How do I find cheap flights to Europe?",
|
| 176 |
+
"What should I pack for a beach vacation?",
|
| 177 |
+
"Tell me about local customs in Thailand",
|
| 178 |
+
"What's the best time to visit Paris?",
|
| 179 |
+
],
|
| 180 |
+
inputs=msg,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
async def on_submit(message: str, history: List[List[str]], uid: str, cats: List[str]) -> Tuple[str, List[List[str]]]:
|
| 184 |
+
"""Handles submission and returns updated values for the message box and chatbot."""
|
| 185 |
+
updated_history = await handle_chat_interaction(message, history, uid, cats)
|
| 186 |
+
return "", updated_history
|
| 187 |
+
|
| 188 |
+
submit_btn.click(on_submit, [msg, chatbot, user_id, categories], [msg, chatbot])
|
| 189 |
+
msg.submit(on_submit, [msg, chatbot, user_id, categories], [msg, chatbot])
|
| 190 |
+
|
| 191 |
+
return demo
|
| 192 |
+
|
| 193 |
+
# ======================================================================================
|
| 194 |
+
# Application Launch
|
| 195 |
+
# ======================================================================================
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
try:
|
| 199 |
+
app = create_interface()
|
| 200 |
+
app.queue(default_concurrency_limit=settings.GRADIO_CONCURRENCY_COUNT)
|
| 201 |
+
app.launch(
|
| 202 |
+
server_name=settings.GRADIO_SERVER_NAME,
|
| 203 |
+
server_port=settings.GRADIO_SERVER_PORT,
|
| 204 |
+
share=settings.GRADIO_SHARE,
|
| 205 |
+
show_error=True,
|
| 206 |
+
show_api=False,
|
| 207 |
+
)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.critical(f"Failed to launch Gradio app: {e}", exc_info=True)
|
| 210 |
+
raise
|
config/config.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from pydantic import model_validator
|
| 6 |
+
from pydantic_settings import BaseSettings
|
| 7 |
+
|
| 8 |
+
class Settings(BaseSettings):
|
| 9 |
+
"""Manages all application settings using Pydantic for robust configuration.
|
| 10 |
+
|
| 11 |
+
Attributes:
|
| 12 |
+
BASE_DIR (Path): The root directory of the project.
|
| 13 |
+
DATA_DIR (Path): The directory for storing data.
|
| 14 |
+
# ... other attributes
|
| 15 |
+
"""
|
| 16 |
+
# ----------------------------------------------------------------------------------
|
| 17 |
+
# Path Settings
|
| 18 |
+
# ----------------------------------------------------------------------------------
|
| 19 |
+
BASE_DIR: Path = Path(__file__).resolve().parent.parent
|
| 20 |
+
DATA_DIR: Path = BASE_DIR / "data"
|
| 21 |
+
VECTOR_STORE_DIR: Path = DATA_DIR / "vector_store"
|
| 22 |
+
USER_PROFILES_DIR: Path = DATA_DIR / "user_profiles"
|
| 23 |
+
CACHE_DIR: Path = DATA_DIR / "cache"
|
| 24 |
+
LOGS_DIR: Path = BASE_DIR / "logs"
|
| 25 |
+
|
| 26 |
+
# ----------------------------------------------------------------------------------
|
| 27 |
+
# Model & RAG Settings
|
| 28 |
+
# ----------------------------------------------------------------------------------
|
| 29 |
+
MODEL_NAME: str = "google/gemma-2b-it"
|
| 30 |
+
EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
|
| 31 |
+
DATASET_ID: str = "bitext/Bitext-travel-llm-chatbot-training-dataset"
|
| 32 |
+
HUGGINGFACE_API_TOKEN: Optional[str] = os.getenv("HUGGINGFACE_API_TOKEN")
|
| 33 |
+
|
| 34 |
+
# RAG pipeline settings
|
| 35 |
+
CHUNK_SIZE: int = 1000
|
| 36 |
+
CHUNK_OVERLAP: int = 100
|
| 37 |
+
MAX_DOCUMENTS_TO_LOAD: int = 50 # Drastically reduced for performance
|
| 38 |
+
TOP_K_RESULTS: int = 3
|
| 39 |
+
SIMILARITY_THRESHOLD: float = 0.7
|
| 40 |
+
|
| 41 |
+
# Model configuration for HuggingFaceEndpoint
|
| 42 |
+
TEMPERATURE: float = 0.7
|
| 43 |
+
MAX_NEW_TOKENS: int = 512 # Drastically reduced to combat latency
|
| 44 |
+
REPETITION_PENALTY: float = 1.2
|
| 45 |
+
|
| 46 |
+
# Production-Grade Prompt Template
|
| 47 |
+
QA_PROMPT_TEMPLATE: str = """You are TravelMate, an expert AI travel assistant.
|
| 48 |
+
Use the following context to answer the user's question concisely and helpfully.
|
| 49 |
+
If you don't know the answer, simply say that you don't know. Do not make up information.
|
| 50 |
+
|
| 51 |
+
Context:
|
| 52 |
+
{context}
|
| 53 |
+
|
| 54 |
+
Question: {input}
|
| 55 |
+
|
| 56 |
+
Answer:"""
|
| 57 |
+
|
| 58 |
+
# Cache settings
|
| 59 |
+
MAX_CACHE_SIZE: int = 1000
|
| 60 |
+
CACHE_TTL: int = 3600 # Time-to-live in seconds (1 hour)
|
| 61 |
+
|
| 62 |
+
# ----------------------------------------------------------------------------------
|
| 63 |
+
# Application Behavior Settings
|
| 64 |
+
# ----------------------------------------------------------------------------------
|
| 65 |
+
QUERY_TIMEOUT: int = 30 # seconds
|
| 66 |
+
MAX_MESSAGE_LENGTH: int = 500 # characters
|
| 67 |
+
MAX_CHAT_HISTORY: int = 20 # messages
|
| 68 |
+
|
| 69 |
+
# ----------------------------------------------------------------------------------
|
| 70 |
+
# API & Security Settings
|
| 71 |
+
# ----------------------------------------------------------------------------------
|
| 72 |
+
API_V1_STR: str = "/api/v1"
|
| 73 |
+
PROJECT_NAME: str = "TravelMate AI Assistant"
|
| 74 |
+
VERSION: str = "1.0.0"
|
| 75 |
+
DEBUG: bool = False
|
| 76 |
+
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
|
| 77 |
+
JWT_SECRET_KEY: str = os.getenv("JWT_SECRET_KEY", "a_very_secret_jwt_key")
|
| 78 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 # 24 hours
|
| 79 |
+
|
| 80 |
+
# ----------------------------------------------------------------------------------
|
| 81 |
+
# Logging Settings
|
| 82 |
+
# ----------------------------------------------------------------------------------
|
| 83 |
+
LOG_LEVEL: str = "INFO"
|
| 84 |
+
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 85 |
+
LOG_FILE_PATH: Path = LOGS_DIR / "app.log"
|
| 86 |
+
LOG_FILE_MAX_BYTES: int = 10 * 1024 * 1024 # 10MB
|
| 87 |
+
LOG_FILE_BACKUP_COUNT: int = 5
|
| 88 |
+
|
| 89 |
+
# ----------------------------------------------------------------------------------
|
| 90 |
+
# Gradio UI Settings
|
| 91 |
+
# ----------------------------------------------------------------------------------
|
| 92 |
+
GRADIO_SERVER_NAME: str = "0.0.0.0"
|
| 93 |
+
GRADIO_SERVER_PORT: int = 7860
|
| 94 |
+
GRADIO_SHARE: bool = True
|
| 95 |
+
GRADIO_CONCURRENCY_COUNT: int = 5
|
| 96 |
+
|
| 97 |
+
class Config:
|
| 98 |
+
env_file = ".env"
|
| 99 |
+
case_sensitive = True
|
| 100 |
+
|
| 101 |
+
@model_validator(mode='after')
|
| 102 |
+
def create_directories(self) -> 'Settings':
|
| 103 |
+
"""Ensures that necessary directories exist upon settings initialization."""
|
| 104 |
+
self.DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 105 |
+
self.VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
|
| 106 |
+
self.USER_PROFILES_DIR.mkdir(parents=True, exist_ok=True)
|
| 107 |
+
self.CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
self.LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
return self
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
settings = Settings()
|
core/data_loader.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import stat
|
| 5 |
+
import time
|
| 6 |
+
from typing import Any, List
|
| 7 |
+
|
| 8 |
+
from config.config import settings
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from langchain.schema import Document
|
| 11 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 12 |
+
from langchain_community.vectorstores import FAISS
|
| 13 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataLoader:
|
| 19 |
+
"""Handles loading and processing of data for the RAG engine."""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""Initialize the data loader."""
|
| 23 |
+
self.data_dir = os.path.abspath("data")
|
| 24 |
+
self.travel_guides_path = os.path.join(self.data_dir, "travel_guides.json")
|
| 25 |
+
self.vector_store_path = os.path.join(self.data_dir, "vector_store", "faiss_index")
|
| 26 |
+
self._ensure_data_directories()
|
| 27 |
+
self._set_directory_permissions()
|
| 28 |
+
|
| 29 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 30 |
+
chunk_size=settings.CHUNK_SIZE,
|
| 31 |
+
chunk_overlap=settings.CHUNK_OVERLAP,
|
| 32 |
+
length_function=len,
|
| 33 |
+
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
|
| 34 |
+
)
|
| 35 |
+
self.max_file_size = 10 * 1024 * 1024 # 10MB
|
| 36 |
+
|
| 37 |
+
def _ensure_data_directories(self):
|
| 38 |
+
"""Ensure necessary data directories exist."""
|
| 39 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
| 40 |
+
os.makedirs(os.path.dirname(self.vector_store_path), exist_ok=True)
|
| 41 |
+
os.makedirs(os.path.join(self.data_dir, "cache"), exist_ok=True)
|
| 42 |
+
|
| 43 |
+
def _set_directory_permissions(self):
|
| 44 |
+
"""Set secure permissions for data directories (755)."""
|
| 45 |
+
try:
|
| 46 |
+
for dir_path in [
|
| 47 |
+
self.data_dir,
|
| 48 |
+
os.path.dirname(self.vector_store_path),
|
| 49 |
+
os.path.join(self.data_dir, "cache"),
|
| 50 |
+
]:
|
| 51 |
+
os.chmod(
|
| 52 |
+
dir_path,
|
| 53 |
+
stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH,
|
| 54 |
+
)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Error setting directory permissions: {e}", exc_info=True)
|
| 57 |
+
|
| 58 |
+
def _validate_file_permissions(self, file_path: str) -> bool:
|
| 59 |
+
"""Validate file permissions to ensure security."""
|
| 60 |
+
try:
|
| 61 |
+
if not os.path.exists(file_path):
|
| 62 |
+
return False
|
| 63 |
+
file_stat = os.stat(file_path)
|
| 64 |
+
if file_stat.st_mode & stat.S_IWOTH: # Disallow world-writable
|
| 65 |
+
logger.warning(f"File {file_path} is world-writable. Skipping.")
|
| 66 |
+
return False
|
| 67 |
+
if file_stat.st_size > self.max_file_size:
|
| 68 |
+
logger.warning(f"File {file_path} exceeds size limit. Skipping.")
|
| 69 |
+
return False
|
| 70 |
+
return True
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Error validating file permissions for {file_path}: {e}", exc_info=True)
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
def _load_dataset_with_retry(self, max_retries: int = 3) -> Any:
|
| 76 |
+
"""Load dataset from Hugging Face with an exponential backoff retry mechanism."""
|
| 77 |
+
for attempt in range(max_retries):
|
| 78 |
+
try:
|
| 79 |
+
return load_dataset(
|
| 80 |
+
settings.DATASET_ID,
|
| 81 |
+
split="train",
|
| 82 |
+
cache_dir=os.path.join(self.data_dir, "cache"),
|
| 83 |
+
)
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"Dataset loading attempt {attempt + 1} failed: {e}")
|
| 86 |
+
if attempt == max_retries - 1:
|
| 87 |
+
logger.error("All attempts to load dataset failed.")
|
| 88 |
+
return None
|
| 89 |
+
time.sleep(2 ** attempt)
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
def load_documents(self) -> List[Document]:
|
| 93 |
+
"""Load and process all documents for the knowledge base."""
|
| 94 |
+
documents = []
|
| 95 |
+
try:
|
| 96 |
+
# 1. Load Bitext Travel Dataset
|
| 97 |
+
logger.info(f"Loading dataset: {settings.DATASET_ID}")
|
| 98 |
+
dataset = self._load_dataset_with_retry()
|
| 99 |
+
if dataset:
|
| 100 |
+
max_docs = settings.MAX_DOCUMENTS_TO_LOAD
|
| 101 |
+
logger.info(f"Loading up to {max_docs} documents from the dataset.")
|
| 102 |
+
for i, item in enumerate(dataset):
|
| 103 |
+
if i >= max_docs:
|
| 104 |
+
logger.info(f"Reached document limit ({max_docs}).")
|
| 105 |
+
break
|
| 106 |
+
instruction = item.get("instruction")
|
| 107 |
+
response = item.get("response")
|
| 108 |
+
|
| 109 |
+
if not instruction or not response:
|
| 110 |
+
logger.warning(f"Skipping item with missing instruction or response: {item}")
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
page_content = f"User query: {instruction}\n\nChatbot response: {response}"
|
| 114 |
+
metadata = {
|
| 115 |
+
"source": "huggingface",
|
| 116 |
+
"intent": item.get("intent"),
|
| 117 |
+
"category": item.get("category"),
|
| 118 |
+
"tags": item.get("tags"),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
documents.append(Document(page_content=page_content, metadata=metadata))
|
| 122 |
+
|
| 123 |
+
# 2. Load Local Travel Guides
|
| 124 |
+
logger.info("Loading local travel guides...")
|
| 125 |
+
if os.path.exists(self.travel_guides_path) and self._validate_file_permissions(self.travel_guides_path):
|
| 126 |
+
with open(self.travel_guides_path, "r", encoding="utf-8") as f:
|
| 127 |
+
guides = json.load(f)
|
| 128 |
+
for guide in guides:
|
| 129 |
+
if not all(k in guide for k in ["title", "content", "category"]):
|
| 130 |
+
logger.warning(f"Skipping malformed guide: {guide}")
|
| 131 |
+
continue
|
| 132 |
+
doc = Document(
|
| 133 |
+
page_content=guide["content"],
|
| 134 |
+
metadata={
|
| 135 |
+
"title": guide["title"],
|
| 136 |
+
"category": guide["category"],
|
| 137 |
+
"source": "travel_guide",
|
| 138 |
+
},
|
| 139 |
+
)
|
| 140 |
+
documents.append(doc)
|
| 141 |
+
else:
|
| 142 |
+
logger.info("Travel guides file not found or invalid. Skipping.")
|
| 143 |
+
|
| 144 |
+
logger.info(f"Loaded {len(documents)} documents in total.")
|
| 145 |
+
return documents
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"A critical error occurred while loading documents: {e}", exc_info=True)
|
| 149 |
+
return []
|
| 150 |
+
|
| 151 |
+
def create_vector_store(self, documents: List[Document]):
|
| 152 |
+
"""Create a FAISS vector store from documents."""
|
| 153 |
+
try:
|
| 154 |
+
logger.info("Creating vector store...")
|
| 155 |
+
embeddings = HuggingFaceEmbeddings(
|
| 156 |
+
model_name=settings.EMBEDDING_MODEL_NAME,
|
| 157 |
+
model_kwargs={"device": "cpu"},
|
| 158 |
+
encode_kwargs={"normalize_embeddings": True},
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
split_docs = self.text_splitter.split_documents(documents)
|
| 162 |
+
|
| 163 |
+
vector_store = FAISS.from_documents(
|
| 164 |
+
documents=split_docs,
|
| 165 |
+
embedding=embeddings,
|
| 166 |
+
)
|
| 167 |
+
vector_store.save_local(self.vector_store_path)
|
| 168 |
+
logger.info(f"Vector store created and saved to {self.vector_store_path} with {len(split_docs)} chunks.")
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logger.error(f"Error creating vector store: {e}", exc_info=True)
|
| 171 |
+
raise
|
| 172 |
+
|
| 173 |
+
def initialize_knowledge_base(self):
|
| 174 |
+
"""Initialize the complete knowledge base."""
|
| 175 |
+
try:
|
| 176 |
+
logger.info("Initializing knowledge base...")
|
| 177 |
+
documents = self.load_documents()
|
| 178 |
+
if not documents:
|
| 179 |
+
logger.error("No documents were loaded. Aborting knowledge base initialization.")
|
| 180 |
+
return
|
| 181 |
+
self.create_vector_store(documents)
|
| 182 |
+
logger.info("Knowledge base initialized successfully.")
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.critical(f"Failed to initialize knowledge base: {e}", exc_info=True)
|
| 185 |
+
raise
|
core/rag_engine.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
|
| 4 |
+
from langchain.prompts import PromptTemplate
|
| 5 |
+
from langchain_community.vectorstores import FAISS
|
| 6 |
+
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
| 7 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 8 |
+
from langchain.chains import create_retrieval_chain
|
| 9 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 10 |
+
from langchain_core.documents import Document
|
| 11 |
+
from expiringdict import ExpiringDict
|
| 12 |
+
|
| 13 |
+
from core.data_loader import DataLoader
|
| 14 |
+
from core.user_profile import UserProfile
|
| 15 |
+
from config.config import settings
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class RAGEngine:
|
| 20 |
+
"""
|
| 21 |
+
The core Retrieval-Augmented Generation engine for the TravelMate chatbot.
|
| 22 |
+
This class handles model initialization, vector store creation, and query processing.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, user_profile: UserProfile):
|
| 26 |
+
"""
|
| 27 |
+
Initializes the RAG engine, loading models and setting up the QA chain.
|
| 28 |
+
"""
|
| 29 |
+
self.user_profile = user_profile
|
| 30 |
+
self.query_cache = ExpiringDict(max_len=settings.MAX_CACHE_SIZE, max_age_seconds=settings.CACHE_TTL)
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
self.embeddings = self._initialize_embeddings()
|
| 34 |
+
self.vector_store = self._initialize_vector_store()
|
| 35 |
+
self.llm = self._initialize_llm()
|
| 36 |
+
self.qa_chain = self._create_rag_chain()
|
| 37 |
+
logger.info("RAG Engine initialized successfully.")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.critical(f"Failed to initialize RAG Engine: {e}", exc_info=True)
|
| 40 |
+
raise
|
| 41 |
+
|
| 42 |
+
def _initialize_embeddings(self) -> HuggingFaceEmbeddings:
|
| 43 |
+
"""Initializes the sentence-transformer embeddings model."""
|
| 44 |
+
return HuggingFaceEmbeddings(
|
| 45 |
+
model_name=settings.EMBEDDING_MODEL_NAME,
|
| 46 |
+
model_kwargs={'device': 'cpu'}
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def _initialize_vector_store(self) -> FAISS:
|
| 50 |
+
"""
|
| 51 |
+
Initializes the FAISS vector store.
|
| 52 |
+
Loads from disk if it exists, otherwise creates it from the data loader.
|
| 53 |
+
"""
|
| 54 |
+
if settings.VECTOR_STORE_DIR.exists() and any(settings.VECTOR_STORE_DIR.iterdir()):
|
| 55 |
+
logger.info(f"Loading existing vector store from {settings.VECTOR_STORE_DIR}...")
|
| 56 |
+
return FAISS.load_local(
|
| 57 |
+
folder_path=str(settings.VECTOR_STORE_DIR),
|
| 58 |
+
embeddings=self.embeddings,
|
| 59 |
+
allow_dangerous_deserialization=True
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
logger.info("Creating new vector store from scratch.")
|
| 63 |
+
data_loader = DataLoader()
|
| 64 |
+
documents = data_loader.load_documents()
|
| 65 |
+
|
| 66 |
+
if not documents:
|
| 67 |
+
raise ValueError("No documents were loaded. Cannot create vector store.")
|
| 68 |
+
|
| 69 |
+
vector_store = FAISS.from_documents(documents, self.embeddings)
|
| 70 |
+
logger.info(f"Saving new vector store to {settings.VECTOR_STORE_DIR}...")
|
| 71 |
+
vector_store.save_local(str(settings.VECTOR_STORE_DIR))
|
| 72 |
+
return vector_store
|
| 73 |
+
|
| 74 |
+
def _initialize_llm(self) -> HuggingFaceEndpoint:
|
| 75 |
+
"""Initializes the Hugging Face Inference Endpoint for the LLM."""
|
| 76 |
+
if not settings.HUGGINGFACE_API_TOKEN:
|
| 77 |
+
raise ValueError("HUGGINGFACE_API_TOKEN is not set.")
|
| 78 |
+
|
| 79 |
+
return HuggingFaceEndpoint(
|
| 80 |
+
repo_id=settings.MODEL_NAME,
|
| 81 |
+
huggingfacehub_api_token=settings.HUGGINGFACE_API_TOKEN,
|
| 82 |
+
temperature=settings.TEMPERATURE,
|
| 83 |
+
max_new_tokens=settings.MAX_NEW_TOKENS,
|
| 84 |
+
repetition_penalty=settings.REPETITION_PENALTY,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def _create_rag_chain(self):
|
| 88 |
+
"""Creates a modern, streamlined RAG chain for question answering."""
|
| 89 |
+
qa_prompt = PromptTemplate.from_template(settings.QA_PROMPT_TEMPLATE)
|
| 90 |
+
|
| 91 |
+
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
|
| 92 |
+
|
| 93 |
+
retriever = self.vector_store.as_retriever(
|
| 94 |
+
search_type="similarity_score_threshold",
|
| 95 |
+
search_kwargs={'k': settings.TOP_K_RESULTS, 'score_threshold': settings.SIMILARITY_THRESHOLD}
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 99 |
+
return rag_chain
|
| 100 |
+
|
| 101 |
+
def _format_sources(self, sources: List[Document]) -> List[Dict[str, Any]]:
|
| 102 |
+
"""Formats source documents into a serializable list of dictionaries."""
|
| 103 |
+
if not sources:
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
formatted_list = []
|
| 107 |
+
for source in sources:
|
| 108 |
+
metadata = source.metadata
|
| 109 |
+
source_name = metadata.get('source', 'Unknown Source')
|
| 110 |
+
|
| 111 |
+
if source_name == 'huggingface':
|
| 112 |
+
title = f"Dataset: {metadata.get('intent', 'N/A')}"
|
| 113 |
+
category = metadata.get('category', 'N/A')
|
| 114 |
+
elif source_name == 'local_guides':
|
| 115 |
+
title = f"Guide: {metadata.get('title', 'N/A')}"
|
| 116 |
+
category = metadata.get('category', 'N/A')
|
| 117 |
+
else:
|
| 118 |
+
title = "Unknown Source"
|
| 119 |
+
category = "N/A"
|
| 120 |
+
|
| 121 |
+
formatted_list.append({"title": title, "category": category})
|
| 122 |
+
|
| 123 |
+
return formatted_list
|
| 124 |
+
|
| 125 |
+
async def process_query(self, query: str, user_id: str) -> Dict[str, Any]:
|
| 126 |
+
"""Processes a user query asynchronously using the streamlined RAG chain."""
|
| 127 |
+
cache_key = f"{user_id}:{query}"
|
| 128 |
+
if cache_key in self.query_cache:
|
| 129 |
+
logger.info(f"Returning cached response for query: {query}")
|
| 130 |
+
return self.query_cache[cache_key]
|
| 131 |
+
|
| 132 |
+
logger.info(f"Processing query for user {user_id}: {query}")
|
| 133 |
+
|
| 134 |
+
# The new chain expects 'input' instead of 'question'
|
| 135 |
+
chain_input = {"input": query}
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
result = await self.qa_chain.ainvoke(chain_input)
|
| 139 |
+
|
| 140 |
+
answer = result.get("answer", "Sorry, I couldn't find an answer.")
|
| 141 |
+
# The new chain returns retrieved documents in the 'context' key
|
| 142 |
+
sources = self._format_sources(result.get("context", []))
|
| 143 |
+
|
| 144 |
+
response = {"answer": answer, "sources": sources}
|
| 145 |
+
self.query_cache[cache_key] = response
|
| 146 |
+
|
| 147 |
+
logger.info(f"Successfully processed query for user {user_id}")
|
| 148 |
+
return response
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"Error processing query for user {user_id}: {e}", exc_info=True)
|
| 151 |
+
return {"answer": "I'm sorry, but I encountered an error while processing your request.", "sources": []}
|
core/user_profile.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Optional, List
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
import logging
|
| 6 |
+
from pydantic import BaseModel, Field, validator
|
| 7 |
+
from enum import Enum
|
| 8 |
+
import time
|
| 9 |
+
import re
|
| 10 |
+
from config.config import settings
|
| 11 |
+
import threading
|
| 12 |
+
import shutil
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import hashlib
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Global lock for profile operations
|
| 19 |
+
_profile_lock = threading.Lock()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TravelStyle(str, Enum):
|
| 23 |
+
BUDGET = "budget"
|
| 24 |
+
LUXURY = "luxury"
|
| 25 |
+
BALANCED = "balanced"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class UserPreferences(BaseModel):
|
| 29 |
+
"""User preferences model with validation."""
|
| 30 |
+
|
| 31 |
+
travel_style: str = Field(
|
| 32 |
+
default="balanced",
|
| 33 |
+
description="Preferred travel style (budget, luxury, balanced)",
|
| 34 |
+
)
|
| 35 |
+
preferred_destinations: list = Field(
|
| 36 |
+
default_factory=list,
|
| 37 |
+
description="List of preferred travel destinations",
|
| 38 |
+
)
|
| 39 |
+
dietary_restrictions: list = Field(
|
| 40 |
+
default_factory=list,
|
| 41 |
+
description="List of dietary restrictions",
|
| 42 |
+
)
|
| 43 |
+
accessibility_needs: list = Field(
|
| 44 |
+
default_factory=list,
|
| 45 |
+
description="List of accessibility requirements",
|
| 46 |
+
)
|
| 47 |
+
preferred_activities: list = Field(
|
| 48 |
+
default_factory=list,
|
| 49 |
+
description="List of preferred activities",
|
| 50 |
+
)
|
| 51 |
+
budget_range: Dict[str, float] = Field(
|
| 52 |
+
default_factory=lambda: {"min": 0, "max": float("inf")},
|
| 53 |
+
description="Budget range for travel",
|
| 54 |
+
)
|
| 55 |
+
preferred_accommodation: str = Field(
|
| 56 |
+
default="hotel",
|
| 57 |
+
description="Preferred type of accommodation",
|
| 58 |
+
)
|
| 59 |
+
preferred_transportation: str = Field(
|
| 60 |
+
default="flexible",
|
| 61 |
+
description="Preferred mode of transportation",
|
| 62 |
+
)
|
| 63 |
+
travel_frequency: str = Field(
|
| 64 |
+
default="occasional",
|
| 65 |
+
description="How often the user travels",
|
| 66 |
+
)
|
| 67 |
+
preferred_seasons: list = Field(
|
| 68 |
+
default_factory=list,
|
| 69 |
+
description="Preferred travel seasons",
|
| 70 |
+
)
|
| 71 |
+
special_requirements: list = Field(
|
| 72 |
+
default_factory=list,
|
| 73 |
+
description="Any special travel requirements",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
@validator("travel_style")
|
| 77 |
+
def validate_travel_style(cls, v):
|
| 78 |
+
allowed_styles = ["budget", "luxury", "balanced"]
|
| 79 |
+
if v not in allowed_styles:
|
| 80 |
+
raise ValueError(f"Travel style must be one of {allowed_styles}")
|
| 81 |
+
return v
|
| 82 |
+
|
| 83 |
+
@validator("preferred_accommodation")
|
| 84 |
+
def validate_accommodation(cls, v):
|
| 85 |
+
allowed_types = [
|
| 86 |
+
"hotel",
|
| 87 |
+
"hostel",
|
| 88 |
+
"apartment",
|
| 89 |
+
"resort",
|
| 90 |
+
"camping",
|
| 91 |
+
"flexible",
|
| 92 |
+
]
|
| 93 |
+
if v not in allowed_types:
|
| 94 |
+
raise ValueError(f"Accommodation type must be one of {allowed_types}")
|
| 95 |
+
return v
|
| 96 |
+
|
| 97 |
+
@validator("preferred_transportation")
|
| 98 |
+
def validate_transportation(cls, v):
|
| 99 |
+
allowed_types = [
|
| 100 |
+
"car",
|
| 101 |
+
"train",
|
| 102 |
+
"bus",
|
| 103 |
+
"plane",
|
| 104 |
+
"flexible",
|
| 105 |
+
]
|
| 106 |
+
if v not in allowed_types:
|
| 107 |
+
raise ValueError(f"Transportation type must be one of {allowed_types}")
|
| 108 |
+
return v
|
| 109 |
+
|
| 110 |
+
@validator("travel_frequency")
|
| 111 |
+
def validate_frequency(cls, v):
|
| 112 |
+
allowed_frequencies = [
|
| 113 |
+
"rarely",
|
| 114 |
+
"occasional",
|
| 115 |
+
"frequent",
|
| 116 |
+
"very_frequent",
|
| 117 |
+
]
|
| 118 |
+
if v not in allowed_frequencies:
|
| 119 |
+
raise ValueError(f"Travel frequency must be one of {allowed_frequencies}")
|
| 120 |
+
return v
|
| 121 |
+
|
| 122 |
+
@validator("budget_range")
|
| 123 |
+
def validate_budget(cls, v):
|
| 124 |
+
if v["min"] < 0:
|
| 125 |
+
raise ValueError("Minimum budget cannot be negative")
|
| 126 |
+
if v["max"] < v["min"]:
|
| 127 |
+
raise ValueError("Maximum budget must be greater than minimum budget")
|
| 128 |
+
return v
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class UserProfile:
|
| 132 |
+
def __init__(self):
|
| 133 |
+
"""Initialize the user profile manager."""
|
| 134 |
+
self.profiles_dir = os.path.join("data", "user_profiles")
|
| 135 |
+
self.backup_dir = os.path.join("data", "user_profiles_backup")
|
| 136 |
+
self._ensure_directories()
|
| 137 |
+
self.rate_limit_window = 3600 # 1 hour
|
| 138 |
+
self.max_updates_per_window = 10
|
| 139 |
+
self.update_history: Dict[str, list] = {}
|
| 140 |
+
self.max_profile_size = 1024 * 1024 # 1MB
|
| 141 |
+
|
| 142 |
+
def _ensure_directories(self):
|
| 143 |
+
"""Ensure necessary directories exist."""
|
| 144 |
+
os.makedirs(self.profiles_dir, exist_ok=True)
|
| 145 |
+
os.makedirs(self.backup_dir, exist_ok=True)
|
| 146 |
+
|
| 147 |
+
def _validate_user_id(self, user_id: str) -> bool:
|
| 148 |
+
"""Validate user ID format."""
|
| 149 |
+
if not user_id or not isinstance(user_id, str):
|
| 150 |
+
return False
|
| 151 |
+
# Allow alphanumeric characters, hyphens, and underscores
|
| 152 |
+
return bool(re.match(r"^[a-zA-Z0-9-_]+$", user_id))
|
| 153 |
+
|
| 154 |
+
def _check_rate_limit(self, user_id: str) -> bool:
|
| 155 |
+
"""Check if user has exceeded rate limit."""
|
| 156 |
+
current_time = time.time()
|
| 157 |
+
if user_id not in self.update_history:
|
| 158 |
+
self.update_history[user_id] = []
|
| 159 |
+
|
| 160 |
+
# Remove old entries
|
| 161 |
+
self.update_history[user_id] = [
|
| 162 |
+
t
|
| 163 |
+
for t in self.update_history[user_id]
|
| 164 |
+
if current_time - t < self.rate_limit_window
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
# Check if limit exceeded
|
| 168 |
+
if len(self.update_history[user_id]) >= self.max_updates_per_window:
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
# Add new entry
|
| 172 |
+
self.update_history[user_id].append(current_time)
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
def _create_backup(self, user_id: str) -> None:
|
| 176 |
+
"""Create a backup of the user profile."""
|
| 177 |
+
try:
|
| 178 |
+
profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
|
| 179 |
+
if not os.path.exists(profile_path):
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
# Create backup with timestamp
|
| 183 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 184 |
+
backup_path = os.path.join(self.backup_dir, f"{user_id}_{timestamp}.json")
|
| 185 |
+
shutil.copy2(profile_path, backup_path)
|
| 186 |
+
|
| 187 |
+
# Keep only the last 5 backups
|
| 188 |
+
backups = sorted(
|
| 189 |
+
Path(self.backup_dir).glob(f"{user_id}_*.json"),
|
| 190 |
+
key=lambda x: x.stat().st_mtime,
|
| 191 |
+
reverse=True,
|
| 192 |
+
)
|
| 193 |
+
for old_backup in backups[5:]:
|
| 194 |
+
old_backup.unlink()
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(
|
| 198 |
+
f"Error creating backup for {user_id}: {str(e)}", exc_info=True
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _cleanup_old_profiles(self):
|
| 202 |
+
"""Clean up profiles older than 30 days."""
|
| 203 |
+
try:
|
| 204 |
+
current_time = time.time()
|
| 205 |
+
for filename in os.listdir(self.profiles_dir):
|
| 206 |
+
if not filename.endswith(".json"):
|
| 207 |
+
continue
|
| 208 |
+
|
| 209 |
+
file_path = os.path.join(self.profiles_dir, filename)
|
| 210 |
+
file_time = os.path.getmtime(file_path)
|
| 211 |
+
|
| 212 |
+
if current_time - file_time > 30 * 24 * 3600: # 30 days
|
| 213 |
+
try:
|
| 214 |
+
# Create final backup before deletion
|
| 215 |
+
user_id = filename[:-5] # Remove .json extension
|
| 216 |
+
self._create_backup(user_id)
|
| 217 |
+
os.remove(file_path)
|
| 218 |
+
logger.info(f"Removed old profile: {filename}")
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.warning(
|
| 221 |
+
f"Error removing old profile {filename}: {str(e)}"
|
| 222 |
+
)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"Error cleaning up old profiles: {str(e)}", exc_info=True)
|
| 225 |
+
|
| 226 |
+
def get_profile(self, user_id: str) -> Dict[str, Any]:
|
| 227 |
+
"""Get user profile with validation."""
|
| 228 |
+
try:
|
| 229 |
+
if not self._validate_user_id(user_id):
|
| 230 |
+
raise ValueError("Invalid user ID format")
|
| 231 |
+
|
| 232 |
+
profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
|
| 233 |
+
|
| 234 |
+
with _profile_lock:
|
| 235 |
+
if not os.path.exists(profile_path):
|
| 236 |
+
return self._create_default_profile(user_id)
|
| 237 |
+
|
| 238 |
+
# Check file size
|
| 239 |
+
if os.path.getsize(profile_path) > self.max_profile_size:
|
| 240 |
+
raise ValueError("Profile file size exceeds limit")
|
| 241 |
+
|
| 242 |
+
with open(profile_path, "r", encoding="utf-8") as f:
|
| 243 |
+
profile = json.load(f)
|
| 244 |
+
|
| 245 |
+
# Validate profile structure
|
| 246 |
+
if not isinstance(profile, dict):
|
| 247 |
+
raise ValueError("Invalid profile format")
|
| 248 |
+
|
| 249 |
+
# Ensure all required fields exist
|
| 250 |
+
required_fields = ["user_id", "preferences", "created_at", "updated_at"]
|
| 251 |
+
if not all(field in profile for field in required_fields):
|
| 252 |
+
raise ValueError("Missing required profile fields")
|
| 253 |
+
|
| 254 |
+
return profile
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.error(
|
| 258 |
+
f"Error getting profile for {user_id}: {str(e)}", exc_info=True
|
| 259 |
+
)
|
| 260 |
+
raise
|
| 261 |
+
|
| 262 |
+
def update_profile(
|
| 263 |
+
self, user_id: str, preferences: Dict[str, Any]
|
| 264 |
+
) -> Dict[str, Any]:
|
| 265 |
+
"""Update user profile with validation and rate limiting."""
|
| 266 |
+
try:
|
| 267 |
+
if not self._validate_user_id(user_id):
|
| 268 |
+
raise ValueError("Invalid user ID format")
|
| 269 |
+
|
| 270 |
+
if not self._check_rate_limit(user_id):
|
| 271 |
+
raise ValueError("Rate limit exceeded")
|
| 272 |
+
|
| 273 |
+
# Validate preferences
|
| 274 |
+
try:
|
| 275 |
+
validated_preferences = UserPreferences(**preferences)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
raise ValueError(f"Invalid preferences: {str(e)}")
|
| 278 |
+
|
| 279 |
+
profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
|
| 280 |
+
|
| 281 |
+
with _profile_lock:
|
| 282 |
+
# Create backup before update
|
| 283 |
+
self._create_backup(user_id)
|
| 284 |
+
|
| 285 |
+
current_profile = self.get_profile(user_id)
|
| 286 |
+
|
| 287 |
+
# Update profile
|
| 288 |
+
current_profile["preferences"] = validated_preferences.dict()
|
| 289 |
+
current_profile["updated_at"] = datetime.utcnow().isoformat()
|
| 290 |
+
|
| 291 |
+
# Save updated profile
|
| 292 |
+
with open(profile_path, "w", encoding="utf-8") as f:
|
| 293 |
+
json.dump(current_profile, f, indent=2)
|
| 294 |
+
|
| 295 |
+
logger.info(f"Updated profile for user {user_id}")
|
| 296 |
+
return current_profile
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.error(
|
| 300 |
+
f"Error updating profile for {user_id}: {str(e)}", exc_info=True
|
| 301 |
+
)
|
| 302 |
+
raise
|
| 303 |
+
|
| 304 |
+
def _create_default_profile(self, user_id: str) -> Dict[str, Any]:
|
| 305 |
+
"""Create a default profile with validation."""
|
| 306 |
+
try:
|
| 307 |
+
if not self._validate_user_id(user_id):
|
| 308 |
+
raise ValueError("Invalid user ID format")
|
| 309 |
+
|
| 310 |
+
default_preferences = UserPreferences().dict()
|
| 311 |
+
profile = {
|
| 312 |
+
"user_id": user_id,
|
| 313 |
+
"preferences": default_preferences,
|
| 314 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 315 |
+
"updated_at": datetime.utcnow().isoformat(),
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
|
| 319 |
+
|
| 320 |
+
with _profile_lock:
|
| 321 |
+
with open(profile_path, "w", encoding="utf-8") as f:
|
| 322 |
+
json.dump(profile, f, indent=2)
|
| 323 |
+
|
| 324 |
+
logger.info(f"Created default profile for user {user_id}")
|
| 325 |
+
return profile
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
logger.error(
|
| 329 |
+
f"Error creating default profile for {user_id}: {str(e)}", exc_info=True
|
| 330 |
+
)
|
| 331 |
+
raise
|
| 332 |
+
|
| 333 |
+
def delete_profile(self, user_id: str) -> None:
|
| 334 |
+
"""Delete user profile with validation."""
|
| 335 |
+
try:
|
| 336 |
+
if not self._validate_user_id(user_id):
|
| 337 |
+
raise ValueError("Invalid user ID format")
|
| 338 |
+
|
| 339 |
+
profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
|
| 340 |
+
|
| 341 |
+
with _profile_lock:
|
| 342 |
+
if os.path.exists(profile_path):
|
| 343 |
+
# Create final backup before deletion
|
| 344 |
+
self._create_backup(user_id)
|
| 345 |
+
os.remove(profile_path)
|
| 346 |
+
logger.info(f"Deleted profile for user {user_id}")
|
| 347 |
+
else:
|
| 348 |
+
logger.warning(f"Profile not found for user {user_id}")
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(
|
| 352 |
+
f"Error deleting profile for {user_id}: {str(e)}", exc_info=True
|
| 353 |
+
)
|
| 354 |
+
raise
|
| 355 |
+
|
| 356 |
+
def get_recommendations(self, user_id: str) -> Dict[str, Any]:
|
| 357 |
+
"""Get personalized recommendations based on user profile with validation."""
|
| 358 |
+
try:
|
| 359 |
+
profile = self.get_profile(user_id)
|
| 360 |
+
if not profile or "preferences" not in profile:
|
| 361 |
+
return {}
|
| 362 |
+
|
| 363 |
+
preferences = UserPreferences(**profile["preferences"])
|
| 364 |
+
recommendations = {
|
| 365 |
+
"destinations": self._get_destination_recommendations(preferences),
|
| 366 |
+
"activities": self._get_activity_recommendations(preferences),
|
| 367 |
+
"tips": self._get_personalized_tips(preferences),
|
| 368 |
+
"generated_at": datetime.now().isoformat(),
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
return recommendations
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"Error getting recommendations: {str(e)}", exc_info=True)
|
| 374 |
+
return {}
|
| 375 |
+
|
| 376 |
+
def _get_destination_recommendations(self, profile: UserPreferences) -> List[str]:
|
| 377 |
+
"""Get destination recommendations based on preferences."""
|
| 378 |
+
try:
|
| 379 |
+
recommendations = []
|
| 380 |
+
|
| 381 |
+
# Add recommendations based on favorite destinations
|
| 382 |
+
if profile.preferred_destinations:
|
| 383 |
+
recommendations.extend(profile.preferred_destinations[:3])
|
| 384 |
+
|
| 385 |
+
# Add recommendations based on interests
|
| 386 |
+
if "beach" in profile.preferred_activities:
|
| 387 |
+
recommendations.append("Bali, Indonesia")
|
| 388 |
+
if "culture" in profile.preferred_activities:
|
| 389 |
+
recommendations.append("Kyoto, Japan")
|
| 390 |
+
if "food" in profile.preferred_activities:
|
| 391 |
+
recommendations.append("Bangkok, Thailand")
|
| 392 |
+
|
| 393 |
+
# Add recommendations based on travel style
|
| 394 |
+
if profile.travel_style == TravelStyle.LUXURY:
|
| 395 |
+
recommendations.append("Dubai, UAE")
|
| 396 |
+
elif profile.travel_style == TravelStyle.BUDGET:
|
| 397 |
+
recommendations.append("Bangkok, Thailand")
|
| 398 |
+
|
| 399 |
+
return list(set(recommendations))[:5] # Return top 5 unique recommendations
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(
|
| 402 |
+
f"Error getting destination recommendations: {str(e)}", exc_info=True
|
| 403 |
+
)
|
| 404 |
+
return []
|
| 405 |
+
|
| 406 |
+
def _get_activity_recommendations(self, profile: UserPreferences) -> List[str]:
|
| 407 |
+
"""Get activity recommendations based on preferences."""
|
| 408 |
+
try:
|
| 409 |
+
activities = []
|
| 410 |
+
|
| 411 |
+
# Add activities based on interests
|
| 412 |
+
if "culture" in profile.preferred_activities:
|
| 413 |
+
activities.append("Visit local museums and historical sites")
|
| 414 |
+
if "food" in profile.preferred_activities:
|
| 415 |
+
activities.append("Try local cuisine and food tours")
|
| 416 |
+
if "nature" in profile.preferred_activities:
|
| 417 |
+
activities.append("Explore national parks and hiking trails")
|
| 418 |
+
if "adventure" in profile.preferred_activities:
|
| 419 |
+
activities.append("Try adventure sports and activities")
|
| 420 |
+
|
| 421 |
+
# Add activities based on travel style
|
| 422 |
+
if profile.travel_style == TravelStyle.LUXURY:
|
| 423 |
+
activities.append("Book private guided tours")
|
| 424 |
+
elif profile.travel_style == TravelStyle.BUDGET:
|
| 425 |
+
activities.append("Explore local markets and street food")
|
| 426 |
+
|
| 427 |
+
return list(set(activities))[:5] # Return top 5 unique activities
|
| 428 |
+
except Exception as e:
|
| 429 |
+
logger.error(
|
| 430 |
+
f"Error getting activity recommendations: {str(e)}", exc_info=True
|
| 431 |
+
)
|
| 432 |
+
return []
|
| 433 |
+
|
| 434 |
+
def _get_personalized_tips(self, profile: UserPreferences) -> List[str]:
|
| 435 |
+
"""Get personalized travel tips based on preferences."""
|
| 436 |
+
try:
|
| 437 |
+
tips = []
|
| 438 |
+
|
| 439 |
+
# Add tips based on travel style
|
| 440 |
+
if profile.travel_style == TravelStyle.BUDGET:
|
| 441 |
+
tips.append(
|
| 442 |
+
"Look for local markets and street food for affordable meals"
|
| 443 |
+
)
|
| 444 |
+
tips.append("Consider staying in hostels or guesthouses")
|
| 445 |
+
elif profile.travel_style == TravelStyle.LUXURY:
|
| 446 |
+
tips.append("Book premium experiences and private tours in advance")
|
| 447 |
+
tips.append("Consider luxury resorts and boutique hotels")
|
| 448 |
+
|
| 449 |
+
# Add tips based on dietary restrictions
|
| 450 |
+
if profile.dietary_restrictions:
|
| 451 |
+
tips.append(
|
| 452 |
+
f"Research restaurants that accommodate {', '.join(profile.dietary_restrictions)}"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# Add tips based on accessibility needs
|
| 456 |
+
if profile.accessibility_needs:
|
| 457 |
+
tips.append(
|
| 458 |
+
f"Research accessibility features for {', '.join(profile.accessibility_needs)}"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
return list(set(tips))[:5] # Return top 5 unique tips
|
| 462 |
+
except Exception as e:
|
| 463 |
+
logger.error(f"Error getting personalized tips: {str(e)}", exc_info=True)
|
| 464 |
+
return []
|
data/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
docs/API.md
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TravelMate AI Assistant API Documentation
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The TravelMate AI Assistant API provides a comprehensive set of endpoints for interacting with an AI-powered travel assistant. The API uses RAG (Retrieval-Augmented Generation) to provide accurate and contextually relevant travel information.
|
| 6 |
+
|
| 7 |
+
## Base URL
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
https://api.travelmate.ai/v1
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## Authentication
|
| 14 |
+
|
| 15 |
+
All API endpoints require authentication using JWT (JSON Web Tokens). Include the token in the Authorization header:
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
Authorization: Bearer <your_token>
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Rate Limiting
|
| 22 |
+
|
| 23 |
+
The API implements rate limiting to ensure fair usage:
|
| 24 |
+
- 100 requests per hour per user
|
| 25 |
+
- Rate limit headers are included in responses:
|
| 26 |
+
- `X-RateLimit-Limit`: Maximum requests per window
|
| 27 |
+
- `X-RateLimit-Remaining`: Remaining requests in current window
|
| 28 |
+
- `X-RateLimit-Reset`: Time until rate limit resets
|
| 29 |
+
|
| 30 |
+
## Endpoints
|
| 31 |
+
|
| 32 |
+
### Chat
|
| 33 |
+
|
| 34 |
+
#### POST /chat
|
| 35 |
+
|
| 36 |
+
Process a chat message and get AI-generated response.
|
| 37 |
+
|
| 38 |
+
**Request Body:**
|
| 39 |
+
```json
|
| 40 |
+
{
|
| 41 |
+
"message": "What are the best places to visit in Paris?",
|
| 42 |
+
"chat_history": [
|
| 43 |
+
{
|
| 44 |
+
"user": "Hello",
|
| 45 |
+
"assistant": "Hi! How can I help you with your travel plans?"
|
| 46 |
+
}
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
**Response:**
|
| 52 |
+
```json
|
| 53 |
+
{
|
| 54 |
+
"answer": "Here are some must-visit places in Paris...",
|
| 55 |
+
"sources": [
|
| 56 |
+
{
|
| 57 |
+
"title": "Paris Travel Guide",
|
| 58 |
+
"url": "https://example.com/paris-guide",
|
| 59 |
+
"relevance_score": 0.95
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"suggested_questions": [
|
| 63 |
+
"What's the best time to visit the Eiffel Tower?",
|
| 64 |
+
"Are there any hidden gems in Paris?"
|
| 65 |
+
]
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### User Profile
|
| 70 |
+
|
| 71 |
+
#### GET /profile
|
| 72 |
+
|
| 73 |
+
Get the current user's profile.
|
| 74 |
+
|
| 75 |
+
**Response:**
|
| 76 |
+
```json
|
| 77 |
+
{
|
| 78 |
+
"user_id": "user_123",
|
| 79 |
+
"preferences": {
|
| 80 |
+
"travel_style": "balanced",
|
| 81 |
+
"preferred_destinations": ["Paris", "Tokyo"],
|
| 82 |
+
"dietary_restrictions": [],
|
| 83 |
+
"accessibility_needs": [],
|
| 84 |
+
"preferred_activities": ["sightseeing", "food"],
|
| 85 |
+
"budget_range": {
|
| 86 |
+
"min": 1000,
|
| 87 |
+
"max": 5000
|
| 88 |
+
},
|
| 89 |
+
"preferred_accommodation": "hotel",
|
| 90 |
+
"preferred_transportation": "flexible",
|
| 91 |
+
"travel_frequency": "occasional",
|
| 92 |
+
"preferred_seasons": ["spring", "fall"],
|
| 93 |
+
"special_requirements": []
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
#### PUT /profile/preferences
|
| 99 |
+
|
| 100 |
+
Update user preferences.
|
| 101 |
+
|
| 102 |
+
**Request Body:**
|
| 103 |
+
```json
|
| 104 |
+
{
|
| 105 |
+
"travel_style": "luxury",
|
| 106 |
+
"preferred_destinations": ["Paris", "Tokyo", "New York"],
|
| 107 |
+
"budget_range": {
|
| 108 |
+
"min": 2000,
|
| 109 |
+
"max": 10000
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
**Response:**
|
| 115 |
+
```json
|
| 116 |
+
{
|
| 117 |
+
"message": "Preferences updated successfully"
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### Health Check
|
| 122 |
+
|
| 123 |
+
#### GET /health
|
| 124 |
+
|
| 125 |
+
Check the health status of the API and its components.
|
| 126 |
+
|
| 127 |
+
**Response:**
|
| 128 |
+
```json
|
| 129 |
+
{
|
| 130 |
+
"status": "healthy",
|
| 131 |
+
"timestamp": "2024-02-20T12:00:00Z",
|
| 132 |
+
"version": "1.0.0",
|
| 133 |
+
"environment": "production",
|
| 134 |
+
"components": {
|
| 135 |
+
"rag_engine": "ok",
|
| 136 |
+
"user_profile": "ok"
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
## Error Handling
|
| 142 |
+
|
| 143 |
+
The API uses standard HTTP status codes and returns error responses in the following format:
|
| 144 |
+
|
| 145 |
+
```json
|
| 146 |
+
{
|
| 147 |
+
"error": "Error type",
|
| 148 |
+
"detail": "Detailed error message",
|
| 149 |
+
"timestamp": "2024-02-20T12:00:00Z",
|
| 150 |
+
"request_id": "req_123"
|
| 151 |
+
}
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Common error codes:
|
| 155 |
+
- 400: Bad Request
|
| 156 |
+
- 401: Unauthorized
|
| 157 |
+
- 403: Forbidden
|
| 158 |
+
- 404: Not Found
|
| 159 |
+
- 429: Too Many Requests
|
| 160 |
+
- 500: Internal Server Error
|
| 161 |
+
- 503: Service Unavailable
|
| 162 |
+
|
| 163 |
+
## Best Practices
|
| 164 |
+
|
| 165 |
+
1. **Error Handling**
|
| 166 |
+
- Always check response status codes
|
| 167 |
+
- Implement exponential backoff for retries
|
| 168 |
+
- Handle rate limiting gracefully
|
| 169 |
+
|
| 170 |
+
2. **Performance**
|
| 171 |
+
- Cache responses when appropriate
|
| 172 |
+
- Minimize chat history size
|
| 173 |
+
- Use compression for large requests
|
| 174 |
+
|
| 175 |
+
3. **Security**
|
| 176 |
+
- Keep tokens secure
|
| 177 |
+
- Use HTTPS for all requests
|
| 178 |
+
- Validate all input data
|
| 179 |
+
|
| 180 |
+
4. **Rate Limiting**
|
| 181 |
+
- Monitor rate limit headers
|
| 182 |
+
- Implement request queuing
|
| 183 |
+
- Handle 429 responses appropriately
|
| 184 |
+
|
| 185 |
+
## SDKs and Examples
|
| 186 |
+
|
| 187 |
+
### Python
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
import requests
|
| 191 |
+
|
| 192 |
+
class TravelMateClient:
|
| 193 |
+
def __init__(self, api_key, base_url="https://api.travelmate.ai/v1"):
|
| 194 |
+
self.api_key = api_key
|
| 195 |
+
self.base_url = base_url
|
| 196 |
+
self.session = requests.Session()
|
| 197 |
+
self.session.headers.update({
|
| 198 |
+
"Authorization": f"Bearer {api_key}",
|
| 199 |
+
"Content-Type": "application/json"
|
| 200 |
+
})
|
| 201 |
+
|
| 202 |
+
def chat(self, message, chat_history=None):
|
| 203 |
+
response = self.session.post(
|
| 204 |
+
f"{self.base_url}/chat",
|
| 205 |
+
json={
|
| 206 |
+
"message": message,
|
| 207 |
+
"chat_history": chat_history or []
|
| 208 |
+
}
|
| 209 |
+
)
|
| 210 |
+
response.raise_for_status()
|
| 211 |
+
return response.json()
|
| 212 |
+
|
| 213 |
+
def get_profile(self):
|
| 214 |
+
response = self.session.get(f"{self.base_url}/profile")
|
| 215 |
+
response.raise_for_status()
|
| 216 |
+
return response.json()
|
| 217 |
+
|
| 218 |
+
def update_preferences(self, preferences):
|
| 219 |
+
response = self.session.put(
|
| 220 |
+
f"{self.base_url}/profile/preferences",
|
| 221 |
+
json=preferences
|
| 222 |
+
)
|
| 223 |
+
response.raise_for_status()
|
| 224 |
+
return response.json()
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
### JavaScript
|
| 228 |
+
|
| 229 |
+
```javascript
|
| 230 |
+
class TravelMateClient {
|
| 231 |
+
constructor(apiKey, baseUrl = 'https://api.travelmate.ai/v1') {
|
| 232 |
+
this.apiKey = apiKey;
|
| 233 |
+
this.baseUrl = baseUrl;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
async chat(message, chatHistory = []) {
|
| 237 |
+
const response = await fetch(`${this.baseUrl}/chat`, {
|
| 238 |
+
method: 'POST',
|
| 239 |
+
headers: {
|
| 240 |
+
'Authorization': `Bearer ${this.apiKey}`,
|
| 241 |
+
'Content-Type': 'application/json'
|
| 242 |
+
},
|
| 243 |
+
body: JSON.stringify({
|
| 244 |
+
message,
|
| 245 |
+
chat_history: chatHistory
|
| 246 |
+
})
|
| 247 |
+
});
|
| 248 |
+
|
| 249 |
+
if (!response.ok) {
|
| 250 |
+
throw new Error(`API error: ${response.statusText}`);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
return response.json();
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
async getProfile() {
|
| 257 |
+
const response = await fetch(`${this.baseUrl}/profile`, {
|
| 258 |
+
headers: {
|
| 259 |
+
'Authorization': `Bearer ${this.apiKey}`
|
| 260 |
+
}
|
| 261 |
+
});
|
| 262 |
+
|
| 263 |
+
if (!response.ok) {
|
| 264 |
+
throw new Error(`API error: ${response.statusText}`);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
return response.json();
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
async updatePreferences(preferences) {
|
| 271 |
+
const response = await fetch(`${this.baseUrl}/profile/preferences`, {
|
| 272 |
+
method: 'PUT',
|
| 273 |
+
headers: {
|
| 274 |
+
'Authorization': `Bearer ${this.apiKey}`,
|
| 275 |
+
'Content-Type': 'application/json'
|
| 276 |
+
},
|
| 277 |
+
body: JSON.stringify(preferences)
|
| 278 |
+
});
|
| 279 |
+
|
| 280 |
+
if (!response.ok) {
|
| 281 |
+
throw new Error(`API error: ${response.statusText}`);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
return response.json();
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
## Support
|
| 290 |
+
|
| 291 |
+
For API support, please contact:
|
| 292 |
+
- Email: [email protected]
|
| 293 |
+
- Documentation: https://docs.travelmate.ai
|
| 294 |
+
- Status Page: https://status.travelmate.ai
|
huggingface.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: gradio
|
| 2 |
+
sdk_version: 4.19.2
|
| 3 |
+
app_file: app.py
|
| 4 |
+
python_version: "3.10"
|
| 5 |
+
|
| 6 |
+
# Hardware requirements
|
| 7 |
+
hardware:
|
| 8 |
+
cpu: 2
|
| 9 |
+
memory: 16GB
|
| 10 |
+
|
| 11 |
+
# Build settings
|
| 12 |
+
build:
|
| 13 |
+
cuda: "None" # No CUDA needed for CPU-only
|
| 14 |
+
system_packages:
|
| 15 |
+
- build-essential
|
| 16 |
+
- python3-dev
|
| 17 |
+
- cmake
|
| 18 |
+
- pkg-config
|
| 19 |
+
- libopenblas-dev
|
| 20 |
+
- libomp-dev
|
| 21 |
+
|
| 22 |
+
# Environment variables
|
| 23 |
+
env:
|
| 24 |
+
- MODEL_NAME=meta-llama/Llama-2-7b-chat-hf
|
| 25 |
+
- EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
| 26 |
+
- SECRET_KEY=${SECRET_KEY}
|
| 27 |
+
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
|
| 28 |
+
- RATE_LIMIT_REQUESTS=100
|
| 29 |
+
- RATE_LIMIT_WINDOW=3600
|
| 30 |
+
- LOG_LEVEL=INFO
|
| 31 |
+
|
| 32 |
+
# Dependencies
|
| 33 |
+
dependencies:
|
| 34 |
+
- gradio==4.19.2
|
| 35 |
+
- langchain==0.1.9
|
| 36 |
+
- langchain-core>=0.1.52,<0.2
|
| 37 |
+
- langchain-community==0.0.27
|
| 38 |
+
- langchain-text-splitters==0.0.1
|
| 39 |
+
- langchain-huggingface==0.0.3
|
| 40 |
+
- transformers==4.38.2
|
| 41 |
+
- torch==2.2.1
|
| 42 |
+
- accelerate==0.27.2
|
| 43 |
+
- bitsandbytes==0.42.0
|
| 44 |
+
- safetensors==0.4.2
|
| 45 |
+
- sentence-transformers==2.6.1
|
| 46 |
+
- faiss-cpu==1.7.4
|
| 47 |
+
- pydantic==2.5.3
|
| 48 |
+
- pydantic-settings==2.1.0
|
| 49 |
+
- python-dotenv==1.0.0
|
| 50 |
+
- fastapi==0.109.2
|
| 51 |
+
- uvicorn==0.27.1
|
| 52 |
+
- python-jose==3.3.0
|
| 53 |
+
- passlib==1.7.4
|
| 54 |
+
- python-multipart
|
| 55 |
+
- bcrypt==4.1.2
|
| 56 |
+
- httpx==0.26.0
|
| 57 |
+
- aiohttp==3.9.5
|
| 58 |
+
- tenacity==8.2.3
|
| 59 |
+
- cachetools==5.3.2
|
| 60 |
+
- numpy==1.26.3
|
| 61 |
+
- tqdm==4.66.1
|
| 62 |
+
- loguru==0.7.2
|
| 63 |
+
- datasets==2.16.1
|
| 64 |
+
- huggingface-hub==0.24.1
|
| 65 |
+
- circuitbreaker==1.4.0
|
| 66 |
+
|
| 67 |
+
# Health check
|
| 68 |
+
health_check:
|
| 69 |
+
path: /health
|
| 70 |
+
interval: 300
|
| 71 |
+
timeout: 10
|
| 72 |
+
retries: 3
|
| 73 |
+
|
| 74 |
+
# Resource limits
|
| 75 |
+
resources:
|
| 76 |
+
cpu: 2
|
| 77 |
+
memory: 16GB
|
| 78 |
+
|
| 79 |
+
# Cache settings
|
| 80 |
+
cache:
|
| 81 |
+
enabled: true
|
| 82 |
+
ttl: 3600
|
| 83 |
+
max_size: 1000
|
| 84 |
+
|
| 85 |
+
# Logging
|
| 86 |
+
logging:
|
| 87 |
+
level: INFO
|
| 88 |
+
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 89 |
+
handlers:
|
| 90 |
+
- type: file
|
| 91 |
+
filename: app.log
|
| 92 |
+
max_bytes: 10485760
|
| 93 |
+
backup_count: 5
|
| 94 |
+
- type: stream
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
|
| 97 |
+
# Space settings
|
| 98 |
+
space:
|
| 99 |
+
title: "TravelMate - AI Travel Assistant"
|
| 100 |
+
description: "An AI-powered travel assistant using Llama-2 and RAG to help plan trips and provide travel information"
|
| 101 |
+
license: mit
|
| 102 |
+
sdk: gradio
|
| 103 |
+
app_port: 7860
|
| 104 |
+
app_url: "https://huggingface.co/spaces/bharadwaj-m/TravelMate-AI"
|
| 105 |
+
|
| 106 |
+
# Build commands
|
| 107 |
+
build:
|
| 108 |
+
- pip install -r requirements.txt
|
| 109 |
+
- mkdir -p data/vector_store data/user_profiles data/cache
|
| 110 |
+
- python -c "from core.data_loader import DataLoader; DataLoader().initialize_knowledge_base()"
|
| 111 |
+
|
| 112 |
+
# Run command
|
| 113 |
+
run: python app.py
|
requirements.txt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core frameworks and UI
|
| 2 |
+
gradio==4.19.2
|
| 3 |
+
fastapi==0.109.2
|
| 4 |
+
uvicorn==0.27.1
|
| 5 |
+
|
| 6 |
+
# LangChain and ecosystem (aligned versions)
|
| 7 |
+
langchain==0.1.9
|
| 8 |
+
langchain-core>=0.1.52,<0.2
|
| 9 |
+
langchain-community==0.0.27
|
| 10 |
+
langchain-text-splitters==0.0.1
|
| 11 |
+
langchain-huggingface==0.0.3
|
| 12 |
+
|
| 13 |
+
# LLM and Transformers
|
| 14 |
+
transformers==4.41.2
|
| 15 |
+
torch==2.2.1
|
| 16 |
+
accelerate==0.27.2
|
| 17 |
+
|
| 18 |
+
# Embeddings / similarity search
|
| 19 |
+
sentence-transformers==2.6.1
|
| 20 |
+
faiss-cpu==1.7.4
|
| 21 |
+
|
| 22 |
+
datasets==2.16.1
|
| 23 |
+
|
| 24 |
+
# Security and auth
|
| 25 |
+
python-jose==3.3.0
|
| 26 |
+
passlib==1.7.4
|
| 27 |
+
bcrypt==4.1.2
|
| 28 |
+
python-multipart
|
| 29 |
+
|
| 30 |
+
# Data & utils
|
| 31 |
+
pydantic==2.5.3
|
| 32 |
+
pydantic-settings==2.1.0
|
| 33 |
+
python-dotenv==1.0.0
|
| 34 |
+
httpx==0.26.0
|
| 35 |
+
aiohttp==3.9.5
|
| 36 |
+
tenacity==8.2.3
|
| 37 |
+
expiringdict==1.2.1
|
| 38 |
+
numpy==1.26.3
|
| 39 |
+
tqdm==4.66.1
|
| 40 |
+
loguru==0.7.2
|
| 41 |
+
huggingface-hub==0.24.1
|
| 42 |
+
|
| 43 |
+
# Resilience / patterns
|
| 44 |
+
circuitbreaker==1.4.0
|