PyscoutAI / pyscout_api.py
PyScoutAI's picture
Update pyscout_api.py
5bfdf79 verified
import asyncio
import time
import uuid
import os
import json
from typing import Dict, List, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Depends, Request, status, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field, EmailStr
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import uvicorn
from db_helper import MongoDBHelper
from deepinfra_client import DeepInfraClient
from hf_utils import HuggingFaceSpaceHelper
# Initialize Hugging Face Space helper
hf_helper = HuggingFaceSpaceHelper()
# Install required packages for HF Spaces if needed
if hf_helper.is_in_space:
hf_helper.install_dependencies([
"pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi",
"fake-useragent", "requests-ip-rotator", "pydantic[email]"
])
# Initialize FastAPI app
app = FastAPI(
title="PyScoutAI API",
description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features",
version="1.0.0"
)
# Setup rate limiting
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Set up CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security
security = HTTPBearer(auto_error=False)
# Database helper
try:
db = MongoDBHelper(hf_helper.get_mongodb_uri())
except Exception as e:
print(f"Warning: MongoDB connection failed: {e}")
print("API key authentication will not work!")
db = None
# Models for requests and responses
class Message(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
class UserCreate(BaseModel):
email: EmailStr
name: str
organization: Optional[str] = None
class APIKeyCreate(BaseModel):
name: str = "Default API Key"
user_id: str
class APIKeyResponse(BaseModel):
key: str
name: str
created_at: str
# API clients storage (one per API key)
clients: Dict[str, DeepInfraClient] = {}
# Helper function to get the API key from the request
async def get_api_key(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[str]:
# Check Authorization header
if credentials:
return credentials.credentials
# Check for API key in the request headers
if "Authorization" in request.headers:
auth = request.headers["Authorization"]
if auth.startswith("Bearer "):
return auth.replace("Bearer ", "")
if "x-api-key" in request.headers:
return request.headers["x-api-key"]
# Check for API key in query parameters
api_key = request.query_params.get("api_key")
if api_key:
return api_key
# No API key found, return None
return None
# Helper function to validate a PyScout API key and get user info
async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]:
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key is required",
headers={"WWW-Authenticate": "Bearer"}
)
# Skip validation if DB is not connected (development mode)
if not db:
return {"user_id": "development", "key": api_key}
# Check if key starts with PyScoutAI-
if not api_key.startswith("PyScoutAI-"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key format",
headers={"WWW-Authenticate": "Bearer"}
)
# Validate the API key
user_info = db.validate_api_key(api_key)
if not user_info:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"}
)
# Check rate limits
rate_limit = db.check_rate_limit(api_key)
if not rate_limit["allowed"]:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=rate_limit["reason"]
)
return user_info
# Helper function to get or create a client
def get_client(api_key: str) -> DeepInfraClient:
if api_key not in clients:
# Create a client with IP rotation and random user agent
clients[api_key] = DeepInfraClient(
use_random_user_agent=True,
use_proxy_rotation=True,
use_ip_rotation=True
)
return clients[api_key]
@app.get("/")
async def root():
metadata = hf_helper.get_hf_metadata()
return {
"message": "Welcome to PyScoutAI API",
"documentation": "/docs",
"environment": "Hugging Face Space" if hf_helper.is_in_space else "Local",
"endpoints": [
"/v1/models",
"/v1/chat/completions",
"/v1/completions"
],
**metadata
}
@app.get("/v1/models")
@limiter.limit("20/minute")
async def list_models(
request: Request,
user_info: Dict[str, Any] = Depends(get_user_info)
):
api_key = user_info["key"]
client = get_client(api_key)
try:
models = await asyncio.to_thread(client.models.list)
# Log the API usage
if db:
db.log_api_usage(api_key, "/v1/models", 0)
return models
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
@app.post("/v1/chat/completions")
@limiter.limit("60/minute")
async def create_chat_completion(
request: Request,
body: ChatCompletionRequest,
user_info: Dict[str, Any] = Depends(get_user_info)
):
api_key = user_info["key"]
client = get_client(api_key)
try:
# Prepare the messages
messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None]
kwargs = {
"model": body.model,
"temperature": body.temperature,
"max_tokens": body.max_tokens,
"stream": body.stream,
"top_p": body.top_p,
"presence_penalty": body.presence_penalty,
"frequency_penalty": body.frequency_penalty,
}
if body.stream:
async def generate_stream():
response_stream = await asyncio.to_thread(
client.chat.create,
messages=messages,
**kwargs
)
total_tokens = 0
for chunk in response_stream:
# Track token usage for each chunk if available
if 'usage' in chunk and chunk['usage']:
total_tokens += chunk['usage'].get('total_tokens', 0)
yield f"data: {json.dumps(chunk)}\n\n"
# Log API usage at the end of streaming
if db:
db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream"
)
else:
response = await asyncio.to_thread(
client.chat.create,
messages=messages,
**kwargs
)
# Log the API usage
if db and 'usage' in response:
total_tokens = response['usage'].get('total_tokens', 0)
db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}")
@app.post("/v1/completions")
@limiter.limit("60/minute")
async def create_completion(
request: Request,
body: CompletionRequest,
user_info: Dict[str, Any] = Depends(get_user_info)
):
api_key = user_info["key"]
client = get_client(api_key)
try:
# Handle different prompt types
prompt = body.prompt
if isinstance(prompt, list):
prompt = prompt[0] # Take the first prompt if it's a list
kwargs = {
"model": body.model,
"temperature": body.temperature,
"max_tokens": body.max_tokens,
"stream": body.stream,
"top_p": body.top_p,
"presence_penalty": body.presence_penalty,
"frequency_penalty": body.frequency_penalty,
}
if body.stream:
async def generate_stream():
response_stream = await asyncio.to_thread(
client.completions.create,
prompt=prompt,
**kwargs
)
total_tokens = 0
for chunk in response_stream:
if 'usage' in chunk and chunk['usage']:
total_tokens += chunk['usage'].get('total_tokens', 0)
yield f"data: {json.dumps(chunk)}\n\n"
# Log API usage at the end of streaming
if db:
db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream"
)
else:
response = await asyncio.to_thread(
client.completions.create,
prompt=prompt,
**kwargs
)
# Log the API usage
if db and 'usage' in response:
total_tokens = response['usage'].get('total_tokens', 0)
db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
@app.get("/health")
async def health_check():
status_info = {"api": "ok"}
# Check MongoDB connection
if db:
try:
# Simple operation to check connection
db.api_keys_collection.find_one({})
status_info["database"] = "ok"
except Exception as e:
status_info["database"] = f"error: {str(e)}"
else:
status_info["database"] = "not configured"
# Add Hugging Face Space info
if hf_helper.is_in_space:
status_info["environment"] = "Hugging Face Space"
status_info["space_name"] = hf_helper.space_name
else:
status_info["environment"] = "Local"
return status_info
# API Key Management Endpoints
@app.post("/v1/api_keys", response_model=APIKeyResponse)
async def create_api_key(body: APIKeyCreate):
if not db:
raise HTTPException(status_code=500, detail="Database not configured")
try:
api_key = db.generate_api_key(body.user_id, body.name)
key_data = db.validate_api_key(api_key)
return {
"key": api_key,
"name": key_data["name"],
"created_at": key_data["created_at"].isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}")
@app.get("/v1/api_keys")
async def list_api_keys(user_id: str):
if not db:
raise HTTPException(status_code=500, detail="Database not configured")
keys = db.get_user_api_keys(user_id)
for key in keys:
if "created_at" in key:
key["created_at"] = key["created_at"].isoformat()
if "last_used" in key and key["last_used"]:
key["last_used"] = key["last_used"].isoformat()
return {"keys": keys}
@app.post("/v1/api_keys/revoke")
async def revoke_api_key(api_key: str):
if not db:
raise HTTPException(status_code=500, detail="Database not configured")
success = db.revoke_api_key(api_key)
if not success:
raise HTTPException(status_code=404, detail="API key not found")
return {"message": "API key revoked successfully"}
# Clean up IP rotator clients on shutdown
@app.on_event("shutdown")
async def cleanup_clients():
for client in clients.values():
try:
if hasattr(client, 'ip_rotator') and client.ip_rotator:
client.ip_rotator.shutdown()
except:
pass
f __name__ == "__main__":
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "7860"))
print(f"Starting PyScoutAI API on http://{host}:{port}")
print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}")
uvicorn.run(
app,
host=host,
port=port,
reload=not hf_helper.is_in_space
)