from fastapi import FastAPI, Request, Response from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import json from typegpt_api import generate, model_mapping, simplified_models from api_info import developer_info, model_providers app = FastAPI() # Set up CORS middleware if needed app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health_check") async def health_check(): return {"status": "OK"} @app.get("/models") async def get_models(): try: response = { "object": "list", "data": [] } for provider, info in model_providers.items(): for model in info["models"]: response["data"].append({ "id": model, "object": "model", "provider": provider, "description": info["description"] }) return JSONResponse(content=response) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) @app.post("/chat/completions") async def chat_completions(request: Request): # Receive the JSON payload try: body = await request.json() except Exception as e: return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400) # Extract parameters model = body.get("model") messages = body.get("messages") temperature = body.get("temperature", 0.7) top_p = body.get("top_p", 1.0) n = body.get("n", 1) stream = body.get("stream", False) stop = body.get("stop") max_tokens = body.get("max_tokens") presence_penalty = body.get("presence_penalty", 0.0) frequency_penalty = body.get("frequency_penalty", 0.0) logit_bias = body.get("logit_bias") user = body.get("user") timeout = 30 # or set based on your preference # Validate required parameters if not model: return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400) if not messages: return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400) # Call the generate function try: if stream: async def generate_stream(): response = generate( model=model, messages=messages, temperature=temperature, top_p=top_p, n=n, stream=True, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, timeout=timeout, ) for chunk in response: yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Transfer-Encoding": "chunked" } ) else: response = generate( model=model, messages=messages, temperature=temperature, top_p=top_p, n=n, stream=False, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, timeout=timeout, ) return JSONResponse(content=response) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) @app.get("/developer_info") async def get_developer_info(): return JSONResponse(content=developer_info) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)