|
import asyncio |
|
import time |
|
import uuid |
|
import os |
|
import json |
|
from typing import Dict, List, Optional, Union, Any |
|
from fastapi import FastAPI, HTTPException, Depends, Request, status, Body |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
from pydantic import BaseModel, Field, EmailStr |
|
from slowapi import Limiter, _rate_limit_exceeded_handler |
|
from slowapi.util import get_remote_address |
|
from slowapi.errors import RateLimitExceeded |
|
import uvicorn |
|
|
|
from db_helper import MongoDBHelper |
|
from deepinfra_client import DeepInfraClient |
|
from hf_utils import HuggingFaceSpaceHelper |
|
|
|
|
|
hf_helper = HuggingFaceSpaceHelper() |
|
|
|
|
|
if hf_helper.is_in_space: |
|
hf_helper.install_dependencies([ |
|
"pymongo", "python-dotenv", "fastapi", "uvicorn", "slowapi", |
|
"fake-useragent", "requests-ip-rotator", "pydantic[email]" |
|
]) |
|
|
|
|
|
app = FastAPI( |
|
title="PyScoutAI API", |
|
description="An OpenAI-compatible API that provides access to DeepInfra models with enhanced features", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
limiter = Limiter(key_func=get_remote_address) |
|
app.state.limiter = limiter |
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
security = HTTPBearer(auto_error=False) |
|
|
|
|
|
try: |
|
db = MongoDBHelper(hf_helper.get_mongodb_uri()) |
|
except Exception as e: |
|
print(f"Warning: MongoDB connection failed: {e}") |
|
print("API key authentication will not work!") |
|
db = None |
|
|
|
|
|
class Message(BaseModel): |
|
role: str |
|
content: Optional[str] = None |
|
name: Optional[str] = None |
|
|
|
class ChatCompletionRequest(BaseModel): |
|
model: str |
|
messages: List[Message] |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 1.0 |
|
n: Optional[int] = 1 |
|
stream: Optional[bool] = False |
|
max_tokens: Optional[int] = None |
|
presence_penalty: Optional[float] = 0.0 |
|
frequency_penalty: Optional[float] = 0.0 |
|
user: Optional[str] = None |
|
|
|
class CompletionRequest(BaseModel): |
|
model: str |
|
prompt: Union[str, List[str]] |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 1.0 |
|
n: Optional[int] = 1 |
|
stream: Optional[bool] = False |
|
max_tokens: Optional[int] = None |
|
presence_penalty: Optional[float] = 0.0 |
|
frequency_penalty: Optional[float] = 0.0 |
|
user: Optional[str] = None |
|
|
|
class UserCreate(BaseModel): |
|
email: EmailStr |
|
name: str |
|
organization: Optional[str] = None |
|
|
|
class APIKeyCreate(BaseModel): |
|
name: str = "Default API Key" |
|
user_id: str |
|
|
|
class APIKeyResponse(BaseModel): |
|
key: str |
|
name: str |
|
created_at: str |
|
|
|
|
|
clients: Dict[str, DeepInfraClient] = {} |
|
|
|
|
|
async def get_api_key( |
|
request: Request, |
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) |
|
) -> Optional[str]: |
|
|
|
if credentials: |
|
return credentials.credentials |
|
|
|
|
|
if "Authorization" in request.headers: |
|
auth = request.headers["Authorization"] |
|
if auth.startswith("Bearer "): |
|
return auth.replace("Bearer ", "") |
|
|
|
if "x-api-key" in request.headers: |
|
return request.headers["x-api-key"] |
|
|
|
|
|
api_key = request.query_params.get("api_key") |
|
if api_key: |
|
return api_key |
|
|
|
|
|
return None |
|
|
|
|
|
async def get_user_info(api_key: Optional[str] = Depends(get_api_key)) -> Dict[str, Any]: |
|
if not api_key: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="API key is required", |
|
headers={"WWW-Authenticate": "Bearer"} |
|
) |
|
|
|
|
|
if not db: |
|
return {"user_id": "development", "key": api_key} |
|
|
|
|
|
if not api_key.startswith("PyScoutAI-"): |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid API key format", |
|
headers={"WWW-Authenticate": "Bearer"} |
|
) |
|
|
|
|
|
user_info = db.validate_api_key(api_key) |
|
if not user_info: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid API key", |
|
headers={"WWW-Authenticate": "Bearer"} |
|
) |
|
|
|
|
|
rate_limit = db.check_rate_limit(api_key) |
|
if not rate_limit["allowed"]: |
|
raise HTTPException( |
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
|
detail=rate_limit["reason"] |
|
) |
|
|
|
return user_info |
|
|
|
|
|
def get_client(api_key: str) -> DeepInfraClient: |
|
if api_key not in clients: |
|
|
|
clients[api_key] = DeepInfraClient( |
|
use_random_user_agent=True, |
|
use_proxy_rotation=True, |
|
use_ip_rotation=True |
|
) |
|
return clients[api_key] |
|
|
|
@app.get("/") |
|
async def root(): |
|
metadata = hf_helper.get_hf_metadata() |
|
return { |
|
"message": "Welcome to PyScoutAI API", |
|
"documentation": "/docs", |
|
"environment": "Hugging Face Space" if hf_helper.is_in_space else "Local", |
|
"endpoints": [ |
|
"/v1/models", |
|
"/v1/chat/completions", |
|
"/v1/completions" |
|
], |
|
**metadata |
|
} |
|
|
|
@app.get("/v1/models") |
|
@limiter.limit("20/minute") |
|
async def list_models( |
|
request: Request, |
|
user_info: Dict[str, Any] = Depends(get_user_info) |
|
): |
|
api_key = user_info["key"] |
|
client = get_client(api_key) |
|
try: |
|
models = await asyncio.to_thread(client.models.list) |
|
|
|
if db: |
|
db.log_api_usage(api_key, "/v1/models", 0) |
|
return models |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}") |
|
|
|
@app.post("/v1/chat/completions") |
|
@limiter.limit("60/minute") |
|
async def create_chat_completion( |
|
request: Request, |
|
body: ChatCompletionRequest, |
|
user_info: Dict[str, Any] = Depends(get_user_info) |
|
): |
|
api_key = user_info["key"] |
|
client = get_client(api_key) |
|
|
|
try: |
|
|
|
messages = [{"role": msg.role, "content": msg.content} for msg in body.messages if msg.content is not None] |
|
|
|
kwargs = { |
|
"model": body.model, |
|
"temperature": body.temperature, |
|
"max_tokens": body.max_tokens, |
|
"stream": body.stream, |
|
"top_p": body.top_p, |
|
"presence_penalty": body.presence_penalty, |
|
"frequency_penalty": body.frequency_penalty, |
|
} |
|
|
|
if body.stream: |
|
async def generate_stream(): |
|
response_stream = await asyncio.to_thread( |
|
client.chat.create, |
|
messages=messages, |
|
**kwargs |
|
) |
|
|
|
total_tokens = 0 |
|
for chunk in response_stream: |
|
|
|
if 'usage' in chunk and chunk['usage']: |
|
total_tokens += chunk['usage'].get('total_tokens', 0) |
|
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
|
|
if db: |
|
db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model) |
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
return StreamingResponse( |
|
generate_stream(), |
|
media_type="text/event-stream" |
|
) |
|
else: |
|
response = await asyncio.to_thread( |
|
client.chat.create, |
|
messages=messages, |
|
**kwargs |
|
) |
|
|
|
|
|
if db and 'usage' in response: |
|
total_tokens = response['usage'].get('total_tokens', 0) |
|
db.log_api_usage(api_key, "/v1/chat/completions", total_tokens, body.model) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error generating chat completion: {str(e)}") |
|
|
|
@app.post("/v1/completions") |
|
@limiter.limit("60/minute") |
|
async def create_completion( |
|
request: Request, |
|
body: CompletionRequest, |
|
user_info: Dict[str, Any] = Depends(get_user_info) |
|
): |
|
api_key = user_info["key"] |
|
client = get_client(api_key) |
|
|
|
try: |
|
|
|
prompt = body.prompt |
|
if isinstance(prompt, list): |
|
prompt = prompt[0] |
|
|
|
kwargs = { |
|
"model": body.model, |
|
"temperature": body.temperature, |
|
"max_tokens": body.max_tokens, |
|
"stream": body.stream, |
|
"top_p": body.top_p, |
|
"presence_penalty": body.presence_penalty, |
|
"frequency_penalty": body.frequency_penalty, |
|
} |
|
|
|
if body.stream: |
|
async def generate_stream(): |
|
response_stream = await asyncio.to_thread( |
|
client.completions.create, |
|
prompt=prompt, |
|
**kwargs |
|
) |
|
|
|
total_tokens = 0 |
|
for chunk in response_stream: |
|
if 'usage' in chunk and chunk['usage']: |
|
total_tokens += chunk['usage'].get('total_tokens', 0) |
|
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
|
|
if db: |
|
db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model) |
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
return StreamingResponse( |
|
generate_stream(), |
|
media_type="text/event-stream" |
|
) |
|
else: |
|
response = await asyncio.to_thread( |
|
client.completions.create, |
|
prompt=prompt, |
|
**kwargs |
|
) |
|
|
|
|
|
if db and 'usage' in response: |
|
total_tokens = response['usage'].get('total_tokens', 0) |
|
db.log_api_usage(api_key, "/v1/completions", total_tokens, body.model) |
|
|
|
return response |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}") |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
status_info = {"api": "ok"} |
|
|
|
|
|
if db: |
|
try: |
|
|
|
db.api_keys_collection.find_one({}) |
|
status_info["database"] = "ok" |
|
except Exception as e: |
|
status_info["database"] = f"error: {str(e)}" |
|
else: |
|
status_info["database"] = "not configured" |
|
|
|
|
|
if hf_helper.is_in_space: |
|
status_info["environment"] = "Hugging Face Space" |
|
status_info["space_name"] = hf_helper.space_name |
|
else: |
|
status_info["environment"] = "Local" |
|
|
|
return status_info |
|
|
|
|
|
@app.post("/v1/api_keys", response_model=APIKeyResponse) |
|
async def create_api_key(body: APIKeyCreate): |
|
if not db: |
|
raise HTTPException(status_code=500, detail="Database not configured") |
|
|
|
try: |
|
api_key = db.generate_api_key(body.user_id, body.name) |
|
key_data = db.validate_api_key(api_key) |
|
return { |
|
"key": api_key, |
|
"name": key_data["name"], |
|
"created_at": key_data["created_at"].isoformat() |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error creating API key: {str(e)}") |
|
|
|
@app.get("/v1/api_keys") |
|
async def list_api_keys(user_id: str): |
|
if not db: |
|
raise HTTPException(status_code=500, detail="Database not configured") |
|
|
|
keys = db.get_user_api_keys(user_id) |
|
for key in keys: |
|
if "created_at" in key: |
|
key["created_at"] = key["created_at"].isoformat() |
|
if "last_used" in key and key["last_used"]: |
|
key["last_used"] = key["last_used"].isoformat() |
|
|
|
return {"keys": keys} |
|
|
|
@app.post("/v1/api_keys/revoke") |
|
async def revoke_api_key(api_key: str): |
|
if not db: |
|
raise HTTPException(status_code=500, detail="Database not configured") |
|
|
|
success = db.revoke_api_key(api_key) |
|
if not success: |
|
raise HTTPException(status_code=404, detail="API key not found") |
|
|
|
return {"message": "API key revoked successfully"} |
|
|
|
|
|
@app.on_event("shutdown") |
|
async def cleanup_clients(): |
|
for client in clients.values(): |
|
try: |
|
if hasattr(client, 'ip_rotator') and client.ip_rotator: |
|
client.ip_rotator.shutdown() |
|
except: |
|
pass |
|
|
|
f __name__ == "__main__": |
|
host = os.environ.get("HOST", "0.0.0.0") |
|
port = int(os.environ.get("PORT", "7860")) |
|
|
|
print(f"Starting PyScoutAI API on http://{host}:{port}") |
|
print(f"Environment: {'Hugging Face Space' if hf_helper.is_in_space else 'Local'}") |
|
|
|
uvicorn.run( |
|
app, |
|
host=host, |
|
port=port, |
|
reload=not hf_helper.is_in_space |
|
) |
|
|