Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -4,18 +4,16 @@ from fastapi import FastAPI, HTTPException, Request
|
|
4 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
from pydantic import BaseModel
|
6 |
import httpx
|
|
|
7 |
|
8 |
-
# Load environment variables from .env file
|
9 |
load_dotenv()
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
-
# Get API keys and secret endpoint from environment variables
|
14 |
api_keys_str = os.getenv('API_KEYS')
|
15 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
16 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
17 |
|
18 |
-
# Check if the endpoint is set in the environment
|
19 |
if not secret_api_endpoint:
|
20 |
raise HTTPException(status_code=500, detail="API endpoint is not configured in environment variables.")
|
21 |
|
@@ -50,8 +48,6 @@ async def root():
|
|
50 |
"""
|
51 |
return HTMLResponse(content=html_content)
|
52 |
|
53 |
-
# Remove cache from get_models
|
54 |
-
@app.get("/v1/models")
|
55 |
async def get_models():
|
56 |
async with httpx.AsyncClient() as client:
|
57 |
try:
|
@@ -61,33 +57,36 @@ async def get_models():
|
|
61 |
except httpx.RequestError as e:
|
62 |
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
|
63 |
|
|
|
|
|
|
|
|
|
64 |
@app.post("/v1/chat/completions")
|
65 |
async def get_completion(payload: Payload, request: Request):
|
66 |
api_key = request.headers.get("Authorization")
|
67 |
|
68 |
-
# Validate API key
|
69 |
if api_key not in valid_api_keys:
|
70 |
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
|
71 |
|
72 |
-
# Prepare the payload for streaming
|
73 |
payload_dict = {**payload.dict(), "stream": True}
|
74 |
|
75 |
-
|
76 |
-
async def stream_generator():
|
77 |
async with httpx.AsyncClient() as client:
|
78 |
try:
|
79 |
-
async with client.stream("POST", secret_api_endpoint, json=payload_dict, timeout=10) as response:
|
80 |
response.raise_for_status()
|
81 |
async for line in response.aiter_lines():
|
82 |
if line:
|
83 |
-
yield f"{line}\n"
|
84 |
-
except httpx.
|
85 |
-
raise HTTPException(status_code=
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
89 |
|
90 |
-
# Log the API endpoints
|
91 |
@app.on_event("startup")
|
92 |
async def startup_event():
|
93 |
print("API endpoints:")
|
@@ -95,7 +94,6 @@ async def startup_event():
|
|
95 |
print("GET /models")
|
96 |
print("POST /v1/chat/completions")
|
97 |
|
98 |
-
# Run the server with Uvicorn using the 'main' module
|
99 |
if __name__ == "__main__":
|
100 |
import uvicorn
|
101 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
4 |
from fastapi.responses import StreamingResponse, HTMLResponse
|
5 |
from pydantic import BaseModel
|
6 |
import httpx
|
7 |
+
from functools import lru_cache
|
8 |
|
|
|
9 |
load_dotenv()
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
|
|
13 |
api_keys_str = os.getenv('API_KEYS')
|
14 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
15 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
16 |
|
|
|
17 |
if not secret_api_endpoint:
|
18 |
raise HTTPException(status_code=500, detail="API endpoint is not configured in environment variables.")
|
19 |
|
|
|
48 |
"""
|
49 |
return HTMLResponse(content=html_content)
|
50 |
|
|
|
|
|
51 |
async def get_models():
|
52 |
async with httpx.AsyncClient() as client:
|
53 |
try:
|
|
|
57 |
except httpx.RequestError as e:
|
58 |
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
|
59 |
|
60 |
+
@app.get("/v1/models")
|
61 |
+
async def fetch_models():
|
62 |
+
return await get_models()
|
63 |
+
|
64 |
@app.post("/v1/chat/completions")
|
65 |
async def get_completion(payload: Payload, request: Request):
|
66 |
api_key = request.headers.get("Authorization")
|
67 |
|
|
|
68 |
if api_key not in valid_api_keys:
|
69 |
raise HTTPException(status_code=403, detail="Forbidden: Invalid API key. Join dsc.gg/chadgang and DM @mr_leaderyt on Discord for a free API key :)")
|
70 |
|
|
|
71 |
payload_dict = {**payload.dict(), "stream": True}
|
72 |
|
73 |
+
async def stream_generator(payload_dict):
|
|
|
74 |
async with httpx.AsyncClient() as client:
|
75 |
try:
|
76 |
+
async with client.stream("POST", f"{secret_api_endpoint}/v1/chat/completions", json=payload_dict, timeout=10) as response:
|
77 |
response.raise_for_status()
|
78 |
async for line in response.aiter_lines():
|
79 |
if line:
|
80 |
+
yield f"{line}\n"
|
81 |
+
except httpx.HTTPStatusError as status_err:
|
82 |
+
raise HTTPException(status_code=status_err.response.status_code, detail=f"HTTP error: {status_err}")
|
83 |
+
except httpx.RequestError as req_err:
|
84 |
+
raise HTTPException(status_code=500, detail=f"Streaming failed: {req_err}")
|
85 |
+
except Exception as e:
|
86 |
+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")
|
87 |
+
|
88 |
+
return StreamingResponse(stream_generator(payload_dict), media_type="application/json")
|
89 |
|
|
|
90 |
@app.on_event("startup")
|
91 |
async def startup_event():
|
92 |
print("API endpoints:")
|
|
|
94 |
print("GET /models")
|
95 |
print("POST /v1/chat/completions")
|
96 |
|
|
|
97 |
if __name__ == "__main__":
|
98 |
import uvicorn
|
99 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|