import asyncio import time import uuid import os import json from typing import Dict, List, Optional, Union, Any from fastapi import FastAPI, HTTPException, Depends, Request, status, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field, EmailStr from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded import uvicorn from db_helper import MongoDBHelper from deepinfra_client import DeepInfraClient from hf_utils import HuggingFaceSpaceHelper # Initialize Hugging Face Space helper hf_helper = HuggingFaceSpaceHelper() # Install required packages for HF Spaces if needed if hf_helper.is_in_space: hf_helper.install_dependencies([ "pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi", "fake-useragent", "requests-ip-rotator", "pydantic[email]" ]) # Initialize FastAPI app app = FastAPI( title="PyScoutAI API", description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features", version="1.0.0" ) # Setup rate limiting limiter = Limiter(key_func=get_remote_address) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # Set up CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Security security = HTTPBearer(auto_error=False) # Database helper try: db = MongoDBHelper(hf_helper.get_mongodb_uri()) except Exception as e: print(f"Warning: MongoDB connection failed: {e}") print("API key authentication will not work!") db = None # Models for requests and responses class Message(BaseModel): role: str content: Optional[str] = None name: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None class CompletionRequest(BaseModel): model: str prompt: Union[str, List[str]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None class UserCreate(BaseModel): email: EmailStr name: str organization: Optional[str] = None class APIKeyCreate(BaseModel): name: str = "Default API Key" user_id: str class APIKeyResponse(BaseModel): key: str name: str created_at: str # API clients storage (one per API key) clients: Dict[str, DeepInfraClient] = {} # Helper function to get the API key from the request async def get_api_key( request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) ) -> Optional[str]: # Check Authorization header if credentials: return credentials.credentials # Check for API key in the request headers if "Authorization" in request.headers: auth = request.headers["Authorization"] if auth.startswith("Bearer "): return auth.replace("Bearer ", "") if "x-api-key" in request.headers: return request.headers["x-api-key"] # Check for API key in query parameters api_key = request.query_params.get("api_key") if api_key: return api_key # No API key found, return None return None # Helper function to validate a PyScout API key and get user info async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]: if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API key is required", headers={"WWW-Authenticate": "Bearer"} ) # Skip validation if DB is not connected (development mode) if not db: return {"user_id": "development", "key": api_key} # Check if key starts with PyScoutAI- if not api_key.startswith("PyScoutAI-"): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key format", headers={"WWW-Authenticate": "Bearer"} ) # Validate the API key user_info = db.validate_api_key(api_key) if not user_info: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key", headers={"WWW-Authenticate": "Bearer"} ) # Check rate limits rate_limit = db.check_rate_limit(api_key) if not rate_limit["allowed"]: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=rate_limit["reason"] ) return user_info # Helper function to get or create a client def get_client(api_key: str) -> DeepInfraClient: if api_key not in clients: # Create a client with IP rotation and random user agent clients[api_key] = DeepInfraClient( use_random_user_agent=True, use_proxy_rotation=True, use_ip_rotation=True ) return clients[api_key] @app.get("/") async def root(): metadata = hf_helper.get_hf_metadata() return { "message": "Welcome to PyScoutAI API", "documentation": "/docs", "environment": "Hugging Face Space" if hf_helper.is_in_space else "Local", "endpoints": [ "/v1/models", "/v1/chat/completions", "/v1/completions" ], **metadata } @app.get("/v1/models") @limiter.limit("20/minute") async def list_models( request: Request, user_info: Dict[str, Any] = Depends(get_user_info) ): api_key = user_info["key"] client = get_client(api_key) try: models = await asyncio.to_thread(client.models.list) # Log the API usage if db: db.log_api_usage(api_key, "/v1/models", 0) return models except Exception as e: raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}") @app.post("/v1/chat/completions") @limiter.limit("60/minute") async def create_chat_completion( request: Request, body: ChatCompletionRequest, user_info: Dict[str, Any] = Depends(get_user_info) ): api_key = user_info["key"] client = get_client(api_key) try: # Prepare the messages messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None] kwargs = { "model": body.model, "temperature": body.temperature, "max_tokens": body.max_tokens, "stream": body.stream, "top_p": body.top_p, "presence_penalty": body.presence_penalty, "frequency_penalty": body.frequency_penalty, } if body.stream: async def generate_stream(): response_stream = await asyncio.to_thread( client.chat.create, messages=messages, **kwargs ) total_tokens = 0 for chunk in response_stream: # Track token usage for each chunk if available if 'usage' in chunk and chunk['usage']: total_tokens += chunk['usage'].get('total_tokens', 0) yield f"data: {json.dumps(chunk)}\n\n" # Log API usage at the end of streaming if db: db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model) yield "data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) else: response = await asyncio.to_thread( client.chat.create, messages=messages, **kwargs ) # Log the API usage if db and 'usage' in response: total_tokens = response['usage'].get('total_tokens', 0) db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model) return response except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}") @app.post("/v1/completions") @limiter.limit("60/minute") async def create_completion( request: Request, body: CompletionRequest, user_info: Dict[str, Any] = Depends(get_user_info) ): api_key = user_info["key"] client = get_client(api_key) try: # Handle different prompt types prompt = body.prompt if isinstance(prompt, list): prompt = prompt[0] # Take the first prompt if it's a list kwargs = { "model": body.model, "temperature": body.temperature, "max_tokens": body.max_tokens, "stream": body.stream, "top_p": body.top_p, "presence_penalty": body.presence_penalty, "frequency_penalty": body.frequency_penalty, } if body.stream: async def generate_stream(): response_stream = await asyncio.to_thread( client.completions.create, prompt=prompt, **kwargs ) total_tokens = 0 for chunk in response_stream: if 'usage' in chunk and chunk['usage']: total_tokens += chunk['usage'].get('total_tokens', 0) yield f"data: {json.dumps(chunk)}\n\n" # Log API usage at the end of streaming if db: db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model) yield "data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream" ) else: response = await asyncio.to_thread( client.completions.create, prompt=prompt, **kwargs ) # Log the API usage if db and 'usage' in response: total_tokens = response['usage'].get('total_tokens', 0) db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model) return response except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}") @app.get("/health") async def health_check(): status_info = {"api": "ok"} # Check MongoDB connection if db: try: # Simple operation to check connection db.api_keys_collection.find_one({}) status_info["database"] = "ok" except Exception as e: status_info["database"] = f"error: {str(e)}" else: status_info["database"] = "not configured" # Add Hugging Face Space info if hf_helper.is_in_space: status_info["environment"] = "Hugging Face Space" status_info["space_name"] = hf_helper.space_name else: status_info["environment"] = "Local" return status_info # API Key Management Endpoints @app.post("/v1/api_keys", response_model=APIKeyResponse) async def create_api_key(body: APIKeyCreate): if not db: raise HTTPException(status_code=500, detail="Database not configured") try: api_key = db.generate_api_key(body.user_id, body.name) key_data = db.validate_api_key(api_key) return { "key": api_key, "name": key_data["name"], "created_at": key_data["created_at"].isoformat() } except Exception as e: raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}") @app.get("/v1/api_keys") async def list_api_keys(user_id: str): if not db: raise HTTPException(status_code=500, detail="Database not configured") keys = db.get_user_api_keys(user_id) for key in keys: if "created_at" in key: key["created_at"] = key["created_at"].isoformat() if "last_used" in key and key["last_used"]: key["last_used"] = key["last_used"].isoformat() return {"keys": keys} @app.post("/v1/api_keys/revoke") async def revoke_api_key(api_key: str): if not db: raise HTTPException(status_code=500, detail="Database not configured") success = db.revoke_api_key(api_key) if not success: raise HTTPException(status_code=404, detail="API key not found") return {"message": "API key revoked successfully"} # Clean up IP rotator clients on shutdown @app.on_event("shutdown") async def cleanup_clients(): for client in clients.values(): try: if hasattr(client, 'ip_rotator') and client.ip_rotator: client.ip_rotator.shutdown() except: pass f __name__ == "__main__": host = os.environ.get("HOST", "0.0.0.0") port = int(os.environ.get("PORT", "7860")) print(f"Starting PyScoutAI API on http://{host}:{port}") print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}") uvicorn.run( app, host=host, port=port, reload=not hf_helper.is_in_space )