bharadwaj-m commited on
Commit
09aa2b8
·
0 Parent(s):

First Commit

Browse files
Files changed (15) hide show
  1. .gitattributes +35 -0
  2. .gitignore +93 -0
  3. README.md +13 -0
  4. api/dependencies.py +126 -0
  5. api/main.py +344 -0
  6. api/schemas.py +61 -0
  7. app.py +210 -0
  8. config/config.py +112 -0
  9. core/data_loader.py +185 -0
  10. core/rag_engine.py +151 -0
  11. core/user_profile.py +464 -0
  12. data/.gitkeep +1 -0
  13. docs/API.md +294 -0
  14. huggingface.yaml +113 -0
  15. requirements.txt +44 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+ .pytest_cache/
24
+ .coverage
25
+ htmlcov/
26
+
27
+ # Virtual Environment
28
+ venv/
29
+ ENV/
30
+ .env/
31
+ .venv/
32
+
33
+ # IDE
34
+ .idea/
35
+ .vscode/
36
+ *.swp
37
+ *.swo
38
+ .DS_Store
39
+ *.sublime-workspace
40
+ *.sublime-project
41
+
42
+ # Project specific
43
+ data/vector_store/
44
+ data/travel_guides.json
45
+ data/user_profiles/
46
+ data/cache/
47
+ .cache/
48
+ *.log
49
+ logs/
50
+ .env
51
+ .env.*
52
+ !.env.example
53
+ secrets.json
54
+ secret_key.py
55
+
56
+ # Model files
57
+ models/
58
+ *.bin
59
+ *.pt
60
+ *.pth
61
+ *.onnx
62
+ *.h5
63
+ *.hdf5
64
+ *.ckpt
65
+ *.safetensors
66
+
67
+ # Hugging Face
68
+ .huggingface/
69
+ transformers/
70
+ datasets/
71
+ hub/
72
+
73
+ # Temporary files
74
+ tmp/
75
+ temp/
76
+ *.tmp
77
+ *.temp
78
+ *.bak
79
+ *.swp
80
+ *~
81
+
82
+ # System files
83
+ .DS_Store
84
+ Thumbs.db
85
+ desktop.ini
86
+
87
+ # Docker
88
+ .docker/
89
+ docker-compose.override.yml
90
+
91
+ # Documentation
92
+ docs/_build/
93
+ site/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TravelMate AI
3
+ emoji: 🌍
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.33.1
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: AI-Powered Customer Support Chatbot using RAG
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
api/dependencies.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Depends, HTTPException, status
2
+ from fastapi.security import OAuth2PasswordBearer
3
+ from jose import JWTError, jwt
4
+ from datetime import datetime, timedelta
5
+ from typing import Dict, Any, Optional
6
+ import time
7
+ from functools import wraps
8
+ import logging
9
+ from cachetools import TTLCache
10
+
11
+ # BaseSettings import removed – unused
12
+
13
+ from config.config import settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Rate limiting cache with adjusted size
18
+ rate_limit_cache = TTLCache(
19
+ maxsize=settings.MAX_CACHE_SIZE, ttl=settings.RATE_LIMIT_WINDOW
20
+ )
21
+
22
+ # JWT settings from environment
23
+ SECRET_KEY = settings.JWT_SECRET_KEY
24
+ ALGORITHM = "HS256"
25
+ ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
26
+
27
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
28
+
29
+
30
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
31
+ """Create JWT access token with validation"""
32
+ if not isinstance(data, dict):
33
+ raise ValueError("Token data must be a dictionary")
34
+
35
+ to_encode = data.copy()
36
+ if expires_delta:
37
+ expire = datetime.utcnow() + expires_delta
38
+ else:
39
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
40
+
41
+ to_encode.update({"exp": expire})
42
+ try:
43
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
44
+ return encoded_jwt
45
+ except Exception as e:
46
+ logger.error(f"Error creating access token: {str(e)}", exc_info=True)
47
+ raise HTTPException(
48
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
49
+ detail="Error creating access token",
50
+ )
51
+
52
+
53
+ async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
54
+ """Get current user with enhanced validation"""
55
+ credentials_exception = HTTPException(
56
+ status_code=status.HTTP_401_UNAUTHORIZED,
57
+ detail="Could not validate credentials",
58
+ headers={"WWW-Authenticate": "Bearer"},
59
+ )
60
+
61
+ try:
62
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
63
+ user_id: str = payload.get("sub")
64
+ if not user_id or not isinstance(user_id, str):
65
+ raise credentials_exception
66
+
67
+ # Validate token expiration
68
+ exp = payload.get("exp")
69
+ if not exp or datetime.fromtimestamp(exp) < datetime.utcnow():
70
+ raise HTTPException(
71
+ status_code=status.HTTP_401_UNAUTHORIZED,
72
+ detail="Token has expired",
73
+ headers={"WWW-Authenticate": "Bearer"},
74
+ )
75
+
76
+ return {"user_id": user_id}
77
+ except JWTError as e:
78
+ logger.error(f"JWT validation error: {str(e)}", exc_info=True)
79
+ raise credentials_exception
80
+
81
+
82
+ def rate_limit(func):
83
+ """Rate limit decorator with enhanced validation"""
84
+
85
+ @wraps(func)
86
+ async def wrapper(*args, **kwargs):
87
+ current_user = kwargs.get("current_user")
88
+ if not current_user or "user_id" not in current_user:
89
+ raise HTTPException(
90
+ status_code=status.HTTP_400_BAD_REQUEST,
91
+ detail="User ID is required for rate limiting",
92
+ )
93
+ user_id = current_user["user_id"]
94
+
95
+ # Check rate limit with enhanced validation
96
+ current_time = time.time()
97
+ key = f"{user_id}:{current_time // settings.RATE_LIMIT_WINDOW}"
98
+
99
+ try:
100
+ with rate_limit_cache._lock:
101
+ if key in rate_limit_cache:
102
+ count = rate_limit_cache[key]
103
+ if count >= settings.RATE_LIMIT_REQUESTS:
104
+ raise HTTPException(
105
+ status_code=status.HTTP_429_TOO_MANY_REQUESTS,
106
+ detail=f"Rate limit exceeded. Try again in {settings.RATE_LIMIT_WINDOW} seconds",
107
+ )
108
+ rate_limit_cache[key] = count + 1
109
+ else:
110
+ rate_limit_cache[key] = 1
111
+
112
+ return await func(*args, **kwargs)
113
+ except Exception as e:
114
+ logger.error(f"Rate limit error: {str(e)}", exc_info=True)
115
+ raise HTTPException(
116
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
117
+ detail="Error processing rate limit",
118
+ )
119
+
120
+ return wrapper
121
+
122
+
123
+ async def cleanup():
124
+ """Cleanup resources"""
125
+ # Add any necessary cleanup here, e.g., closing database connections
126
+ pass
api/main.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Request, status, BackgroundTasks
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from fastapi.security import OAuth2PasswordBearer
5
+ from fastapi.middleware.gzip import GZipMiddleware
6
+ from typing import Dict, Any, Optional, List
7
+ import time
8
+ import logging
9
+ from datetime import datetime
10
+ from pydantic import BaseModel, Field
11
+ import os
12
+ import asyncio
13
+ from tenacity import retry, stop_after_attempt, wait_exponential
14
+
15
+ from config.config import settings
16
+ from core.rag_engine import RAGEngine
17
+ from core.user_profile import UserProfile, UserPreferences
18
+
19
+ # Define missing types
20
+ class ChatRequest(BaseModel):
21
+ message: str
22
+ chat_history: Optional[List[Dict[str, str]]] = None
23
+
24
+ class ChatResponse(BaseModel):
25
+ answer: str
26
+ sources: Optional[List[str]] = None
27
+ suggested_questions: Optional[List[str]] = None
28
+
29
+ class ErrorResponse(BaseModel):
30
+ error: str
31
+ detail: Optional[str] = None
32
+ timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
33
+ request_id: Optional[str] = None
34
+
35
+ class UserProfileResponse(BaseModel):
36
+ profile: Dict[str, Any]
37
+
38
+ class UserPreferencesUpdate(BaseModel):
39
+ preferences: Dict[str, Any]
40
+
41
+ # Setup logging with rotation
42
+ from logging.handlers import RotatingFileHandler
43
+
44
+ logging.basicConfig(
45
+ level=getattr(logging, settings.LOG_LEVEL),
46
+ format=settings.LOG_FORMAT,
47
+ handlers=[
48
+ logging.StreamHandler(),
49
+ RotatingFileHandler(
50
+ "api.log",
51
+ maxBytes=10 * 1024 * 1024, # 10MB
52
+ backupCount=5,
53
+ ),
54
+ ],
55
+ )
56
+ logger = logging.getLogger(__name__)
57
+
58
+ app = FastAPI(
59
+ title=settings.PROJECT_NAME,
60
+ description="AI-powered travel assistant API",
61
+ version=settings.VERSION,
62
+ docs_url="/docs", # Always show docs on HF Spaces
63
+ redoc_url="/redoc",
64
+ )
65
+
66
+ # Add security headers middleware
67
+ @app.middleware("http")
68
+ async def add_security_headers(request: Request, call_next):
69
+ response = await call_next(request)
70
+ response.headers["X-Content-Type-Options"] = "nosniff"
71
+ response.headers["X-Frame-Options"] = "DENY"
72
+ response.headers["X-XSS-Protection"] = "1; mode=block"
73
+ response.headers["Strict-Transport-Security"] = (
74
+ "max-age=31536000; includeSubDomains"
75
+ )
76
+ return response
77
+
78
+ # Add CORS middleware with validation
79
+ app.add_middleware(
80
+ CORSMiddleware,
81
+ allow_origins=["*"], # Allow all origins for Hugging Face Spaces
82
+ allow_credentials=True,
83
+ allow_methods=["GET", "POST", "PUT", "DELETE"],
84
+ allow_headers=["*"],
85
+ max_age=3600,
86
+ )
87
+
88
+ # Add Gzip compression
89
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
90
+
91
+ # Initialize core components with retry
92
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
93
+ async def initialize_components():
94
+ try:
95
+ global rag_engine, user_profile
96
+ rag_engine = RAGEngine()
97
+ user_profile = UserProfile()
98
+ logger.info("Core components initialized successfully")
99
+ except Exception as e:
100
+ logger.error(f"Failed to initialize core components: {str(e)}", exc_info=True)
101
+ raise
102
+
103
+ # Initialize components asynchronously
104
+ asyncio.create_task(initialize_components())
105
+
106
+ # OAuth2 scheme for token authentication
107
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
108
+
109
+ from api.dependencies import (
110
+ get_current_user,
111
+ rate_limit,
112
+ cleanup,
113
+ )
114
+
115
+ @app.exception_handler(Exception)
116
+ async def global_exception_handler(request: Request, exc: Exception):
117
+ """Global exception handler with request ID"""
118
+ request_id = request.headers.get("X-Request-ID", "unknown")
119
+ logger.error(
120
+ f"Unhandled exception: {str(exc)}",
121
+ exc_info=True,
122
+ extra={"request_id": request_id},
123
+ )
124
+ return JSONResponse(
125
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
126
+ content=ErrorResponse(
127
+ error="Internal Server Error", detail=str(exc), request_id=request_id
128
+ ).dict(),
129
+ )
130
+
131
+ @app.middleware("http")
132
+ async def add_process_time_header(request: Request, call_next):
133
+ """Add processing time header to response"""
134
+ start_time = time.time()
135
+ try:
136
+ response = await call_next(request)
137
+ process_time = time.time() - start_time
138
+ response.headers["X-Process-Time"] = str(process_time)
139
+ return response
140
+ except Exception as e:
141
+ logger.error(f"Error in middleware: {str(e)}", exc_info=True)
142
+ raise
143
+
144
+ @app.get("/")
145
+ async def root():
146
+ """Root endpoint with version info"""
147
+ return {
148
+ "message": "Welcome to TravelMate AI Assistant API",
149
+ "version": settings.VERSION,
150
+ "environment": settings.DEBUG, # Use DEBUG setting for environment
151
+ }
152
+
153
+ @app.post(
154
+ "/chat",
155
+ response_model=ChatResponse,
156
+ responses={
157
+ 400: {"model": ErrorResponse},
158
+ 401: {"model": ErrorResponse},
159
+ 429: {"model": ErrorResponse},
160
+ 500: {"model": ErrorResponse},
161
+ },
162
+ )
163
+ @rate_limit
164
+ async def chat(
165
+ request: ChatRequest,
166
+ background_tasks: BackgroundTasks,
167
+ current_user: Dict[str, Any] = Depends(get_current_user),
168
+ ):
169
+ """Process chat request with enhanced validation"""
170
+ try:
171
+ # Validate request size
172
+ if len(request.message) > settings.MAX_MESSAGE_LENGTH:
173
+ raise HTTPException(
174
+ status_code=status.HTTP_400_BAD_REQUEST,
175
+ detail=f"Message too long. Maximum length is {settings.MAX_MESSAGE_LENGTH} characters",
176
+ )
177
+
178
+ # Validate chat history
179
+ if request.chat_history:
180
+ if len(request.chat_history) > settings.MAX_CHAT_HISTORY:
181
+ raise HTTPException(
182
+ status_code=status.HTTP_400_BAD_REQUEST,
183
+ detail=f"Chat history too long. Maximum length is {settings.MAX_CHAT_HISTORY} messages",
184
+ )
185
+ for msg in request.chat_history:
186
+ if not isinstance(msg, dict) or not all(
187
+ k in msg for k in ["user", "assistant"]
188
+ ):
189
+ raise HTTPException(
190
+ status_code=status.HTTP_400_BAD_REQUEST,
191
+ detail="Invalid chat history format",
192
+ )
193
+
194
+ # Process query with RAG engine
195
+ result = await asyncio.wait_for(
196
+ rag_engine.process_query(
197
+ query=request.message,
198
+ chat_history=request.chat_history,
199
+ user_id=current_user["user_id"],
200
+ ),
201
+ timeout=settings.QUERY_TIMEOUT,
202
+ )
203
+
204
+ # Add cleanup task
205
+ background_tasks.add_task(cleanup)
206
+
207
+ return ChatResponse(
208
+ answer=result["answer"],
209
+ sources=result.get("metadata", {}).get("sources", []),
210
+ suggested_questions=result.get("suggested_questions", []),
211
+ )
212
+ except asyncio.TimeoutError:
213
+ raise HTTPException(
214
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out"
215
+ )
216
+ except HTTPException:
217
+ raise
218
+ except Exception as e:
219
+ logger.error(f"Error processing chat request: {str(e)}", exc_info=True)
220
+ raise HTTPException(
221
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
222
+ detail="Error processing chat request",
223
+ )
224
+
225
+ @app.get(
226
+ "/profile",
227
+ response_model=UserProfileResponse,
228
+ responses={401: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
229
+ )
230
+ async def get_profile(current_user: Dict[str, Any] = Depends(get_current_user)):
231
+ """Get user profile with enhanced error handling"""
232
+ try:
233
+ profile = await asyncio.wait_for(
234
+ user_profile.get_profile(current_user["user_id"]),
235
+ timeout=settings.PROFILE_TIMEOUT,
236
+ )
237
+ if not profile:
238
+ raise HTTPException(
239
+ status_code=status.HTTP_404_NOT_FOUND, detail="Profile not found"
240
+ )
241
+ return UserProfileResponse(**profile)
242
+ except asyncio.TimeoutError:
243
+ raise HTTPException(
244
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out"
245
+ )
246
+ except HTTPException:
247
+ raise
248
+ except Exception as e:
249
+ logger.error(f"Error getting user profile: {str(e)}", exc_info=True)
250
+ raise HTTPException(
251
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
252
+ detail="Error retrieving profile",
253
+ )
254
+
255
+ @app.put(
256
+ "/profile/preferences",
257
+ responses={
258
+ 400: {"model": ErrorResponse},
259
+ 401: {"model": ErrorResponse},
260
+ 500: {"model": ErrorResponse},
261
+ },
262
+ )
263
+ async def update_preferences(
264
+ preferences: UserPreferencesUpdate,
265
+ current_user: Dict[str, Any] = Depends(get_current_user),
266
+ ):
267
+ """Update user preferences with validation"""
268
+ try:
269
+ # Validate preferences
270
+ try:
271
+ UserPreferences(**preferences.preferences)
272
+ except Exception as e:
273
+ raise HTTPException(
274
+ status_code=status.HTTP_400_BAD_REQUEST,
275
+ detail=f"Invalid preferences: {str(e)}",
276
+ )
277
+
278
+ success = await asyncio.wait_for(
279
+ user_profile.update_profile(
280
+ current_user["user_id"], {"preferences": preferences.preferences}
281
+ ),
282
+ timeout=settings.PROFILE_TIMEOUT,
283
+ )
284
+
285
+ if not success:
286
+ raise HTTPException(
287
+ status_code=status.HTTP_400_BAD_REQUEST,
288
+ detail="Failed to update preferences",
289
+ )
290
+ return {"message": "Preferences updated successfully"}
291
+ except asyncio.TimeoutError:
292
+ raise HTTPException(
293
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail="Request timed out"
294
+ )
295
+ except HTTPException:
296
+ raise
297
+ except Exception as e:
298
+ logger.error(f"Error updating preferences: {str(e)}", exc_info=True)
299
+ raise HTTPException(
300
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
301
+ detail="Error updating preferences",
302
+ )
303
+
304
+ @app.get("/health", responses={500: {"model": ErrorResponse}})
305
+ async def health_check():
306
+ """Health check endpoint with detailed status"""
307
+ try:
308
+ # Check core components
309
+ if not rag_engine or not user_profile:
310
+ raise HTTPException(
311
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
312
+ detail="Core components not initialized",
313
+ )
314
+
315
+ return {
316
+ "status": "healthy",
317
+ "timestamp": datetime.utcnow().isoformat(),
318
+ "version": settings.VERSION,
319
+ "environment": settings.DEBUG, # Use DEBUG setting for environment
320
+ "components": {
321
+ "rag_engine": "ok",
322
+ "user_profile": "ok",
323
+ },
324
+ }
325
+ except Exception as e:
326
+ logger.error(f"Health check failed: {str(e)}", exc_info=True)
327
+ raise HTTPException(
328
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Service unhealthy"
329
+ )
330
+
331
+ @app.on_event("shutdown")
332
+ async def shutdown_event():
333
+ """Cleanup on shutdown"""
334
+ await cleanup()
335
+
336
+ if __name__ == "__main__":
337
+ import uvicorn
338
+
339
+ uvicorn.run(
340
+ "api.main:app",
341
+ host="0.0.0.0",
342
+ port=int(os.getenv("PORT", 7860)),
343
+ reload=False, # Set reload to False for production
344
+ )
api/schemas.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Dict, Any, Optional
3
+ from datetime import datetime
4
+
5
+
6
+ class ChatRequest(BaseModel):
7
+ message: str = Field(..., min_length=1, max_length=1000)
8
+ chat_history: List[Dict[str, str]] = Field(default_factory=list)
9
+
10
+
11
+ class Source(BaseModel):
12
+ title: str
13
+ url: str
14
+ relevance_score: float
15
+
16
+
17
+ class ChatResponse(BaseModel):
18
+ answer: str
19
+ sources: List[Source] = Field(default_factory=list)
20
+ suggested_questions: List[str] = Field(default_factory=list)
21
+ processing_time: Optional[float] = None
22
+
23
+
24
+ class UserPreferences(BaseModel):
25
+ favorite_destinations: List[str] = Field(default_factory=list)
26
+ travel_style: str = Field(default="balanced")
27
+ preferred_seasons: List[str] = Field(default_factory=list)
28
+ interests: List[str] = Field(default_factory=list)
29
+ dietary_restrictions: List[str] = Field(default_factory=list)
30
+ accessibility_needs: List[str] = Field(default_factory=list)
31
+ language: str = Field(default="en")
32
+ currency: str = Field(default="USD")
33
+ temperature_unit: str = Field(default="C")
34
+ timezone: str = Field(default="UTC")
35
+
36
+
37
+ class UserProfileResponse(BaseModel):
38
+ user_id: str
39
+ preferences: UserPreferences
40
+ created_at: datetime
41
+ updated_at: datetime
42
+
43
+
44
+ class UserPreferencesUpdate(BaseModel):
45
+ favorite_destinations: Optional[List[str]] = None
46
+ travel_style: Optional[str] = None
47
+ preferred_seasons: Optional[List[str]] = None
48
+ interests: Optional[List[str]] = None
49
+ dietary_restrictions: Optional[List[str]] = None
50
+ accessibility_needs: Optional[List[str]] = None
51
+ language: Optional[str] = None
52
+ currency: Optional[str] = None
53
+ temperature_unit: Optional[str] = None
54
+ timezone: Optional[str] = None
55
+
56
+
57
+ class ErrorResponse(BaseModel):
58
+ error: str
59
+ detail: Optional[str] = None
60
+ timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
61
+ request_id: Optional[str] = None
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import uuid
5
+ from typing import List, Dict, Any, Tuple
6
+ from logging.handlers import RotatingFileHandler
7
+
8
+ import gradio as gr
9
+ from tenacity import retry, stop_after_attempt, wait_exponential
10
+
11
+ from core.rag_engine import RAGEngine
12
+ from core.user_profile import UserProfile
13
+ from config.config import settings
14
+
15
+ # ======================================================================================
16
+ # Logging Setup
17
+ # ======================================================================================
18
+
19
+ os.makedirs("logs", exist_ok=True)
20
+
21
+ logging.basicConfig(
22
+ level=getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO),
23
+ format=settings.LOG_FORMAT,
24
+ handlers=[
25
+ logging.StreamHandler(sys.stdout),
26
+ RotatingFileHandler(
27
+ settings.LOG_FILE_PATH,
28
+ maxBytes=settings.LOG_FILE_MAX_BYTES,
29
+ backupCount=settings.LOG_FILE_BACKUP_COUNT
30
+ ),
31
+ ],
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # ======================================================================================
36
+ # Core Module Initialization
37
+ # ======================================================================================
38
+
39
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
40
+ def initialize_with_retry(func):
41
+ """Initializes a component with retry logic."""
42
+ try:
43
+ return func()
44
+ except Exception as e:
45
+ logger.error(f"Initialization failed: {e}", exc_info=True)
46
+ raise
47
+
48
+ try:
49
+ user_profile = initialize_with_retry(UserProfile)
50
+ rag_engine = initialize_with_retry(lambda: RAGEngine(user_profile=user_profile))
51
+ logger.info("Core modules initialized successfully.")
52
+ except Exception as e:
53
+ logger.critical(f"Fatal: Could not initialize core modules: {e}. Exiting.", exc_info=True)
54
+ sys.exit(1)
55
+
56
+ # ======================================================================================
57
+ # Business Logic
58
+ # ======================================================================================
59
+
60
+ async def handle_chat_interaction(
61
+ message: str, chat_history: List[List[str]], user_id: str, categories: List[str]
62
+ ) -> List[List[str]]:
63
+ """Handles the user's chat message, processes it, and updates the history."""
64
+ if not message.strip():
65
+ gr.Warning("Message cannot be empty. Please type a question.")
66
+ return chat_history
67
+
68
+ try:
69
+ profile = user_profile.get_profile(user_id)
70
+ profile["preferences"]["favorite_categories"] = categories
71
+ user_profile.update_profile(user_id, profile)
72
+ logger.info(f"Updated preferences for user {user_id}: {categories}")
73
+
74
+ result = await rag_engine.process_query(query=message, user_id=user_id)
75
+ response = result.get("answer", "Sorry, I could not find an answer.")
76
+
77
+ sources = result.get("sources")
78
+ if sources:
79
+ response += "\n\n**Sources:**\n" + format_sources(sources)
80
+
81
+ chat_history.append((message, response))
82
+ logger.info(f"User {user_id} received response.")
83
+ return chat_history
84
+
85
+ except Exception as e:
86
+ error_message = f"An unexpected error occurred: {str(e)}"
87
+ logger.error(f"Error for user {user_id}: {error_message}", exc_info=True)
88
+ gr.Warning("Sorry, I encountered a problem. Please try again or rephrase your question.")
89
+ return chat_history
90
+
91
+ def format_sources(sources: List[Dict[str, Any]]) -> str:
92
+ """Formats the source documents into a readable string."""
93
+ if not sources:
94
+ return ""
95
+ formatted_list = [f"- **{source.get('title', 'Unknown Source')}** (Category: {source.get('category', 'N/A')})" for source in sources]
96
+ return "\n".join(formatted_list)
97
+
98
+ # ======================================================================================
99
+ # Gradio UI Definition
100
+ # ======================================================================================
101
+
102
+ def handle_slider_change(value: int) -> None:
103
+ """
104
+ Handles the change event for the document loader slider.
105
+ Note: This currently only shows a notification. A restart is required.
106
+ """
107
+ gr.Info(f"Document limit set to {int(value)}. Please restart the app for changes to take effect.")
108
+
109
+
110
+ def create_interface() -> gr.Blocks:
111
+ """Creates and configures the Gradio web interface."""
112
+
113
+ with gr.Blocks(
114
+ title="TravelMate - Your AI Travel Assistant",
115
+ theme=gr.themes.Base(),
116
+ ) as demo:
117
+
118
+ user_id = gr.State(lambda: str(uuid.uuid4()))
119
+
120
+ gr.Markdown("""
121
+ <div style="text-align: center;">
122
+ <h1 style="font-size: 2.5em;">✈️ TravelMate</h1>
123
+ <p style="font-size: 1.1em; color: #333;">Your AI-powered travel assistant. Ask me anything to plan your next trip!</p>
124
+ </div>
125
+ """)
126
+
127
+ with gr.Accordion("Advanced Settings", open=False):
128
+ doc_load_slider = gr.Slider(
129
+ minimum=100,
130
+ maximum=5000,
131
+ value=settings.MAX_DOCUMENTS_TO_LOAD,
132
+ step=100,
133
+ label="Documents to Load",
134
+ info="Controls how many documents are loaded for the RAG engine. Higher values may increase startup time.",
135
+ )
136
+ doc_load_slider.change(
137
+ fn=handle_slider_change, inputs=[doc_load_slider], outputs=None
138
+ )
139
+
140
+ with gr.Row():
141
+ with gr.Column(scale=2):
142
+ chatbot = gr.Chatbot(
143
+ elem_id="chatbot",
144
+ label="TravelMate Chat",
145
+ height=600,
146
+ show_label=False,
147
+ show_copy_button=True,
148
+ bubble_full_width=False,
149
+ avatar_images=("assets/user_avatar.png", "assets/bot_avatar.png"),
150
+ )
151
+ with gr.Row():
152
+ msg = gr.Textbox(
153
+ placeholder="Ask me about destinations, flights, hotels...",
154
+ show_label=False,
155
+ container=False,
156
+ scale=8,
157
+ )
158
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
159
+
160
+ with gr.Column(scale=1):
161
+ gr.Markdown("### Select Your Interests")
162
+ categories = gr.CheckboxGroup(
163
+ choices=[
164
+ "Flights", "Hotels", "Destinations", "Activities",
165
+ "Transportation", "Food & Dining", "Shopping",
166
+ "Health & Safety", "Budget Planning"
167
+ ],
168
+ value=["Flights", "Hotels"],
169
+ label="Travel Categories",
170
+ )
171
+ gr.Markdown("### Example Questions")
172
+ gr.Examples(
173
+ examples=[
174
+ "What are the best places to visit in Japan?",
175
+ "How do I find cheap flights to Europe?",
176
+ "What should I pack for a beach vacation?",
177
+ "Tell me about local customs in Thailand",
178
+ "What's the best time to visit Paris?",
179
+ ],
180
+ inputs=msg,
181
+ )
182
+
183
+ async def on_submit(message: str, history: List[List[str]], uid: str, cats: List[str]) -> Tuple[str, List[List[str]]]:
184
+ """Handles submission and returns updated values for the message box and chatbot."""
185
+ updated_history = await handle_chat_interaction(message, history, uid, cats)
186
+ return "", updated_history
187
+
188
+ submit_btn.click(on_submit, [msg, chatbot, user_id, categories], [msg, chatbot])
189
+ msg.submit(on_submit, [msg, chatbot, user_id, categories], [msg, chatbot])
190
+
191
+ return demo
192
+
193
+ # ======================================================================================
194
+ # Application Launch
195
+ # ======================================================================================
196
+
197
+ if __name__ == "__main__":
198
+ try:
199
+ app = create_interface()
200
+ app.queue(default_concurrency_limit=settings.GRADIO_CONCURRENCY_COUNT)
201
+ app.launch(
202
+ server_name=settings.GRADIO_SERVER_NAME,
203
+ server_port=settings.GRADIO_SERVER_PORT,
204
+ share=settings.GRADIO_SHARE,
205
+ show_error=True,
206
+ show_api=False,
207
+ )
208
+ except Exception as e:
209
+ logger.critical(f"Failed to launch Gradio app: {e}", exc_info=True)
210
+ raise
config/config.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from pydantic import model_validator
6
+ from pydantic_settings import BaseSettings
7
+
8
+ class Settings(BaseSettings):
9
+ """Manages all application settings using Pydantic for robust configuration.
10
+
11
+ Attributes:
12
+ BASE_DIR (Path): The root directory of the project.
13
+ DATA_DIR (Path): The directory for storing data.
14
+ # ... other attributes
15
+ """
16
+ # ----------------------------------------------------------------------------------
17
+ # Path Settings
18
+ # ----------------------------------------------------------------------------------
19
+ BASE_DIR: Path = Path(__file__).resolve().parent.parent
20
+ DATA_DIR: Path = BASE_DIR / "data"
21
+ VECTOR_STORE_DIR: Path = DATA_DIR / "vector_store"
22
+ USER_PROFILES_DIR: Path = DATA_DIR / "user_profiles"
23
+ CACHE_DIR: Path = DATA_DIR / "cache"
24
+ LOGS_DIR: Path = BASE_DIR / "logs"
25
+
26
+ # ----------------------------------------------------------------------------------
27
+ # Model & RAG Settings
28
+ # ----------------------------------------------------------------------------------
29
+ MODEL_NAME: str = "google/gemma-2b-it"
30
+ EMBEDDING_MODEL_NAME: str = os.getenv("EMBEDDING_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
31
+ DATASET_ID: str = "bitext/Bitext-travel-llm-chatbot-training-dataset"
32
+ HUGGINGFACE_API_TOKEN: Optional[str] = os.getenv("HUGGINGFACE_API_TOKEN")
33
+
34
+ # RAG pipeline settings
35
+ CHUNK_SIZE: int = 1000
36
+ CHUNK_OVERLAP: int = 100
37
+ MAX_DOCUMENTS_TO_LOAD: int = 50 # Drastically reduced for performance
38
+ TOP_K_RESULTS: int = 3
39
+ SIMILARITY_THRESHOLD: float = 0.7
40
+
41
+ # Model configuration for HuggingFaceEndpoint
42
+ TEMPERATURE: float = 0.7
43
+ MAX_NEW_TOKENS: int = 512 # Drastically reduced to combat latency
44
+ REPETITION_PENALTY: float = 1.2
45
+
46
+ # Production-Grade Prompt Template
47
+ QA_PROMPT_TEMPLATE: str = """You are TravelMate, an expert AI travel assistant.
48
+ Use the following context to answer the user's question concisely and helpfully.
49
+ If you don't know the answer, simply say that you don't know. Do not make up information.
50
+
51
+ Context:
52
+ {context}
53
+
54
+ Question: {input}
55
+
56
+ Answer:"""
57
+
58
+ # Cache settings
59
+ MAX_CACHE_SIZE: int = 1000
60
+ CACHE_TTL: int = 3600 # Time-to-live in seconds (1 hour)
61
+
62
+ # ----------------------------------------------------------------------------------
63
+ # Application Behavior Settings
64
+ # ----------------------------------------------------------------------------------
65
+ QUERY_TIMEOUT: int = 30 # seconds
66
+ MAX_MESSAGE_LENGTH: int = 500 # characters
67
+ MAX_CHAT_HISTORY: int = 20 # messages
68
+
69
+ # ----------------------------------------------------------------------------------
70
+ # API & Security Settings
71
+ # ----------------------------------------------------------------------------------
72
+ API_V1_STR: str = "/api/v1"
73
+ PROJECT_NAME: str = "TravelMate AI Assistant"
74
+ VERSION: str = "1.0.0"
75
+ DEBUG: bool = False
76
+ SECRET_KEY: str = os.getenv("SECRET_KEY", "a_very_secret_key")
77
+ JWT_SECRET_KEY: str = os.getenv("JWT_SECRET_KEY", "a_very_secret_jwt_key")
78
+ ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 # 24 hours
79
+
80
+ # ----------------------------------------------------------------------------------
81
+ # Logging Settings
82
+ # ----------------------------------------------------------------------------------
83
+ LOG_LEVEL: str = "INFO"
84
+ LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
85
+ LOG_FILE_PATH: Path = LOGS_DIR / "app.log"
86
+ LOG_FILE_MAX_BYTES: int = 10 * 1024 * 1024 # 10MB
87
+ LOG_FILE_BACKUP_COUNT: int = 5
88
+
89
+ # ----------------------------------------------------------------------------------
90
+ # Gradio UI Settings
91
+ # ----------------------------------------------------------------------------------
92
+ GRADIO_SERVER_NAME: str = "0.0.0.0"
93
+ GRADIO_SERVER_PORT: int = 7860
94
+ GRADIO_SHARE: bool = True
95
+ GRADIO_CONCURRENCY_COUNT: int = 5
96
+
97
+ class Config:
98
+ env_file = ".env"
99
+ case_sensitive = True
100
+
101
+ @model_validator(mode='after')
102
+ def create_directories(self) -> 'Settings':
103
+ """Ensures that necessary directories exist upon settings initialization."""
104
+ self.DATA_DIR.mkdir(parents=True, exist_ok=True)
105
+ self.VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
106
+ self.USER_PROFILES_DIR.mkdir(parents=True, exist_ok=True)
107
+ self.CACHE_DIR.mkdir(parents=True, exist_ok=True)
108
+ self.LOGS_DIR.mkdir(parents=True, exist_ok=True)
109
+ return self
110
+
111
+
112
+ settings = Settings()
core/data_loader.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ import stat
5
+ import time
6
+ from typing import Any, List
7
+
8
+ from config.config import settings
9
+ from datasets import load_dataset
10
+ from langchain.schema import Document
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DataLoader:
19
+ """Handles loading and processing of data for the RAG engine."""
20
+
21
+ def __init__(self):
22
+ """Initialize the data loader."""
23
+ self.data_dir = os.path.abspath("data")
24
+ self.travel_guides_path = os.path.join(self.data_dir, "travel_guides.json")
25
+ self.vector_store_path = os.path.join(self.data_dir, "vector_store", "faiss_index")
26
+ self._ensure_data_directories()
27
+ self._set_directory_permissions()
28
+
29
+ self.text_splitter = RecursiveCharacterTextSplitter(
30
+ chunk_size=settings.CHUNK_SIZE,
31
+ chunk_overlap=settings.CHUNK_OVERLAP,
32
+ length_function=len,
33
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
34
+ )
35
+ self.max_file_size = 10 * 1024 * 1024 # 10MB
36
+
37
+ def _ensure_data_directories(self):
38
+ """Ensure necessary data directories exist."""
39
+ os.makedirs(self.data_dir, exist_ok=True)
40
+ os.makedirs(os.path.dirname(self.vector_store_path), exist_ok=True)
41
+ os.makedirs(os.path.join(self.data_dir, "cache"), exist_ok=True)
42
+
43
+ def _set_directory_permissions(self):
44
+ """Set secure permissions for data directories (755)."""
45
+ try:
46
+ for dir_path in [
47
+ self.data_dir,
48
+ os.path.dirname(self.vector_store_path),
49
+ os.path.join(self.data_dir, "cache"),
50
+ ]:
51
+ os.chmod(
52
+ dir_path,
53
+ stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH,
54
+ )
55
+ except Exception as e:
56
+ logger.error(f"Error setting directory permissions: {e}", exc_info=True)
57
+
58
+ def _validate_file_permissions(self, file_path: str) -> bool:
59
+ """Validate file permissions to ensure security."""
60
+ try:
61
+ if not os.path.exists(file_path):
62
+ return False
63
+ file_stat = os.stat(file_path)
64
+ if file_stat.st_mode & stat.S_IWOTH: # Disallow world-writable
65
+ logger.warning(f"File {file_path} is world-writable. Skipping.")
66
+ return False
67
+ if file_stat.st_size > self.max_file_size:
68
+ logger.warning(f"File {file_path} exceeds size limit. Skipping.")
69
+ return False
70
+ return True
71
+ except Exception as e:
72
+ logger.error(f"Error validating file permissions for {file_path}: {e}", exc_info=True)
73
+ return False
74
+
75
+ def _load_dataset_with_retry(self, max_retries: int = 3) -> Any:
76
+ """Load dataset from Hugging Face with an exponential backoff retry mechanism."""
77
+ for attempt in range(max_retries):
78
+ try:
79
+ return load_dataset(
80
+ settings.DATASET_ID,
81
+ split="train",
82
+ cache_dir=os.path.join(self.data_dir, "cache"),
83
+ )
84
+ except Exception as e:
85
+ logger.warning(f"Dataset loading attempt {attempt + 1} failed: {e}")
86
+ if attempt == max_retries - 1:
87
+ logger.error("All attempts to load dataset failed.")
88
+ return None
89
+ time.sleep(2 ** attempt)
90
+ return None
91
+
92
+ def load_documents(self) -> List[Document]:
93
+ """Load and process all documents for the knowledge base."""
94
+ documents = []
95
+ try:
96
+ # 1. Load Bitext Travel Dataset
97
+ logger.info(f"Loading dataset: {settings.DATASET_ID}")
98
+ dataset = self._load_dataset_with_retry()
99
+ if dataset:
100
+ max_docs = settings.MAX_DOCUMENTS_TO_LOAD
101
+ logger.info(f"Loading up to {max_docs} documents from the dataset.")
102
+ for i, item in enumerate(dataset):
103
+ if i >= max_docs:
104
+ logger.info(f"Reached document limit ({max_docs}).")
105
+ break
106
+ instruction = item.get("instruction")
107
+ response = item.get("response")
108
+
109
+ if not instruction or not response:
110
+ logger.warning(f"Skipping item with missing instruction or response: {item}")
111
+ continue
112
+
113
+ page_content = f"User query: {instruction}\n\nChatbot response: {response}"
114
+ metadata = {
115
+ "source": "huggingface",
116
+ "intent": item.get("intent"),
117
+ "category": item.get("category"),
118
+ "tags": item.get("tags"),
119
+ }
120
+
121
+ documents.append(Document(page_content=page_content, metadata=metadata))
122
+
123
+ # 2. Load Local Travel Guides
124
+ logger.info("Loading local travel guides...")
125
+ if os.path.exists(self.travel_guides_path) and self._validate_file_permissions(self.travel_guides_path):
126
+ with open(self.travel_guides_path, "r", encoding="utf-8") as f:
127
+ guides = json.load(f)
128
+ for guide in guides:
129
+ if not all(k in guide for k in ["title", "content", "category"]):
130
+ logger.warning(f"Skipping malformed guide: {guide}")
131
+ continue
132
+ doc = Document(
133
+ page_content=guide["content"],
134
+ metadata={
135
+ "title": guide["title"],
136
+ "category": guide["category"],
137
+ "source": "travel_guide",
138
+ },
139
+ )
140
+ documents.append(doc)
141
+ else:
142
+ logger.info("Travel guides file not found or invalid. Skipping.")
143
+
144
+ logger.info(f"Loaded {len(documents)} documents in total.")
145
+ return documents
146
+
147
+ except Exception as e:
148
+ logger.error(f"A critical error occurred while loading documents: {e}", exc_info=True)
149
+ return []
150
+
151
+ def create_vector_store(self, documents: List[Document]):
152
+ """Create a FAISS vector store from documents."""
153
+ try:
154
+ logger.info("Creating vector store...")
155
+ embeddings = HuggingFaceEmbeddings(
156
+ model_name=settings.EMBEDDING_MODEL_NAME,
157
+ model_kwargs={"device": "cpu"},
158
+ encode_kwargs={"normalize_embeddings": True},
159
+ )
160
+
161
+ split_docs = self.text_splitter.split_documents(documents)
162
+
163
+ vector_store = FAISS.from_documents(
164
+ documents=split_docs,
165
+ embedding=embeddings,
166
+ )
167
+ vector_store.save_local(self.vector_store_path)
168
+ logger.info(f"Vector store created and saved to {self.vector_store_path} with {len(split_docs)} chunks.")
169
+ except Exception as e:
170
+ logger.error(f"Error creating vector store: {e}", exc_info=True)
171
+ raise
172
+
173
+ def initialize_knowledge_base(self):
174
+ """Initialize the complete knowledge base."""
175
+ try:
176
+ logger.info("Initializing knowledge base...")
177
+ documents = self.load_documents()
178
+ if not documents:
179
+ logger.error("No documents were loaded. Aborting knowledge base initialization.")
180
+ return
181
+ self.create_vector_store(documents)
182
+ logger.info("Knowledge base initialized successfully.")
183
+ except Exception as e:
184
+ logger.critical(f"Failed to initialize knowledge base: {e}", exc_info=True)
185
+ raise
core/rag_engine.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Any
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain.chains.combine_documents import create_stuff_documents_chain
10
+ from langchain_core.documents import Document
11
+ from expiringdict import ExpiringDict
12
+
13
+ from core.data_loader import DataLoader
14
+ from core.user_profile import UserProfile
15
+ from config.config import settings
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class RAGEngine:
20
+ """
21
+ The core Retrieval-Augmented Generation engine for the TravelMate chatbot.
22
+ This class handles model initialization, vector store creation, and query processing.
23
+ """
24
+
25
+ def __init__(self, user_profile: UserProfile):
26
+ """
27
+ Initializes the RAG engine, loading models and setting up the QA chain.
28
+ """
29
+ self.user_profile = user_profile
30
+ self.query_cache = ExpiringDict(max_len=settings.MAX_CACHE_SIZE, max_age_seconds=settings.CACHE_TTL)
31
+
32
+ try:
33
+ self.embeddings = self._initialize_embeddings()
34
+ self.vector_store = self._initialize_vector_store()
35
+ self.llm = self._initialize_llm()
36
+ self.qa_chain = self._create_rag_chain()
37
+ logger.info("RAG Engine initialized successfully.")
38
+ except Exception as e:
39
+ logger.critical(f"Failed to initialize RAG Engine: {e}", exc_info=True)
40
+ raise
41
+
42
+ def _initialize_embeddings(self) -> HuggingFaceEmbeddings:
43
+ """Initializes the sentence-transformer embeddings model."""
44
+ return HuggingFaceEmbeddings(
45
+ model_name=settings.EMBEDDING_MODEL_NAME,
46
+ model_kwargs={'device': 'cpu'}
47
+ )
48
+
49
+ def _initialize_vector_store(self) -> FAISS:
50
+ """
51
+ Initializes the FAISS vector store.
52
+ Loads from disk if it exists, otherwise creates it from the data loader.
53
+ """
54
+ if settings.VECTOR_STORE_DIR.exists() and any(settings.VECTOR_STORE_DIR.iterdir()):
55
+ logger.info(f"Loading existing vector store from {settings.VECTOR_STORE_DIR}...")
56
+ return FAISS.load_local(
57
+ folder_path=str(settings.VECTOR_STORE_DIR),
58
+ embeddings=self.embeddings,
59
+ allow_dangerous_deserialization=True
60
+ )
61
+ else:
62
+ logger.info("Creating new vector store from scratch.")
63
+ data_loader = DataLoader()
64
+ documents = data_loader.load_documents()
65
+
66
+ if not documents:
67
+ raise ValueError("No documents were loaded. Cannot create vector store.")
68
+
69
+ vector_store = FAISS.from_documents(documents, self.embeddings)
70
+ logger.info(f"Saving new vector store to {settings.VECTOR_STORE_DIR}...")
71
+ vector_store.save_local(str(settings.VECTOR_STORE_DIR))
72
+ return vector_store
73
+
74
+ def _initialize_llm(self) -> HuggingFaceEndpoint:
75
+ """Initializes the Hugging Face Inference Endpoint for the LLM."""
76
+ if not settings.HUGGINGFACE_API_TOKEN:
77
+ raise ValueError("HUGGINGFACE_API_TOKEN is not set.")
78
+
79
+ return HuggingFaceEndpoint(
80
+ repo_id=settings.MODEL_NAME,
81
+ huggingfacehub_api_token=settings.HUGGINGFACE_API_TOKEN,
82
+ temperature=settings.TEMPERATURE,
83
+ max_new_tokens=settings.MAX_NEW_TOKENS,
84
+ repetition_penalty=settings.REPETITION_PENALTY,
85
+ )
86
+
87
+ def _create_rag_chain(self):
88
+ """Creates a modern, streamlined RAG chain for question answering."""
89
+ qa_prompt = PromptTemplate.from_template(settings.QA_PROMPT_TEMPLATE)
90
+
91
+ question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
92
+
93
+ retriever = self.vector_store.as_retriever(
94
+ search_type="similarity_score_threshold",
95
+ search_kwargs={'k': settings.TOP_K_RESULTS, 'score_threshold': settings.SIMILARITY_THRESHOLD}
96
+ )
97
+
98
+ rag_chain = create_retrieval_chain(retriever, question_answer_chain)
99
+ return rag_chain
100
+
101
+ def _format_sources(self, sources: List[Document]) -> List[Dict[str, Any]]:
102
+ """Formats source documents into a serializable list of dictionaries."""
103
+ if not sources:
104
+ return []
105
+
106
+ formatted_list = []
107
+ for source in sources:
108
+ metadata = source.metadata
109
+ source_name = metadata.get('source', 'Unknown Source')
110
+
111
+ if source_name == 'huggingface':
112
+ title = f"Dataset: {metadata.get('intent', 'N/A')}"
113
+ category = metadata.get('category', 'N/A')
114
+ elif source_name == 'local_guides':
115
+ title = f"Guide: {metadata.get('title', 'N/A')}"
116
+ category = metadata.get('category', 'N/A')
117
+ else:
118
+ title = "Unknown Source"
119
+ category = "N/A"
120
+
121
+ formatted_list.append({"title": title, "category": category})
122
+
123
+ return formatted_list
124
+
125
+ async def process_query(self, query: str, user_id: str) -> Dict[str, Any]:
126
+ """Processes a user query asynchronously using the streamlined RAG chain."""
127
+ cache_key = f"{user_id}:{query}"
128
+ if cache_key in self.query_cache:
129
+ logger.info(f"Returning cached response for query: {query}")
130
+ return self.query_cache[cache_key]
131
+
132
+ logger.info(f"Processing query for user {user_id}: {query}")
133
+
134
+ # The new chain expects 'input' instead of 'question'
135
+ chain_input = {"input": query}
136
+
137
+ try:
138
+ result = await self.qa_chain.ainvoke(chain_input)
139
+
140
+ answer = result.get("answer", "Sorry, I couldn't find an answer.")
141
+ # The new chain returns retrieved documents in the 'context' key
142
+ sources = self._format_sources(result.get("context", []))
143
+
144
+ response = {"answer": answer, "sources": sources}
145
+ self.query_cache[cache_key] = response
146
+
147
+ logger.info(f"Successfully processed query for user {user_id}")
148
+ return response
149
+ except Exception as e:
150
+ logger.error(f"Error processing query for user {user_id}: {e}", exc_info=True)
151
+ return {"answer": "I'm sorry, but I encountered an error while processing your request.", "sources": []}
core/user_profile.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Optional, List
2
+ import json
3
+ import os
4
+ from datetime import datetime, timedelta
5
+ import logging
6
+ from pydantic import BaseModel, Field, validator
7
+ from enum import Enum
8
+ import time
9
+ import re
10
+ from config.config import settings
11
+ import threading
12
+ import shutil
13
+ from pathlib import Path
14
+ import hashlib
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Global lock for profile operations
19
+ _profile_lock = threading.Lock()
20
+
21
+
22
+ class TravelStyle(str, Enum):
23
+ BUDGET = "budget"
24
+ LUXURY = "luxury"
25
+ BALANCED = "balanced"
26
+
27
+
28
+ class UserPreferences(BaseModel):
29
+ """User preferences model with validation."""
30
+
31
+ travel_style: str = Field(
32
+ default="balanced",
33
+ description="Preferred travel style (budget, luxury, balanced)",
34
+ )
35
+ preferred_destinations: list = Field(
36
+ default_factory=list,
37
+ description="List of preferred travel destinations",
38
+ )
39
+ dietary_restrictions: list = Field(
40
+ default_factory=list,
41
+ description="List of dietary restrictions",
42
+ )
43
+ accessibility_needs: list = Field(
44
+ default_factory=list,
45
+ description="List of accessibility requirements",
46
+ )
47
+ preferred_activities: list = Field(
48
+ default_factory=list,
49
+ description="List of preferred activities",
50
+ )
51
+ budget_range: Dict[str, float] = Field(
52
+ default_factory=lambda: {"min": 0, "max": float("inf")},
53
+ description="Budget range for travel",
54
+ )
55
+ preferred_accommodation: str = Field(
56
+ default="hotel",
57
+ description="Preferred type of accommodation",
58
+ )
59
+ preferred_transportation: str = Field(
60
+ default="flexible",
61
+ description="Preferred mode of transportation",
62
+ )
63
+ travel_frequency: str = Field(
64
+ default="occasional",
65
+ description="How often the user travels",
66
+ )
67
+ preferred_seasons: list = Field(
68
+ default_factory=list,
69
+ description="Preferred travel seasons",
70
+ )
71
+ special_requirements: list = Field(
72
+ default_factory=list,
73
+ description="Any special travel requirements",
74
+ )
75
+
76
+ @validator("travel_style")
77
+ def validate_travel_style(cls, v):
78
+ allowed_styles = ["budget", "luxury", "balanced"]
79
+ if v not in allowed_styles:
80
+ raise ValueError(f"Travel style must be one of {allowed_styles}")
81
+ return v
82
+
83
+ @validator("preferred_accommodation")
84
+ def validate_accommodation(cls, v):
85
+ allowed_types = [
86
+ "hotel",
87
+ "hostel",
88
+ "apartment",
89
+ "resort",
90
+ "camping",
91
+ "flexible",
92
+ ]
93
+ if v not in allowed_types:
94
+ raise ValueError(f"Accommodation type must be one of {allowed_types}")
95
+ return v
96
+
97
+ @validator("preferred_transportation")
98
+ def validate_transportation(cls, v):
99
+ allowed_types = [
100
+ "car",
101
+ "train",
102
+ "bus",
103
+ "plane",
104
+ "flexible",
105
+ ]
106
+ if v not in allowed_types:
107
+ raise ValueError(f"Transportation type must be one of {allowed_types}")
108
+ return v
109
+
110
+ @validator("travel_frequency")
111
+ def validate_frequency(cls, v):
112
+ allowed_frequencies = [
113
+ "rarely",
114
+ "occasional",
115
+ "frequent",
116
+ "very_frequent",
117
+ ]
118
+ if v not in allowed_frequencies:
119
+ raise ValueError(f"Travel frequency must be one of {allowed_frequencies}")
120
+ return v
121
+
122
+ @validator("budget_range")
123
+ def validate_budget(cls, v):
124
+ if v["min"] < 0:
125
+ raise ValueError("Minimum budget cannot be negative")
126
+ if v["max"] < v["min"]:
127
+ raise ValueError("Maximum budget must be greater than minimum budget")
128
+ return v
129
+
130
+
131
+ class UserProfile:
132
+ def __init__(self):
133
+ """Initialize the user profile manager."""
134
+ self.profiles_dir = os.path.join("data", "user_profiles")
135
+ self.backup_dir = os.path.join("data", "user_profiles_backup")
136
+ self._ensure_directories()
137
+ self.rate_limit_window = 3600 # 1 hour
138
+ self.max_updates_per_window = 10
139
+ self.update_history: Dict[str, list] = {}
140
+ self.max_profile_size = 1024 * 1024 # 1MB
141
+
142
+ def _ensure_directories(self):
143
+ """Ensure necessary directories exist."""
144
+ os.makedirs(self.profiles_dir, exist_ok=True)
145
+ os.makedirs(self.backup_dir, exist_ok=True)
146
+
147
+ def _validate_user_id(self, user_id: str) -> bool:
148
+ """Validate user ID format."""
149
+ if not user_id or not isinstance(user_id, str):
150
+ return False
151
+ # Allow alphanumeric characters, hyphens, and underscores
152
+ return bool(re.match(r"^[a-zA-Z0-9-_]+$", user_id))
153
+
154
+ def _check_rate_limit(self, user_id: str) -> bool:
155
+ """Check if user has exceeded rate limit."""
156
+ current_time = time.time()
157
+ if user_id not in self.update_history:
158
+ self.update_history[user_id] = []
159
+
160
+ # Remove old entries
161
+ self.update_history[user_id] = [
162
+ t
163
+ for t in self.update_history[user_id]
164
+ if current_time - t < self.rate_limit_window
165
+ ]
166
+
167
+ # Check if limit exceeded
168
+ if len(self.update_history[user_id]) >= self.max_updates_per_window:
169
+ return False
170
+
171
+ # Add new entry
172
+ self.update_history[user_id].append(current_time)
173
+ return True
174
+
175
+ def _create_backup(self, user_id: str) -> None:
176
+ """Create a backup of the user profile."""
177
+ try:
178
+ profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
179
+ if not os.path.exists(profile_path):
180
+ return
181
+
182
+ # Create backup with timestamp
183
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
184
+ backup_path = os.path.join(self.backup_dir, f"{user_id}_{timestamp}.json")
185
+ shutil.copy2(profile_path, backup_path)
186
+
187
+ # Keep only the last 5 backups
188
+ backups = sorted(
189
+ Path(self.backup_dir).glob(f"{user_id}_*.json"),
190
+ key=lambda x: x.stat().st_mtime,
191
+ reverse=True,
192
+ )
193
+ for old_backup in backups[5:]:
194
+ old_backup.unlink()
195
+
196
+ except Exception as e:
197
+ logger.error(
198
+ f"Error creating backup for {user_id}: {str(e)}", exc_info=True
199
+ )
200
+
201
+ def _cleanup_old_profiles(self):
202
+ """Clean up profiles older than 30 days."""
203
+ try:
204
+ current_time = time.time()
205
+ for filename in os.listdir(self.profiles_dir):
206
+ if not filename.endswith(".json"):
207
+ continue
208
+
209
+ file_path = os.path.join(self.profiles_dir, filename)
210
+ file_time = os.path.getmtime(file_path)
211
+
212
+ if current_time - file_time > 30 * 24 * 3600: # 30 days
213
+ try:
214
+ # Create final backup before deletion
215
+ user_id = filename[:-5] # Remove .json extension
216
+ self._create_backup(user_id)
217
+ os.remove(file_path)
218
+ logger.info(f"Removed old profile: {filename}")
219
+ except Exception as e:
220
+ logger.warning(
221
+ f"Error removing old profile {filename}: {str(e)}"
222
+ )
223
+ except Exception as e:
224
+ logger.error(f"Error cleaning up old profiles: {str(e)}", exc_info=True)
225
+
226
+ def get_profile(self, user_id: str) -> Dict[str, Any]:
227
+ """Get user profile with validation."""
228
+ try:
229
+ if not self._validate_user_id(user_id):
230
+ raise ValueError("Invalid user ID format")
231
+
232
+ profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
233
+
234
+ with _profile_lock:
235
+ if not os.path.exists(profile_path):
236
+ return self._create_default_profile(user_id)
237
+
238
+ # Check file size
239
+ if os.path.getsize(profile_path) > self.max_profile_size:
240
+ raise ValueError("Profile file size exceeds limit")
241
+
242
+ with open(profile_path, "r", encoding="utf-8") as f:
243
+ profile = json.load(f)
244
+
245
+ # Validate profile structure
246
+ if not isinstance(profile, dict):
247
+ raise ValueError("Invalid profile format")
248
+
249
+ # Ensure all required fields exist
250
+ required_fields = ["user_id", "preferences", "created_at", "updated_at"]
251
+ if not all(field in profile for field in required_fields):
252
+ raise ValueError("Missing required profile fields")
253
+
254
+ return profile
255
+
256
+ except Exception as e:
257
+ logger.error(
258
+ f"Error getting profile for {user_id}: {str(e)}", exc_info=True
259
+ )
260
+ raise
261
+
262
+ def update_profile(
263
+ self, user_id: str, preferences: Dict[str, Any]
264
+ ) -> Dict[str, Any]:
265
+ """Update user profile with validation and rate limiting."""
266
+ try:
267
+ if not self._validate_user_id(user_id):
268
+ raise ValueError("Invalid user ID format")
269
+
270
+ if not self._check_rate_limit(user_id):
271
+ raise ValueError("Rate limit exceeded")
272
+
273
+ # Validate preferences
274
+ try:
275
+ validated_preferences = UserPreferences(**preferences)
276
+ except Exception as e:
277
+ raise ValueError(f"Invalid preferences: {str(e)}")
278
+
279
+ profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
280
+
281
+ with _profile_lock:
282
+ # Create backup before update
283
+ self._create_backup(user_id)
284
+
285
+ current_profile = self.get_profile(user_id)
286
+
287
+ # Update profile
288
+ current_profile["preferences"] = validated_preferences.dict()
289
+ current_profile["updated_at"] = datetime.utcnow().isoformat()
290
+
291
+ # Save updated profile
292
+ with open(profile_path, "w", encoding="utf-8") as f:
293
+ json.dump(current_profile, f, indent=2)
294
+
295
+ logger.info(f"Updated profile for user {user_id}")
296
+ return current_profile
297
+
298
+ except Exception as e:
299
+ logger.error(
300
+ f"Error updating profile for {user_id}: {str(e)}", exc_info=True
301
+ )
302
+ raise
303
+
304
+ def _create_default_profile(self, user_id: str) -> Dict[str, Any]:
305
+ """Create a default profile with validation."""
306
+ try:
307
+ if not self._validate_user_id(user_id):
308
+ raise ValueError("Invalid user ID format")
309
+
310
+ default_preferences = UserPreferences().dict()
311
+ profile = {
312
+ "user_id": user_id,
313
+ "preferences": default_preferences,
314
+ "created_at": datetime.utcnow().isoformat(),
315
+ "updated_at": datetime.utcnow().isoformat(),
316
+ }
317
+
318
+ profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
319
+
320
+ with _profile_lock:
321
+ with open(profile_path, "w", encoding="utf-8") as f:
322
+ json.dump(profile, f, indent=2)
323
+
324
+ logger.info(f"Created default profile for user {user_id}")
325
+ return profile
326
+
327
+ except Exception as e:
328
+ logger.error(
329
+ f"Error creating default profile for {user_id}: {str(e)}", exc_info=True
330
+ )
331
+ raise
332
+
333
+ def delete_profile(self, user_id: str) -> None:
334
+ """Delete user profile with validation."""
335
+ try:
336
+ if not self._validate_user_id(user_id):
337
+ raise ValueError("Invalid user ID format")
338
+
339
+ profile_path = os.path.join(self.profiles_dir, f"{user_id}.json")
340
+
341
+ with _profile_lock:
342
+ if os.path.exists(profile_path):
343
+ # Create final backup before deletion
344
+ self._create_backup(user_id)
345
+ os.remove(profile_path)
346
+ logger.info(f"Deleted profile for user {user_id}")
347
+ else:
348
+ logger.warning(f"Profile not found for user {user_id}")
349
+
350
+ except Exception as e:
351
+ logger.error(
352
+ f"Error deleting profile for {user_id}: {str(e)}", exc_info=True
353
+ )
354
+ raise
355
+
356
+ def get_recommendations(self, user_id: str) -> Dict[str, Any]:
357
+ """Get personalized recommendations based on user profile with validation."""
358
+ try:
359
+ profile = self.get_profile(user_id)
360
+ if not profile or "preferences" not in profile:
361
+ return {}
362
+
363
+ preferences = UserPreferences(**profile["preferences"])
364
+ recommendations = {
365
+ "destinations": self._get_destination_recommendations(preferences),
366
+ "activities": self._get_activity_recommendations(preferences),
367
+ "tips": self._get_personalized_tips(preferences),
368
+ "generated_at": datetime.now().isoformat(),
369
+ }
370
+
371
+ return recommendations
372
+ except Exception as e:
373
+ logger.error(f"Error getting recommendations: {str(e)}", exc_info=True)
374
+ return {}
375
+
376
+ def _get_destination_recommendations(self, profile: UserPreferences) -> List[str]:
377
+ """Get destination recommendations based on preferences."""
378
+ try:
379
+ recommendations = []
380
+
381
+ # Add recommendations based on favorite destinations
382
+ if profile.preferred_destinations:
383
+ recommendations.extend(profile.preferred_destinations[:3])
384
+
385
+ # Add recommendations based on interests
386
+ if "beach" in profile.preferred_activities:
387
+ recommendations.append("Bali, Indonesia")
388
+ if "culture" in profile.preferred_activities:
389
+ recommendations.append("Kyoto, Japan")
390
+ if "food" in profile.preferred_activities:
391
+ recommendations.append("Bangkok, Thailand")
392
+
393
+ # Add recommendations based on travel style
394
+ if profile.travel_style == TravelStyle.LUXURY:
395
+ recommendations.append("Dubai, UAE")
396
+ elif profile.travel_style == TravelStyle.BUDGET:
397
+ recommendations.append("Bangkok, Thailand")
398
+
399
+ return list(set(recommendations))[:5] # Return top 5 unique recommendations
400
+ except Exception as e:
401
+ logger.error(
402
+ f"Error getting destination recommendations: {str(e)}", exc_info=True
403
+ )
404
+ return []
405
+
406
+ def _get_activity_recommendations(self, profile: UserPreferences) -> List[str]:
407
+ """Get activity recommendations based on preferences."""
408
+ try:
409
+ activities = []
410
+
411
+ # Add activities based on interests
412
+ if "culture" in profile.preferred_activities:
413
+ activities.append("Visit local museums and historical sites")
414
+ if "food" in profile.preferred_activities:
415
+ activities.append("Try local cuisine and food tours")
416
+ if "nature" in profile.preferred_activities:
417
+ activities.append("Explore national parks and hiking trails")
418
+ if "adventure" in profile.preferred_activities:
419
+ activities.append("Try adventure sports and activities")
420
+
421
+ # Add activities based on travel style
422
+ if profile.travel_style == TravelStyle.LUXURY:
423
+ activities.append("Book private guided tours")
424
+ elif profile.travel_style == TravelStyle.BUDGET:
425
+ activities.append("Explore local markets and street food")
426
+
427
+ return list(set(activities))[:5] # Return top 5 unique activities
428
+ except Exception as e:
429
+ logger.error(
430
+ f"Error getting activity recommendations: {str(e)}", exc_info=True
431
+ )
432
+ return []
433
+
434
+ def _get_personalized_tips(self, profile: UserPreferences) -> List[str]:
435
+ """Get personalized travel tips based on preferences."""
436
+ try:
437
+ tips = []
438
+
439
+ # Add tips based on travel style
440
+ if profile.travel_style == TravelStyle.BUDGET:
441
+ tips.append(
442
+ "Look for local markets and street food for affordable meals"
443
+ )
444
+ tips.append("Consider staying in hostels or guesthouses")
445
+ elif profile.travel_style == TravelStyle.LUXURY:
446
+ tips.append("Book premium experiences and private tours in advance")
447
+ tips.append("Consider luxury resorts and boutique hotels")
448
+
449
+ # Add tips based on dietary restrictions
450
+ if profile.dietary_restrictions:
451
+ tips.append(
452
+ f"Research restaurants that accommodate {', '.join(profile.dietary_restrictions)}"
453
+ )
454
+
455
+ # Add tips based on accessibility needs
456
+ if profile.accessibility_needs:
457
+ tips.append(
458
+ f"Research accessibility features for {', '.join(profile.accessibility_needs)}"
459
+ )
460
+
461
+ return list(set(tips))[:5] # Return top 5 unique tips
462
+ except Exception as e:
463
+ logger.error(f"Error getting personalized tips: {str(e)}", exc_info=True)
464
+ return []
data/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
docs/API.md ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TravelMate AI Assistant API Documentation
2
+
3
+ ## Overview
4
+
5
+ The TravelMate AI Assistant API provides a comprehensive set of endpoints for interacting with an AI-powered travel assistant. The API uses RAG (Retrieval-Augmented Generation) to provide accurate and contextually relevant travel information.
6
+
7
+ ## Base URL
8
+
9
+ ```
10
+ https://api.travelmate.ai/v1
11
+ ```
12
+
13
+ ## Authentication
14
+
15
+ All API endpoints require authentication using JWT (JSON Web Tokens). Include the token in the Authorization header:
16
+
17
+ ```
18
+ Authorization: Bearer <your_token>
19
+ ```
20
+
21
+ ## Rate Limiting
22
+
23
+ The API implements rate limiting to ensure fair usage:
24
+ - 100 requests per hour per user
25
+ - Rate limit headers are included in responses:
26
+ - `X-RateLimit-Limit`: Maximum requests per window
27
+ - `X-RateLimit-Remaining`: Remaining requests in current window
28
+ - `X-RateLimit-Reset`: Time until rate limit resets
29
+
30
+ ## Endpoints
31
+
32
+ ### Chat
33
+
34
+ #### POST /chat
35
+
36
+ Process a chat message and get AI-generated response.
37
+
38
+ **Request Body:**
39
+ ```json
40
+ {
41
+ "message": "What are the best places to visit in Paris?",
42
+ "chat_history": [
43
+ {
44
+ "user": "Hello",
45
+ "assistant": "Hi! How can I help you with your travel plans?"
46
+ }
47
+ ]
48
+ }
49
+ ```
50
+
51
+ **Response:**
52
+ ```json
53
+ {
54
+ "answer": "Here are some must-visit places in Paris...",
55
+ "sources": [
56
+ {
57
+ "title": "Paris Travel Guide",
58
+ "url": "https://example.com/paris-guide",
59
+ "relevance_score": 0.95
60
+ }
61
+ ],
62
+ "suggested_questions": [
63
+ "What's the best time to visit the Eiffel Tower?",
64
+ "Are there any hidden gems in Paris?"
65
+ ]
66
+ }
67
+ ```
68
+
69
+ ### User Profile
70
+
71
+ #### GET /profile
72
+
73
+ Get the current user's profile.
74
+
75
+ **Response:**
76
+ ```json
77
+ {
78
+ "user_id": "user_123",
79
+ "preferences": {
80
+ "travel_style": "balanced",
81
+ "preferred_destinations": ["Paris", "Tokyo"],
82
+ "dietary_restrictions": [],
83
+ "accessibility_needs": [],
84
+ "preferred_activities": ["sightseeing", "food"],
85
+ "budget_range": {
86
+ "min": 1000,
87
+ "max": 5000
88
+ },
89
+ "preferred_accommodation": "hotel",
90
+ "preferred_transportation": "flexible",
91
+ "travel_frequency": "occasional",
92
+ "preferred_seasons": ["spring", "fall"],
93
+ "special_requirements": []
94
+ }
95
+ }
96
+ ```
97
+
98
+ #### PUT /profile/preferences
99
+
100
+ Update user preferences.
101
+
102
+ **Request Body:**
103
+ ```json
104
+ {
105
+ "travel_style": "luxury",
106
+ "preferred_destinations": ["Paris", "Tokyo", "New York"],
107
+ "budget_range": {
108
+ "min": 2000,
109
+ "max": 10000
110
+ }
111
+ }
112
+ ```
113
+
114
+ **Response:**
115
+ ```json
116
+ {
117
+ "message": "Preferences updated successfully"
118
+ }
119
+ ```
120
+
121
+ ### Health Check
122
+
123
+ #### GET /health
124
+
125
+ Check the health status of the API and its components.
126
+
127
+ **Response:**
128
+ ```json
129
+ {
130
+ "status": "healthy",
131
+ "timestamp": "2024-02-20T12:00:00Z",
132
+ "version": "1.0.0",
133
+ "environment": "production",
134
+ "components": {
135
+ "rag_engine": "ok",
136
+ "user_profile": "ok"
137
+ }
138
+ }
139
+ ```
140
+
141
+ ## Error Handling
142
+
143
+ The API uses standard HTTP status codes and returns error responses in the following format:
144
+
145
+ ```json
146
+ {
147
+ "error": "Error type",
148
+ "detail": "Detailed error message",
149
+ "timestamp": "2024-02-20T12:00:00Z",
150
+ "request_id": "req_123"
151
+ }
152
+ ```
153
+
154
+ Common error codes:
155
+ - 400: Bad Request
156
+ - 401: Unauthorized
157
+ - 403: Forbidden
158
+ - 404: Not Found
159
+ - 429: Too Many Requests
160
+ - 500: Internal Server Error
161
+ - 503: Service Unavailable
162
+
163
+ ## Best Practices
164
+
165
+ 1. **Error Handling**
166
+ - Always check response status codes
167
+ - Implement exponential backoff for retries
168
+ - Handle rate limiting gracefully
169
+
170
+ 2. **Performance**
171
+ - Cache responses when appropriate
172
+ - Minimize chat history size
173
+ - Use compression for large requests
174
+
175
+ 3. **Security**
176
+ - Keep tokens secure
177
+ - Use HTTPS for all requests
178
+ - Validate all input data
179
+
180
+ 4. **Rate Limiting**
181
+ - Monitor rate limit headers
182
+ - Implement request queuing
183
+ - Handle 429 responses appropriately
184
+
185
+ ## SDKs and Examples
186
+
187
+ ### Python
188
+
189
+ ```python
190
+ import requests
191
+
192
+ class TravelMateClient:
193
+ def __init__(self, api_key, base_url="https://api.travelmate.ai/v1"):
194
+ self.api_key = api_key
195
+ self.base_url = base_url
196
+ self.session = requests.Session()
197
+ self.session.headers.update({
198
+ "Authorization": f"Bearer {api_key}",
199
+ "Content-Type": "application/json"
200
+ })
201
+
202
+ def chat(self, message, chat_history=None):
203
+ response = self.session.post(
204
+ f"{self.base_url}/chat",
205
+ json={
206
+ "message": message,
207
+ "chat_history": chat_history or []
208
+ }
209
+ )
210
+ response.raise_for_status()
211
+ return response.json()
212
+
213
+ def get_profile(self):
214
+ response = self.session.get(f"{self.base_url}/profile")
215
+ response.raise_for_status()
216
+ return response.json()
217
+
218
+ def update_preferences(self, preferences):
219
+ response = self.session.put(
220
+ f"{self.base_url}/profile/preferences",
221
+ json=preferences
222
+ )
223
+ response.raise_for_status()
224
+ return response.json()
225
+ ```
226
+
227
+ ### JavaScript
228
+
229
+ ```javascript
230
+ class TravelMateClient {
231
+ constructor(apiKey, baseUrl = 'https://api.travelmate.ai/v1') {
232
+ this.apiKey = apiKey;
233
+ this.baseUrl = baseUrl;
234
+ }
235
+
236
+ async chat(message, chatHistory = []) {
237
+ const response = await fetch(`${this.baseUrl}/chat`, {
238
+ method: 'POST',
239
+ headers: {
240
+ 'Authorization': `Bearer ${this.apiKey}`,
241
+ 'Content-Type': 'application/json'
242
+ },
243
+ body: JSON.stringify({
244
+ message,
245
+ chat_history: chatHistory
246
+ })
247
+ });
248
+
249
+ if (!response.ok) {
250
+ throw new Error(`API error: ${response.statusText}`);
251
+ }
252
+
253
+ return response.json();
254
+ }
255
+
256
+ async getProfile() {
257
+ const response = await fetch(`${this.baseUrl}/profile`, {
258
+ headers: {
259
+ 'Authorization': `Bearer ${this.apiKey}`
260
+ }
261
+ });
262
+
263
+ if (!response.ok) {
264
+ throw new Error(`API error: ${response.statusText}`);
265
+ }
266
+
267
+ return response.json();
268
+ }
269
+
270
+ async updatePreferences(preferences) {
271
+ const response = await fetch(`${this.baseUrl}/profile/preferences`, {
272
+ method: 'PUT',
273
+ headers: {
274
+ 'Authorization': `Bearer ${this.apiKey}`,
275
+ 'Content-Type': 'application/json'
276
+ },
277
+ body: JSON.stringify(preferences)
278
+ });
279
+
280
+ if (!response.ok) {
281
+ throw new Error(`API error: ${response.statusText}`);
282
+ }
283
+
284
+ return response.json();
285
+ }
286
+ }
287
+ ```
288
+
289
+ ## Support
290
+
291
+ For API support, please contact:
292
+ - Email: [email protected]
293
+ - Documentation: https://docs.travelmate.ai
294
+ - Status Page: https://status.travelmate.ai
huggingface.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sdk: gradio
2
+ sdk_version: 4.19.2
3
+ app_file: app.py
4
+ python_version: "3.10"
5
+
6
+ # Hardware requirements
7
+ hardware:
8
+ cpu: 2
9
+ memory: 16GB
10
+
11
+ # Build settings
12
+ build:
13
+ cuda: "None" # No CUDA needed for CPU-only
14
+ system_packages:
15
+ - build-essential
16
+ - python3-dev
17
+ - cmake
18
+ - pkg-config
19
+ - libopenblas-dev
20
+ - libomp-dev
21
+
22
+ # Environment variables
23
+ env:
24
+ - MODEL_NAME=meta-llama/Llama-2-7b-chat-hf
25
+ - EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
26
+ - SECRET_KEY=${SECRET_KEY}
27
+ - JWT_SECRET_KEY=${JWT_SECRET_KEY}
28
+ - RATE_LIMIT_REQUESTS=100
29
+ - RATE_LIMIT_WINDOW=3600
30
+ - LOG_LEVEL=INFO
31
+
32
+ # Dependencies
33
+ dependencies:
34
+ - gradio==4.19.2
35
+ - langchain==0.1.9
36
+ - langchain-core>=0.1.52,<0.2
37
+ - langchain-community==0.0.27
38
+ - langchain-text-splitters==0.0.1
39
+ - langchain-huggingface==0.0.3
40
+ - transformers==4.38.2
41
+ - torch==2.2.1
42
+ - accelerate==0.27.2
43
+ - bitsandbytes==0.42.0
44
+ - safetensors==0.4.2
45
+ - sentence-transformers==2.6.1
46
+ - faiss-cpu==1.7.4
47
+ - pydantic==2.5.3
48
+ - pydantic-settings==2.1.0
49
+ - python-dotenv==1.0.0
50
+ - fastapi==0.109.2
51
+ - uvicorn==0.27.1
52
+ - python-jose==3.3.0
53
+ - passlib==1.7.4
54
+ - python-multipart
55
+ - bcrypt==4.1.2
56
+ - httpx==0.26.0
57
+ - aiohttp==3.9.5
58
+ - tenacity==8.2.3
59
+ - cachetools==5.3.2
60
+ - numpy==1.26.3
61
+ - tqdm==4.66.1
62
+ - loguru==0.7.2
63
+ - datasets==2.16.1
64
+ - huggingface-hub==0.24.1
65
+ - circuitbreaker==1.4.0
66
+
67
+ # Health check
68
+ health_check:
69
+ path: /health
70
+ interval: 300
71
+ timeout: 10
72
+ retries: 3
73
+
74
+ # Resource limits
75
+ resources:
76
+ cpu: 2
77
+ memory: 16GB
78
+
79
+ # Cache settings
80
+ cache:
81
+ enabled: true
82
+ ttl: 3600
83
+ max_size: 1000
84
+
85
+ # Logging
86
+ logging:
87
+ level: INFO
88
+ format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
89
+ handlers:
90
+ - type: file
91
+ filename: app.log
92
+ max_bytes: 10485760
93
+ backup_count: 5
94
+ - type: stream
95
+ stream: ext://sys.stdout
96
+
97
+ # Space settings
98
+ space:
99
+ title: "TravelMate - AI Travel Assistant"
100
+ description: "An AI-powered travel assistant using Llama-2 and RAG to help plan trips and provide travel information"
101
+ license: mit
102
+ sdk: gradio
103
+ app_port: 7860
104
+ app_url: "https://huggingface.co/spaces/bharadwaj-m/TravelMate-AI"
105
+
106
+ # Build commands
107
+ build:
108
+ - pip install -r requirements.txt
109
+ - mkdir -p data/vector_store data/user_profiles data/cache
110
+ - python -c "from core.data_loader import DataLoader; DataLoader().initialize_knowledge_base()"
111
+
112
+ # Run command
113
+ run: python app.py
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core frameworks and UI
2
+ gradio==4.19.2
3
+ fastapi==0.109.2
4
+ uvicorn==0.27.1
5
+
6
+ # LangChain and ecosystem (aligned versions)
7
+ langchain==0.1.9
8
+ langchain-core>=0.1.52,<0.2
9
+ langchain-community==0.0.27
10
+ langchain-text-splitters==0.0.1
11
+ langchain-huggingface==0.0.3
12
+
13
+ # LLM and Transformers
14
+ transformers==4.41.2
15
+ torch==2.2.1
16
+ accelerate==0.27.2
17
+
18
+ # Embeddings / similarity search
19
+ sentence-transformers==2.6.1
20
+ faiss-cpu==1.7.4
21
+
22
+ datasets==2.16.1
23
+
24
+ # Security and auth
25
+ python-jose==3.3.0
26
+ passlib==1.7.4
27
+ bcrypt==4.1.2
28
+ python-multipart
29
+
30
+ # Data & utils
31
+ pydantic==2.5.3
32
+ pydantic-settings==2.1.0
33
+ python-dotenv==1.0.0
34
+ httpx==0.26.0
35
+ aiohttp==3.9.5
36
+ tenacity==8.2.3
37
+ expiringdict==1.2.1
38
+ numpy==1.26.3
39
+ tqdm==4.66.1
40
+ loguru==0.7.2
41
+ huggingface-hub==0.24.1
42
+
43
+ # Resilience / patterns
44
+ circuitbreaker==1.4.0