| | import contextlib |
| | import logging |
| | import os |
| | import sys |
| | import ast |
| | import json |
| | from threading import Thread |
| | import time |
| | from traceback import print_exception |
| | from typing import List |
| | from pydantic import BaseModel, Field |
| |
|
| | import uvicorn |
| | from fastapi import Depends, FastAPI, Header, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.requests import Request |
| | from fastapi.responses import JSONResponse |
| | from sse_starlette import EventSourceResponse |
| | from starlette.responses import PlainTextResponse |
| |
|
| | from openai_server.log import logger |
| |
|
| | sys.path.append('openai_server') |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class Generation(BaseModel): |
| | |
| | |
| | top_k: int | None = 1 |
| | repetition_penalty: float | None = 1 |
| | min_p: float | None = 0.0 |
| | max_time: float | None = 360 |
| |
|
| |
|
| | class Params(BaseModel): |
| | |
| | user: str | None = Field(default=None, description="Track user") |
| | model: str | None = Field(default=None, description="Choose model") |
| | best_of: int | None = Field(default=1, description="Unused") |
| | frequency_penalty: float | None = 0.0 |
| | max_tokens: int | None = 256 |
| | n: int | None = Field(default=1, description="Unused") |
| | presence_penalty: float | None = 0.0 |
| | stop: str | List[str] | None = None |
| | stop_token_ids: List[int] | None = None |
| | stream: bool | None = False |
| | temperature: float | None = 0.3 |
| | top_p: float | None = 1.0 |
| | seed: int | None = 1234 |
| |
|
| |
|
| | class CompletionParams(Params): |
| | prompt: str | List[str] |
| | logit_bias: dict | None = None |
| | logprobs: int | None = None |
| |
|
| |
|
| | class TextRequest(Generation, CompletionParams): |
| | pass |
| |
|
| |
|
| | class TextResponse(BaseModel): |
| | id: str |
| | choices: List[dict] |
| | created: int = int(time.time()) |
| | model: str |
| | object: str = "text_completion" |
| | usage: dict |
| |
|
| |
|
| | class ChatParams(Params): |
| | messages: List[dict] |
| | tools: list | None = Field(default=None, description="WIP") |
| | tool_choice: str | None = Field(default=None, description="WIP") |
| |
|
| |
|
| | class ChatRequest(Generation, ChatParams): |
| | |
| | pass |
| |
|
| |
|
| | class ChatResponse(BaseModel): |
| | id: str |
| | choices: List[dict] |
| | created: int = int(time.time()) |
| | model: str |
| | object: str = "chat.completion" |
| | usage: dict |
| |
|
| |
|
| | class ModelInfoResponse(BaseModel): |
| | model_name: str |
| |
|
| |
|
| | class ModelListResponse(BaseModel): |
| | model_names: List[str] |
| |
|
| |
|
| | def verify_api_key(authorization: str = Header(None)) -> None: |
| | server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', 'EMPTY') |
| | if server_api_key == 'EMPTY': |
| | |
| | return |
| | if server_api_key and (authorization is None or authorization != f"Bearer {server_api_key}"): |
| | raise HTTPException(status_code=401, detail="Unauthorized") |
| |
|
| |
|
| | app = FastAPI() |
| | check_key = [Depends(verify_api_key)] |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"] |
| | ) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class InvalidRequestError(Exception): |
| | pass |
| |
|
| |
|
| | @app.exception_handler(Exception) |
| | async def validation_exception_handler(request, exc): |
| | print_exception(exc) |
| | exc2 = InvalidRequestError(str(exc)) |
| | return PlainTextResponse(str(exc2), status_code=400) |
| |
|
| |
|
| | @app.options("/", dependencies=check_key) |
| | async def options_route(): |
| | return JSONResponse(content="OK") |
| |
|
| |
|
| | @app.post('/v1/completions', response_model=TextResponse, dependencies=check_key) |
| | async def openai_completions(request: Request, request_data: TextRequest): |
| | if request_data.stream: |
| | async def generator(): |
| | from openai_server.backend import stream_completions |
| | response = stream_completions(dict(request_data)) |
| | for resp in response: |
| | disconnected = await request.is_disconnected() |
| | if disconnected: |
| | break |
| |
|
| | yield {"data": json.dumps(resp)} |
| |
|
| | return EventSourceResponse(generator()) |
| |
|
| | else: |
| | from openai_server.backend import completions |
| | response = completions(dict(request_data)) |
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.post('/v1/chat/completions', response_model=ChatResponse, dependencies=check_key) |
| | async def openai_chat_completions(request: Request, request_data: ChatRequest): |
| | if request_data.stream: |
| | from openai_server.backend import stream_chat_completions |
| |
|
| | async def generator(): |
| | response = stream_chat_completions(dict(request_data)) |
| | for resp in response: |
| | disconnected = await request.is_disconnected() |
| | if disconnected: |
| | break |
| |
|
| | yield {"data": json.dumps(resp)} |
| |
|
| | return EventSourceResponse(generator()) |
| | else: |
| | from openai_server.backend import chat_completions |
| | response = chat_completions(dict(request_data)) |
| | return JSONResponse(response) |
| |
|
| |
|
| | |
| | @app.get("/v1/models", dependencies=check_key) |
| | @app.get("/v1/models/{model}", dependencies=check_key) |
| | @app.get("/v1/models/{repo}/{model}", dependencies=check_key) |
| | async def handle_models(request: Request): |
| | path = request.url.path |
| | model_name = path[len('/v1/models/'):] |
| |
|
| | from openai_server.backend import gradio_client |
| | model_dict = ast.literal_eval(gradio_client.predict(api_name='/model_names')) |
| | base_models = [x['base_model'] for x in model_dict] |
| |
|
| | if not model_name: |
| | response = { |
| | "object": "list", |
| | "data": base_models, |
| | } |
| | else: |
| | model_index = base_models.index(model_name) |
| | if model_index >= 0: |
| | response = model_dict[model_index] |
| | else: |
| | response = dict(model_name='INVALID') |
| |
|
| | return JSONResponse(response) |
| |
|
| |
|
| | @app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key) |
| | async def handle_model_info(): |
| | from openai_server.backend import get_model_info |
| | return JSONResponse(content=get_model_info()) |
| |
|
| |
|
| | @app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_key) |
| | async def handle_list_models(): |
| | from openai_server.backend import get_model_list |
| | return JSONResponse(content=get_model_list()) |
| |
|
| |
|
| | def run_server(host='0.0.0.0', |
| | port=5000, |
| | ssl_certfile=None, |
| | ssl_keyfile=None, |
| | gradio_prefix=None, |
| | gradio_host=None, |
| | gradio_port=None, |
| | h2ogpt_key=None, |
| | ): |
| | os.environ['GRADIO_PREFIX'] = gradio_prefix or 'http' |
| | os.environ['GRADIO_SERVER_HOST'] = gradio_host or 'localhost' |
| | os.environ['GRADIO_SERVER_PORT'] = gradio_port or '7860' |
| | os.environ['GRADIO_H2OGPT_H2OGPT_KEY'] = h2ogpt_key or '' |
| | |
| | |
| | server_api_key = os.getenv('H2OGPT_OPENAI_API_KEY', os.environ['GRADIO_H2OGPT_H2OGPT_KEY']) or 'EMPTY' |
| | os.environ['H2OGPT_OPENAI_API_KEY'] = server_api_key |
| |
|
| | port = int(os.getenv('H2OGPT_OPENAI_PORT', port)) |
| | ssl_certfile = os.getenv('H2OGPT_OPENAI_CERT_PATH', ssl_certfile) |
| | ssl_keyfile = os.getenv('H2OGPT_OPENAI_KEY_PATH', ssl_keyfile) |
| |
|
| | prefix = 'https' if ssl_keyfile and ssl_certfile else 'http' |
| | logger.info(f'OpenAI API URL: {prefix}://{host}:{port}') |
| | logger.info(f'OpenAI API key: {server_api_key}') |
| |
|
| | logging.getLogger("uvicorn.error").propagate = False |
| | uvicorn.run(app, host=host, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) |
| |
|
| |
|
| | def run(wait=True, **kwargs): |
| | if wait: |
| | run_server(**kwargs) |
| | else: |
| | Thread(target=run_server, kwargs=kwargs, daemon=True).start() |
| |
|