Spaces:
Running
Running
| import os | |
| import json | |
| import uvicorn | |
| from pydantic import BaseSettings | |
| from fastapi import FastAPI, Depends | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.exceptions import HTTPException | |
| from text_generation.errors import OverloadedError, UnknownError, ValidationError | |
| from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers | |
| from spitfight.colosseum.common import ( | |
| COLOSSEUM_MODELS_ROUTE, | |
| COLOSSEUM_PROMPT_ROUTE, | |
| COLOSSEUM_RESP_VOTE_ROUTE, | |
| COLOSSEUM_ENERGY_VOTE_ROUTE, | |
| COLOSSEUM_HEALTH_ROUTE, | |
| ModelsResponse, | |
| PromptRequest, | |
| ResponseVoteRequest, | |
| ResponseVoteResponse, | |
| EnergyVoteRequest, | |
| EnergyVoteResponse, | |
| ) | |
| from spitfight.colosseum.controller.controller import ( | |
| Controller, | |
| init_global_controller, | |
| get_global_controller, | |
| ) | |
| from spitfight.utils import prepend_generator | |
| class ControllerConfig(BaseSettings): | |
| """Controller settings automatically loaded from environment variables.""" | |
| # Controller | |
| background_task_interval: int = 300 | |
| max_num_req_states: int = 10000 | |
| req_state_expiration_time: int = 600 | |
| compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"] | |
| # Logging | |
| log_dir: str = "/logs" | |
| controller_log_file: str = "controller.log" | |
| request_log_file: str = "requests.log" | |
| uvicorn_log_file: str = "uvicorn.log" | |
| # Generation | |
| max_new_tokens: int = 512 | |
| do_sample: bool = True | |
| temperature: float = 1.0 | |
| repetition_penalty: float = 1.0 | |
| top_k: int = 50 | |
| top_p: float = 0.95 | |
| app = FastAPI() | |
| settings = ControllerConfig() | |
| logger = get_logger("spitfight.colosseum.controller.router") | |
| async def startup_event(): | |
| init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file)) | |
| init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file)) | |
| init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file)) | |
| init_global_controller(settings) | |
| async def shutdown_event(): | |
| get_global_controller().shutdown() | |
| shutdown_queued_root_loggers() | |
| async def models(controller: Controller = Depends(get_global_controller)): | |
| return ModelsResponse(available_models=controller.get_available_models()) | |
| async def prompt( | |
| request: PromptRequest, | |
| controller: Controller = Depends(get_global_controller), | |
| ): | |
| generator = controller.prompt( | |
| request.request_id, | |
| request.prompt, | |
| request.model_index, | |
| request.model_preference, | |
| ) | |
| # First try to get the first token in order to catch TGI errors. | |
| try: | |
| first_token = await generator.__anext__() | |
| except OverloadedError: | |
| name = controller.request_states[request.request_id].model_names[request.model_index] | |
| logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request)) | |
| raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.") | |
| except ValidationError as e: | |
| logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request)) | |
| raise HTTPException(status_code=422, detail=str(e)) | |
| except StopAsyncIteration: | |
| logger.info("TGI returned empty response. Failed request: %s", repr(request)) | |
| return StreamingResponse( | |
| iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]), | |
| ) | |
| except UnknownError as e: | |
| logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request)) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return StreamingResponse(prepend_generator(first_token, generator)) | |
| async def response_vote( | |
| request: ResponseVoteRequest, | |
| controller: Controller = Depends(get_global_controller), | |
| ): | |
| if (state := controller.response_vote(request.request_id, request.victory_index)) is None: | |
| raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.") | |
| return ResponseVoteResponse( | |
| energy_consumptions=state.energy_consumptions, | |
| model_names=state.model_names, | |
| ) | |
| async def energy_vote( | |
| request: EnergyVoteRequest, | |
| controller: Controller = Depends(get_global_controller), | |
| ): | |
| if (state := controller.energy_vote(request.request_id, request.is_worth)) is None: | |
| raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.") | |
| return EnergyVoteResponse(model_names=state.model_names) | |
| async def health(): | |
| return "OK" | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", log_config=None) | |