PyscoutAI / db_helper.py
PyScoutAI's picture
Upload 15 files
ead2510 verified
import os
import uuid
import datetime
from typing import Dict, List, Optional, Any
from pymongo import MongoClient
from bson.objectid import ObjectId
class MongoDBHelper:
"""Helper class for MongoDB operations"""
def __init__(self, connection_string: Optional[str] = None):
"""Initialize the MongoDB client"""
# Get connection string from env var or use provided one
self.connection_string = connection_string or os.getenv('MONGODB_URI')
if not self.connection_string:
raise ValueError("MongoDB connection string not provided. Set MONGODB_URI environment variable or pass it to constructor.")
self.client = MongoClient(self.connection_string)
self.db = self.client.get_database("pyscout_ai")
# Collections
self.api_keys_collection = self.db.api_keys
self.usage_collection = self.db.usage
self.users_collection = self.db.users
self.conversations_collection = self.db.conversations
self.messages_collection = self.db.messages
self._create_indexes()
def _create_indexes(self):
# API Keys indexes
self.api_keys_collection.create_index("key", unique=True)
self.api_keys_collection.create_index("user_id")
self.api_keys_collection.create_index("created_at")
# Usage indexes
self.usage_collection.create_index("api_key")
self.usage_collection.create_index("timestamp")
# Users indexes
self.users_collection.create_index("email", unique=True)
# Conversations indexes
self.conversations_collection.create_index("user_id")
self.conversations_collection.create_index("created_at")
# Messages indexes
self.messages_collection.create_index("conversation_id")
self.messages_collection.create_index("timestamp")
def create_user(self, email: str, name: str, organization: str = None) -> str:
user_id = str(ObjectId())
self.users_collection.insert_one({
"_id": ObjectId(user_id),
"email": email,
"name": name,
"organization": organization,
"created_at": datetime.datetime.utcnow(),
"last_active": datetime.datetime.utcnow()
})
return user_id
def create_conversation(self, user_id: str, system_prompt: str = None) -> str:
conversation_id = str(ObjectId())
self.conversations_collection.insert_one({
"_id": ObjectId(conversation_id),
"user_id": user_id,
"system_prompt": system_prompt,
"created_at": datetime.datetime.utcnow(),
"last_message_at": datetime.datetime.utcnow(),
"is_active": True
})
return conversation_id
def add_message(self, conversation_id: str, role: str, content: str,
model: str = None, tokens: int = 0) -> str:
message_id = str(ObjectId())
self.messages_collection.insert_one({
"_id": ObjectId(message_id),
"conversation_id": conversation_id,
"role": role,
"content": content,
"model": model,
"tokens": tokens,
"timestamp": datetime.datetime.utcnow()
})
# Update conversation last_message_at
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{"$set": {"last_message_at": datetime.datetime.utcnow()}}
)
return message_id
def get_conversation_history(self, conversation_id: str) -> List[Dict]:
return list(self.messages_collection.find(
{"conversation_id": conversation_id},
{"_id": 0}
).sort("timestamp", 1))
def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
conversations = list(self.conversations_collection.find(
{"user_id": user_id},
{"_id": 1, "system_prompt": 1, "created_at": 1, "last_message_at": 1}
).sort("last_message_at", -1).limit(limit))
# Convert ObjectId to string
for conv in conversations:
conv["_id"] = str(conv["_id"])
return conversations
def generate_api_key(self, user_id: str, name: str = "Default API Key") -> str:
"""Generate a new API key for a user"""
# Format: PyScoutAI-{uuid4-hex}
api_key = f"PyScoutAI-{uuid.uuid4().hex}"
# Store in database
self.api_keys_collection.insert_one({
"key": api_key,
"user_id": user_id,
"name": name,
"created_at": datetime.datetime.utcnow(),
"last_used": None,
"is_active": True,
"rate_limit": {
"requests_per_day": 1000,
"tokens_per_day": 1000000
}
})
return api_key
def validate_api_key(self, api_key: str) -> Dict[str, Any]:
"""
Validate an API key
Returns:
Dict with user info if valid, None otherwise
"""
if not api_key:
return None
# Find the API key in the database
key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True})
if not key_data:
return None
# Update last used timestamp
self.api_keys_collection.update_one(
{"_id": key_data["_id"]},
{"$set": {"last_used": datetime.datetime.utcnow()}}
)
return key_data
def log_api_usage(self, api_key: str, endpoint: str, tokens: int = 0,
model: str = None, conversation_id: str = None):
usage_data = {
"api_key": api_key,
"endpoint": endpoint,
"tokens": tokens,
"model": model,
"timestamp": datetime.datetime.utcnow()
}
if conversation_id:
usage_data["conversation_id"] = conversation_id
self.usage_collection.insert_one(usage_data)
def get_user_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
"""Get all API keys for a user"""
keys = list(self.api_keys_collection.find({"user_id": user_id}))
# Convert ObjectId to string for JSON serialization
for key in keys:
key["_id"] = str(key["_id"])
return keys
def revoke_api_key(self, api_key: str) -> bool:
"""Revoke an API key"""
result = self.api_keys_collection.update_one(
{"key": api_key},
{"$set": {"is_active": False}}
)
return result.modified_count > 0
def check_rate_limit(self, api_key: str) -> Dict[str, Any]:
"""
Check if the API key has exceeded its rate limits
Returns:
Dict with rate limit info and allowed status
"""
key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True})
if not key_data:
return {"allowed": False, "reason": "Invalid API key"}
# Get rate limit settings
rate_limit = key_data.get("rate_limit", {})
requests_per_day = rate_limit.get("requests_per_day", 1000)
tokens_per_day = rate_limit.get("tokens_per_day", 1000000)
# Calculate usage for today
today_start = datetime.datetime.combine(
datetime.datetime.utcnow().date(),
datetime.time.min
)
# Count requests today
requests_today = self.usage_collection.count_documents({
"api_key": api_key,
"timestamp": {"$gte": today_start}
})
# Sum tokens used today
tokens_pipeline = [
{"$match": {"api_key": api_key, "timestamp": {"$gte": today_start}}},
{"$group": {"_id": None, "total_tokens": {"$sum": "$tokens"}}}
]
tokens_result = list(self.usage_collection.aggregate(tokens_pipeline))
tokens_today = tokens_result[0]["total_tokens"] if tokens_result else 0
# Check if limits are exceeded
if requests_today >= requests_per_day:
return {
"allowed": False,
"reason": "Daily request limit exceeded",
"limit": requests_per_day,
"used": requests_today
}
if tokens_today >= tokens_per_day:
return {
"allowed": False,
"reason": "Daily token limit exceeded",
"limit": tokens_per_day,
"used": tokens_today
}
return {
"allowed": True,
"requests": {
"limit": requests_per_day,
"used": requests_today,
"remaining": requests_per_day - requests_today
},
"tokens": {
"limit": tokens_per_day,
"used": tokens_today,
"remaining": tokens_per_day - tokens_today
}
}
def get_user_stats(self, user_id: str) -> Dict:
pipeline = [
{"$match": {"user_id": user_id}},
{"$group": {
"_id": None,
"total_conversations": {"$sum": 1},
"total_messages": {"$sum": "$message_count"},
"total_tokens": {"$sum": "$total_tokens"}
}}
]
stats = list(self.conversations_collection.aggregate(pipeline))
return stats[0] if stats else {"total_conversations": 0, "total_messages": 0, "total_tokens": 0}