Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from pydantic import BaseModel | |
| import httpx | |
| from functools import lru_cache | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = FastAPI() | |
| # Get API keys and secret endpoint from environment variables | |
| api_keys_str = os.getenv('API_KEYS') | |
| valid_api_keys = api_keys_str.split(',') if api_keys_str else [] | |
| secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT') | |
| # Check if the endpoint is set in the environment | |
| if not secret_api_endpoint: | |
| raise HTTPException(status_code=500, detail="Secret API endpoint is not configured in environment variables.") | |
| class Payload(BaseModel): | |
| model: str | |
| messages: list | |
| async def root(): | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Loki.AI API</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; text-align: center; margin-top: 50px; background-color: #121212; color: white; } | |
| h1 { color: #4CAF50; } | |
| a { color: #BB86FC; text-decoration: none; } | |
| a:hover { text-decoration: underline; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Welcome to Loki.AI API!</h1> | |
| <p>Created by Parth Sadaria</p> | |
| <p>Check out the GitHub for more projects:</p> | |
| <a href="https://github.com/ParthSadaria" target="_blank">github.com/ParthSadaria</a> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| # Cache function with lru_cache | |
| async def get_cached_models(): | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| response = await client.get(f"{secret_api_endpoint}/api/v1/models", timeout=3) | |
| response.raise_for_status() | |
| return response.json() | |
| except httpx.RequestError as e: | |
| raise HTTPException(status_code=500, detail=f"Request failed: {e}") | |
| async def get_models(): | |
| return await get_cached_models() | |
| async def get_completion(payload: Payload, request: Request): | |
| api_key = request.headers.get("Authorization") | |
| # Validate API key | |
| if api_key not in valid_api_keys: | |
| raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)") | |
| # Prepare the payload for streaming | |
| payload_dict = {**payload.dict(), "stream": True} | |
| # Define an asynchronous generator to stream the response | |
| async def stream_generator(): | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| async with client.stream("POST", secret_api_endpoint, json=payload_dict, timeout=10) as response: | |
| response.raise_for_status() | |
| async for chunk in response.aiter_bytes(chunk_size=512): # Smaller chunks for faster response | |
| if chunk: | |
| yield chunk | |
| except httpx.RequestError as e: | |
| raise HTTPException(status_code=500, detail=f"Streaming failed: {e}") | |
| # Return the streaming response | |
| return StreamingResponse(stream_generator(), media_type="application/json") | |
| # Log the API endpoints | |
| async def startup_event(): | |
| print("API endpoints:") | |
| print("GET /") | |
| print("GET /models") | |
| print("POST /v1/chat/completions") | |
| # Run the server with Uvicorn using the 'main' module | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |