import os if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": from modelscope import patch_hub patch_hub() os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" from config import CONFIG, ModelConfig from utils import ( cleanMessages, parse_think_response, remove_nested_think_tags_stack, format_bytes, ) import copy, types, gc, sys, re, time, collections, asyncio from huggingface_hub import hf_hub_download from loguru import logger from rich import print from snowflake import SnowflakeGenerator CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) from typing import List, Optional, Union, Any, Dict from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings import numpy as np import torch if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): logger.info(f"CUDA not found, fall back to cpu") CONFIG.STRATEGY = "cpu fp16" if "cuda" in CONFIG.STRATEGY.lower(): from pynvml import * nvmlInit() gpu_h = nvmlDeviceGetHandleByIndex(0) def logGPUState(): if "cuda" in CONFIG.STRATEGY: gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) logger.info( f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}" ) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models os.environ["RWKV_JIT_ON"] = "1" os.environ["RWKV_CUDA_ON"] = ( "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0" ) from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.middleware.gzip import GZipMiddleware from api_types import ( ChatMessage, ChatCompletion, ChatCompletionChunk, Usage, PromptTokensDetails, ChatCompletionChoice, ChatCompletionMessage, ) class ModelStorage: MODEL_CONFIG: Optional[ModelConfig] = None model: Optional[RWKV] = None pipeline: Optional[PIPELINE] = None MODEL_STORAGE: Dict[str, ModelStorage] = {} DEFALUT_MODEL_NAME = None DEFAULT_REASONING_MODEL_NAME = None logger.info(f"STRATEGY - {CONFIG.STRATEGY}") logGPUState() for model_config in CONFIG.MODELS: logger.info(f"Load Model - {model_config.SERVICE_NAME}") if model_config.MODEL_FILE_PATH == None: model_config.MODEL_FILE_PATH = hf_hub_download( repo_id=model_config.DOWNLOAD_MODEL_REPO_ID, filename=model_config.DOWNLOAD_MODEL_FILE_NAME, local_dir=model_config.DOWNLOAD_MODEL_DIR, ) logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}") if model_config.DEFAULT_CHAT: if DEFALUT_MODEL_NAME != None: logger.info( f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" ) DEFALUT_MODEL_NAME = model_config.SERVICE_NAME if model_config.DEFAULT_REASONING: if DEFAULT_REASONING_MODEL_NAME != None: logger.info( f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" ) DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME print(model_config.DEFAULT_SAMPLER) MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( model=model_config.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY, ) MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB ) if "cuda" in CONFIG.STRATEGY: torch.cuda.empty_cache() gc.collect() logGPUState() logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`") logger.info( f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`" ) class ChatCompletionRequest(BaseModel): model: str = Field( default="rwkv-latest", description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`", ) messages: Optional[List[ChatMessage]] = Field(default=None) prompt: Optional[str] = Field(default=None) max_tokens: Optional[int] = Field(default=None) temperature: Optional[float] = Field(default=None) top_p: Optional[float] = Field(default=None) presence_penalty: Optional[float] = Field(default=None) count_penalty: Optional[float] = Field(default=None) penalty_decay: Optional[float] = Field(default=None) stream: Optional[bool] = Field(default=False) state_name: Optional[str] = Field(default=None) include_usage: Optional[bool] = Field(default=False) stop: Optional[list[str]] = Field(["\n\n"]) stop_tokens: Optional[list[int]] = Field([0]) @model_validator(mode="before") @classmethod def validate_mutual_exclusivity(cls, data: Any) -> Any: if not isinstance(data, dict): return data messages_provided = "messages" in data and data["messages"] != None prompt_provided = "prompt" in data and data["prompt"] != None if messages_provided and prompt_provided: raise ValueError("messages and prompt cannot coexist. Choose one.") if not messages_provided and not prompt_provided: raise ValueError("Either messages or prompt must be provided.") return data app = FastAPI(title="RWKV OpenAI-Compatible API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) async def runPrefill( request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state ): ctx = ctx.replace("\r\n", "\n") tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx) tokens = [int(x) for x in tokens] model_tokens += tokens while len(tokens) > 0: out, model_state = MODEL_STORAGE[request.model].model.forward( tokens[: CONFIG.CHUNK_LEN], model_state ) tokens = tokens[CONFIG.CHUNK_LEN :] await asyncio.sleep(0) return out, model_tokens, model_state def generate( request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048, ): args = PIPELINE_ARGS( temperature=max(0.2, request.temperature), top_p=request.top_p, alpha_frequency=request.count_penalty, alpha_presence=request.presence_penalty, token_ban=[], # ban the generation of some tokens token_stop=[0], ) # stop generation whenever you see any token here occurrence = {} out_tokens: List[int] = [] out_last = 0 cache_word_list = [] cache_word_len = 5 for i in range(max_tokens): for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency # out[0] -= 1e10 # disable END_OF_TEXT token = MODEL_STORAGE[request.model].pipeline.sample_logits( out, temperature=args.temperature, top_p=args.top_p ) if token == 0 and token in request.stop_tokens: yield { "content": "".join(cache_word_list), "tokens": out_tokens[out_last:], "finish_reason": "stop:token:0", "state": model_state, } del out gc.collect() return out, model_state = MODEL_STORAGE[request.model].model.forward( [token], model_state ) model_tokens.append(token) out_tokens.append(token) if token in request.stop_tokens: yield { "content": "".join(cache_word_list), "tokens": out_tokens[out_last:], "finish_reason": f"stop:token:{token}", "state": model_state, } del out gc.collect() return for xxx in occurrence: occurrence[xxx] *= request.penalty_decay occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:]) if "\ufffd" in tmp: continue cache_word_list.append(tmp) output_cache_str = "".join(cache_word_list) for stop_words in request.stop: if stop_words in output_cache_str: yield { "content": output_cache_str.replace(stop_words, ""), "tokens": out_tokens[out_last - cache_word_len :], "finish_reason": f"stop:words:{stop_words}", "state": model_state, } del out gc.collect() return if len(cache_word_list) > cache_word_len: yield { "content": cache_word_list.pop(0), "tokens": out_tokens[out_last - cache_word_len :], "finish_reason": None, } out_last = i + 1 else: yield { "content": "", "tokens": [], "finish_reason": "length", } async def chatResponse( request: ChatCompletionRequest, model_state: any, completionId: str, enableReasoning: bool, ) -> ChatCompletion: createTimestamp = time.time() prompt = ( f"{cleanMessages(request.messages)}\n\nAssistant:{' tag "fullTextCursor": 0, "in_think": False, "cacheStr": "", } for chunk in generate( request, out, model_tokens, model_state, max_tokens=( 64000 if "max_tokens" not in request.model_fields_set and enableReasoning else request.max_tokens ), ): completionTokenCount += 1 chunkContent: str = chunk["content"] buffer.append(chunkContent) fullText = "".join(buffer) if chunk["finish_reason"]: finishReason = chunk["finish_reason"] response = ChatCompletionChunk( id=completionId, created=createTimestamp, model=request.model, usage=( Usage( prompt_tokens=promptTokenCount, completion_tokens=completionTokenCount, total_tokens=promptTokenCount + completionTokenCount, prompt_tokens_details={"cached_tokens": 0}, ) if request.include_usage else None ), choices=[ ChatCompletionChoice( index=0, delta=ChatCompletionMessage( content=None, reasoning_content=None ), logprobs=None, finish_reason=finishReason, ) ], ) markStart = fullText.find("<", streamConfig["fullTextCursor"]) if not streamConfig["isChecking"] and markStart != -1: streamConfig["isChecking"] = True if streamConfig["in_think"]: response.choices[0].delta.reasoning_content = fullText[ streamConfig["fullTextCursor"] : markStart ] else: response.choices[0].delta.content = fullText[ streamConfig["fullTextCursor"] : markStart ] streamConfig["cacheStr"] = "" streamConfig["fullTextCursor"] = markStart if streamConfig["isChecking"]: streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :] else: if streamConfig["in_think"]: response.choices[0].delta.reasoning_content = chunkContent else: response.choices[0].delta.content = chunkContent streamConfig["fullTextCursor"] = len(fullText) markEnd = fullText.find(">", streamConfig["fullTextCursor"]) if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None: streamConfig["isChecking"] = False if ( not streamConfig["in_think"] and streamConfig["cacheStr"].find("") != -1 ): streamConfig["in_think"] = True response.choices[0].delta.reasoning_content = ( response.choices[0].delta.reasoning_content if response.choices[0].delta.reasoning_content != None else "" + streamConfig["cacheStr"].replace("", "") ) elif ( streamConfig["in_think"] and streamConfig["cacheStr"].find("") != -1 ): streamConfig["in_think"] = False response.choices[0].delta.content = ( response.choices[0].delta.content if response.choices[0].delta.content != None else "" + streamConfig["cacheStr"].replace("", "") ) else: if streamConfig["in_think"]: response.choices[0].delta.reasoning_content = ( response.choices[0].delta.reasoning_content if response.choices[0].delta.reasoning_content != None else "" + streamConfig["cacheStr"] ) else: response.choices[0].delta.content = ( response.choices[0].delta.content if response.choices[0].delta.content != None else "" + streamConfig["cacheStr"] ) streamConfig["fullTextCursor"] = len(fullText) if ( response.choices[0].delta.content != None or response.choices[0].delta.reasoning_content != None ): yield f"data: {response.model_dump_json()}\n\n" await asyncio.sleep(0) del streamConfig else: for chunk in generate(request, out, model_tokens, model_state): completionTokenCount += 1 buffer.append(chunk["content"]) if chunk["finish_reason"]: finishReason = chunk["finish_reason"] response = ChatCompletionChunk( id=completionId, created=createTimestamp, model=request.model, usage=( Usage( prompt_tokens=promptTokenCount, completion_tokens=completionTokenCount, total_tokens=promptTokenCount + completionTokenCount, prompt_tokens_details={"cached_tokens": 0}, ) if request.include_usage else None ), choices=[ ChatCompletionChoice( index=0, delta=ChatCompletionMessage(content=chunk["content"]), logprobs=None, finish_reason=finishReason, ) ], ) yield f"data: {response.model_dump_json()}\n\n" await asyncio.sleep(0) genenrateTime = time.time() responseLog = { "content": "".join(buffer), "finish": finishReason, "prefill_len": promptTokenCount, "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), "gen_len": completionTokenCount, "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), } logger.info(f"[RES] {completionId} - {responseLog}") del buffer yield "data: [DONE]\n\n" @app.post("/api/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): completionId = str(next(CompletionIdGenerator)) logger.info(f"[REQ] {completionId} - {request.model_dump()}") modelName = request.model.split(":")[0] enableReasoning = ":thinking" in request.model if "rwkv-latest" in request.model: if enableReasoning: if DEFAULT_REASONING_MODEL_NAME == None: raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set") defaultSamplerConfig = MODEL_STORAGE[ DEFAULT_REASONING_MODEL_NAME ].MODEL_CONFIG.DEFAULT_SAMPLER request.model = DEFAULT_REASONING_MODEL_NAME else: if DEFALUT_MODEL_NAME == None: raise HTTPException(404, "DEFALUT_MODEL_NAME not set") defaultSamplerConfig = MODEL_STORAGE[ DEFALUT_MODEL_NAME ].MODEL_CONFIG.DEFAULT_SAMPLER request.model = DEFALUT_MODEL_NAME elif modelName in MODEL_STORAGE: defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER request.model = modelName else: raise f"Can not find `{modelName}`" async def chatResponseStreamDisconnect(): logGPUState() model_state = None request_dict = request.model_dump() for k, v in defaultSamplerConfig.model_dump().items(): if request_dict[k] == None: request_dict[k] = v realRequest = ChatCompletionRequest(**request_dict) logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}") if request.stream: r = StreamingResponse( chatResponseStream(realRequest, model_state, completionId, enableReasoning), media_type="text/event-stream", background=chatResponseStreamDisconnect, ) else: r = await chatResponse(realRequest, model_state, completionId, enableReasoning) return r app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static") if __name__ == "__main__": import uvicorn uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)