import os import uuid import datetime from typing import Dict, List, Optional, Any from pymongo import MongoClient from bson.objectid import ObjectId class MongoDBHelper: """Helper class for MongoDB operations""" def __init__(self, connection_string: Optional[str] = None): """Initialize the MongoDB client""" # Get connection string from env var or use provided one self.connection_string = connection_string or os.getenv('MONGODB_URI') if not self.connection_string: raise ValueError("MongoDB connection string not provided. Set MONGODB_URI environment variable or pass it to constructor.") self.client = MongoClient(self.connection_string) self.db = self.client.get_database("pyscout_ai") # Collections self.api_keys_collection = self.db.api_keys self.usage_collection = self.db.usage self.users_collection = self.db.users self.conversations_collection = self.db.conversations self.messages_collection = self.db.messages self._create_indexes() def _create_indexes(self): # API Keys indexes self.api_keys_collection.create_index("key", unique=True) self.api_keys_collection.create_index("user_id") self.api_keys_collection.create_index("created_at") # Usage indexes self.usage_collection.create_index("api_key") self.usage_collection.create_index("timestamp") # Users indexes self.users_collection.create_index("email", unique=True) # Conversations indexes self.conversations_collection.create_index("user_id") self.conversations_collection.create_index("created_at") # Messages indexes self.messages_collection.create_index("conversation_id") self.messages_collection.create_index("timestamp") def create_user(self, email: str, name: str, organization: str = None) -> str: user_id = str(ObjectId()) self.users_collection.insert_one({ "_id": ObjectId(user_id), "email": email, "name": name, "organization": organization, "created_at": datetime.datetime.utcnow(), "last_active": datetime.datetime.utcnow() }) return user_id def create_conversation(self, user_id: str, system_prompt: str = None) -> str: conversation_id = str(ObjectId()) self.conversations_collection.insert_one({ "_id": ObjectId(conversation_id), "user_id": user_id, "system_prompt": system_prompt, "created_at": datetime.datetime.utcnow(), "last_message_at": datetime.datetime.utcnow(), "is_active": True }) return conversation_id def add_message(self, conversation_id: str, role: str, content: str, model: str = None, tokens: int = 0) -> str: message_id = str(ObjectId()) self.messages_collection.insert_one({ "_id": ObjectId(message_id), "conversation_id": conversation_id, "role": role, "content": content, "model": model, "tokens": tokens, "timestamp": datetime.datetime.utcnow() }) # Update conversation last_message_at self.conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, {"$set": {"last_message_at": datetime.datetime.utcnow()}} ) return message_id def get_conversation_history(self, conversation_id: str) -> List[Dict]: return list(self.messages_collection.find( {"conversation_id": conversation_id}, {"_id": 0} ).sort("timestamp", 1)) def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]: conversations = list(self.conversations_collection.find( {"user_id": user_id}, {"_id": 1, "system_prompt": 1, "created_at": 1, "last_message_at": 1} ).sort("last_message_at", -1).limit(limit)) # Convert ObjectId to string for conv in conversations: conv["_id"] = str(conv["_id"]) return conversations def generate_api_key(self, user_id: str, name: str = "Default API Key") -> str: """Generate a new API key for a user""" # Format: PyScoutAI-{uuid4-hex} api_key = f"PyScoutAI-{uuid.uuid4().hex}" # Store in database self.api_keys_collection.insert_one({ "key": api_key, "user_id": user_id, "name": name, "created_at": datetime.datetime.utcnow(), "last_used": None, "is_active": True, "rate_limit": { "requests_per_day": 1000, "tokens_per_day": 1000000 } }) return api_key def validate_api_key(self, api_key: str) -> Dict[str, Any]: """ Validate an API key Returns: Dict with user info if valid, None otherwise """ if not api_key: return None # Find the API key in the database key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True}) if not key_data: return None # Update last used timestamp self.api_keys_collection.update_one( {"_id": key_data["_id"]}, {"$set": {"last_used": datetime.datetime.utcnow()}} ) return key_data def log_api_usage(self, api_key: str, endpoint: str, tokens: int = 0, model: str = None, conversation_id: str = None): usage_data = { "api_key": api_key, "endpoint": endpoint, "tokens": tokens, "model": model, "timestamp": datetime.datetime.utcnow() } if conversation_id: usage_data["conversation_id"] = conversation_id self.usage_collection.insert_one(usage_data) def get_user_api_keys(self, user_id: str) -> List[Dict[str, Any]]: """Get all API keys for a user""" keys = list(self.api_keys_collection.find({"user_id": user_id})) # Convert ObjectId to string for JSON serialization for key in keys: key["_id"] = str(key["_id"]) return keys def revoke_api_key(self, api_key: str) -> bool: """Revoke an API key""" result = self.api_keys_collection.update_one( {"key": api_key}, {"$set": {"is_active": False}} ) return result.modified_count > 0 def check_rate_limit(self, api_key: str) -> Dict[str, Any]: """ Check if the API key has exceeded its rate limits Returns: Dict with rate limit info and allowed status """ key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True}) if not key_data: return {"allowed": False, "reason": "Invalid API key"} # Get rate limit settings rate_limit = key_data.get("rate_limit", {}) requests_per_day = rate_limit.get("requests_per_day", 1000) tokens_per_day = rate_limit.get("tokens_per_day", 1000000) # Calculate usage for today today_start = datetime.datetime.combine( datetime.datetime.utcnow().date(), datetime.time.min ) # Count requests today requests_today = self.usage_collection.count_documents({ "api_key": api_key, "timestamp": {"$gte": today_start} }) # Sum tokens used today tokens_pipeline = [ {"$match": {"api_key": api_key, "timestamp": {"$gte": today_start}}}, {"$group": {"_id": None, "total_tokens": {"$sum": "$tokens"}}} ] tokens_result = list(self.usage_collection.aggregate(tokens_pipeline)) tokens_today = tokens_result[0]["total_tokens"] if tokens_result else 0 # Check if limits are exceeded if requests_today >= requests_per_day: return { "allowed": False, "reason": "Daily request limit exceeded", "limit": requests_per_day, "used": requests_today } if tokens_today >= tokens_per_day: return { "allowed": False, "reason": "Daily token limit exceeded", "limit": tokens_per_day, "used": tokens_today } return { "allowed": True, "requests": { "limit": requests_per_day, "used": requests_today, "remaining": requests_per_day - requests_today }, "tokens": { "limit": tokens_per_day, "used": tokens_today, "remaining": tokens_per_day - tokens_today } } def get_user_stats(self, user_id: str) -> Dict: pipeline = [ {"$match": {"user_id": user_id}}, {"$group": { "_id": None, "total_conversations": {"$sum": 1}, "total_messages": {"$sum": "$message_count"}, "total_tokens": {"$sum": "$total_tokens"} }} ] stats = list(self.conversations_collection.aggregate(pipeline)) return stats[0] if stats else {"total_conversations": 0, "total_messages": 0, "total_tokens": 0}