|
import os
|
|
import uuid
|
|
import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
from pymongo import MongoClient
|
|
from bson.objectid import ObjectId
|
|
|
|
class MongoDBHelper:
|
|
"""Helper class for MongoDB operations"""
|
|
|
|
def __init__(self, connection_string: Optional[str] = None):
|
|
"""Initialize the MongoDB client"""
|
|
|
|
self.connection_string = connection_string or os.getenv('MONGODB_URI')
|
|
|
|
if not self.connection_string:
|
|
raise ValueError("MongoDB connection string not provided. Set MONGODB_URI environment variable or pass it to constructor.")
|
|
|
|
self.client = MongoClient(self.connection_string)
|
|
self.db = self.client.get_database("pyscout_ai")
|
|
|
|
|
|
self.api_keys_collection = self.db.api_keys
|
|
self.usage_collection = self.db.usage
|
|
self.users_collection = self.db.users
|
|
self.conversations_collection = self.db.conversations
|
|
self.messages_collection = self.db.messages
|
|
|
|
self._create_indexes()
|
|
|
|
def _create_indexes(self):
|
|
|
|
self.api_keys_collection.create_index("key", unique=True)
|
|
self.api_keys_collection.create_index("user_id")
|
|
self.api_keys_collection.create_index("created_at")
|
|
|
|
|
|
self.usage_collection.create_index("api_key")
|
|
self.usage_collection.create_index("timestamp")
|
|
|
|
|
|
self.users_collection.create_index("email", unique=True)
|
|
|
|
|
|
self.conversations_collection.create_index("user_id")
|
|
self.conversations_collection.create_index("created_at")
|
|
|
|
|
|
self.messages_collection.create_index("conversation_id")
|
|
self.messages_collection.create_index("timestamp")
|
|
|
|
def create_user(self, email: str, name: str, organization: str = None) -> str:
|
|
user_id = str(ObjectId())
|
|
self.users_collection.insert_one({
|
|
"_id": ObjectId(user_id),
|
|
"email": email,
|
|
"name": name,
|
|
"organization": organization,
|
|
"created_at": datetime.datetime.utcnow(),
|
|
"last_active": datetime.datetime.utcnow()
|
|
})
|
|
return user_id
|
|
|
|
def create_conversation(self, user_id: str, system_prompt: str = None) -> str:
|
|
conversation_id = str(ObjectId())
|
|
self.conversations_collection.insert_one({
|
|
"_id": ObjectId(conversation_id),
|
|
"user_id": user_id,
|
|
"system_prompt": system_prompt,
|
|
"created_at": datetime.datetime.utcnow(),
|
|
"last_message_at": datetime.datetime.utcnow(),
|
|
"is_active": True
|
|
})
|
|
return conversation_id
|
|
|
|
def add_message(self, conversation_id: str, role: str, content: str,
|
|
model: str = None, tokens: int = 0) -> str:
|
|
message_id = str(ObjectId())
|
|
self.messages_collection.insert_one({
|
|
"_id": ObjectId(message_id),
|
|
"conversation_id": conversation_id,
|
|
"role": role,
|
|
"content": content,
|
|
"model": model,
|
|
"tokens": tokens,
|
|
"timestamp": datetime.datetime.utcnow()
|
|
})
|
|
|
|
|
|
self.conversations_collection.update_one(
|
|
{"_id": ObjectId(conversation_id)},
|
|
{"$set": {"last_message_at": datetime.datetime.utcnow()}}
|
|
)
|
|
|
|
return message_id
|
|
|
|
def get_conversation_history(self, conversation_id: str) -> List[Dict]:
|
|
return list(self.messages_collection.find(
|
|
{"conversation_id": conversation_id},
|
|
{"_id": 0}
|
|
).sort("timestamp", 1))
|
|
|
|
def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
|
|
conversations = list(self.conversations_collection.find(
|
|
{"user_id": user_id},
|
|
{"_id": 1, "system_prompt": 1, "created_at": 1, "last_message_at": 1}
|
|
).sort("last_message_at", -1).limit(limit))
|
|
|
|
|
|
for conv in conversations:
|
|
conv["_id"] = str(conv["_id"])
|
|
return conversations
|
|
|
|
def generate_api_key(self, user_id: str, name: str = "Default API Key") -> str:
|
|
"""Generate a new API key for a user"""
|
|
|
|
api_key = f"PyScoutAI-{uuid.uuid4().hex}"
|
|
|
|
|
|
self.api_keys_collection.insert_one({
|
|
"key": api_key,
|
|
"user_id": user_id,
|
|
"name": name,
|
|
"created_at": datetime.datetime.utcnow(),
|
|
"last_used": None,
|
|
"is_active": True,
|
|
"rate_limit": {
|
|
"requests_per_day": 1000,
|
|
"tokens_per_day": 1000000
|
|
}
|
|
})
|
|
|
|
return api_key
|
|
|
|
def validate_api_key(self, api_key: str) -> Dict[str, Any]:
|
|
"""
|
|
Validate an API key
|
|
|
|
Returns:
|
|
Dict with user info if valid, None otherwise
|
|
"""
|
|
if not api_key:
|
|
return None
|
|
|
|
|
|
key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True})
|
|
if not key_data:
|
|
return None
|
|
|
|
|
|
self.api_keys_collection.update_one(
|
|
{"_id": key_data["_id"]},
|
|
{"$set": {"last_used": datetime.datetime.utcnow()}}
|
|
)
|
|
|
|
return key_data
|
|
|
|
def log_api_usage(self, api_key: str, endpoint: str, tokens: int = 0,
|
|
model: str = None, conversation_id: str = None):
|
|
usage_data = {
|
|
"api_key": api_key,
|
|
"endpoint": endpoint,
|
|
"tokens": tokens,
|
|
"model": model,
|
|
"timestamp": datetime.datetime.utcnow()
|
|
}
|
|
if conversation_id:
|
|
usage_data["conversation_id"] = conversation_id
|
|
|
|
self.usage_collection.insert_one(usage_data)
|
|
|
|
def get_user_api_keys(self, user_id: str) -> List[Dict[str, Any]]:
|
|
"""Get all API keys for a user"""
|
|
keys = list(self.api_keys_collection.find({"user_id": user_id}))
|
|
|
|
for key in keys:
|
|
key["_id"] = str(key["_id"])
|
|
return keys
|
|
|
|
def revoke_api_key(self, api_key: str) -> bool:
|
|
"""Revoke an API key"""
|
|
result = self.api_keys_collection.update_one(
|
|
{"key": api_key},
|
|
{"$set": {"is_active": False}}
|
|
)
|
|
return result.modified_count > 0
|
|
|
|
def check_rate_limit(self, api_key: str) -> Dict[str, Any]:
|
|
"""
|
|
Check if the API key has exceeded its rate limits
|
|
|
|
Returns:
|
|
Dict with rate limit info and allowed status
|
|
"""
|
|
key_data = self.api_keys_collection.find_one({"key": api_key, "is_active": True})
|
|
if not key_data:
|
|
return {"allowed": False, "reason": "Invalid API key"}
|
|
|
|
|
|
rate_limit = key_data.get("rate_limit", {})
|
|
requests_per_day = rate_limit.get("requests_per_day", 1000)
|
|
tokens_per_day = rate_limit.get("tokens_per_day", 1000000)
|
|
|
|
|
|
today_start = datetime.datetime.combine(
|
|
datetime.datetime.utcnow().date(),
|
|
datetime.time.min
|
|
)
|
|
|
|
|
|
requests_today = self.usage_collection.count_documents({
|
|
"api_key": api_key,
|
|
"timestamp": {"$gte": today_start}
|
|
})
|
|
|
|
|
|
tokens_pipeline = [
|
|
{"$match": {"api_key": api_key, "timestamp": {"$gte": today_start}}},
|
|
{"$group": {"_id": None, "total_tokens": {"$sum": "$tokens"}}}
|
|
]
|
|
tokens_result = list(self.usage_collection.aggregate(tokens_pipeline))
|
|
tokens_today = tokens_result[0]["total_tokens"] if tokens_result else 0
|
|
|
|
|
|
if requests_today >= requests_per_day:
|
|
return {
|
|
"allowed": False,
|
|
"reason": "Daily request limit exceeded",
|
|
"limit": requests_per_day,
|
|
"used": requests_today
|
|
}
|
|
|
|
if tokens_today >= tokens_per_day:
|
|
return {
|
|
"allowed": False,
|
|
"reason": "Daily token limit exceeded",
|
|
"limit": tokens_per_day,
|
|
"used": tokens_today
|
|
}
|
|
|
|
return {
|
|
"allowed": True,
|
|
"requests": {
|
|
"limit": requests_per_day,
|
|
"used": requests_today,
|
|
"remaining": requests_per_day - requests_today
|
|
},
|
|
"tokens": {
|
|
"limit": tokens_per_day,
|
|
"used": tokens_today,
|
|
"remaining": tokens_per_day - tokens_today
|
|
}
|
|
}
|
|
|
|
def get_user_stats(self, user_id: str) -> Dict:
|
|
pipeline = [
|
|
{"$match": {"user_id": user_id}},
|
|
{"$group": {
|
|
"_id": None,
|
|
"total_conversations": {"$sum": 1},
|
|
"total_messages": {"$sum": "$message_count"},
|
|
"total_tokens": {"$sum": "$total_tokens"}
|
|
}}
|
|
]
|
|
stats = list(self.conversations_collection.aggregate(pipeline))
|
|
return stats[0] if stats else {"total_conversations": 0, "total_messages": 0, "total_tokens": 0}
|
|
|