Spaces:
Sleeping
Sleeping
import os | |
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": | |
from modelscope import patch_hub | |
patch_hub() | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" | |
from config import CONFIG, ModelConfig | |
from utils import ( | |
cleanMessages, | |
parse_think_response, | |
remove_nested_think_tags_stack, | |
format_bytes, | |
) | |
import copy, types, gc, sys, re, time, collections, asyncio | |
from huggingface_hub import hf_hub_download | |
from loguru import logger | |
from rich import print | |
from snowflake import SnowflakeGenerator | |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) | |
from typing import List, Optional, Union, Any, Dict | |
from pydantic import BaseModel, Field, model_validator | |
from pydantic_settings import BaseSettings | |
import numpy as np | |
import torch | |
if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): | |
logger.info(f"CUDA not found, fall back to cpu") | |
CONFIG.STRATEGY = "cpu fp16" | |
if "cuda" in CONFIG.STRATEGY.lower(): | |
from pynvml import * | |
nvmlInit() | |
gpu_h = nvmlDeviceGetHandleByIndex(0) | |
def logGPUState(): | |
if "cuda" in CONFIG.STRATEGY: | |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) | |
logger.info( | |
f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}" | |
) | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models | |
os.environ["RWKV_JIT_ON"] = "1" | |
os.environ["RWKV_CUDA_ON"] = ( | |
"1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0" | |
) | |
from rwkv.model import RWKV | |
from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.gzip import GZipMiddleware | |
from api_types import ( | |
ChatMessage, | |
ChatCompletion, | |
ChatCompletionChunk, | |
Usage, | |
PromptTokensDetails, | |
ChatCompletionChoice, | |
ChatCompletionMessage, | |
) | |
class ModelStorage: | |
MODEL_CONFIG: Optional[ModelConfig] = None | |
model: Optional[RWKV] = None | |
pipeline: Optional[PIPELINE] = None | |
MODEL_STORAGE: Dict[str, ModelStorage] = {} | |
DEFALUT_MODEL_NAME = None | |
DEFAULT_REASONING_MODEL_NAME = None | |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}") | |
logGPUState() | |
for model_config in CONFIG.MODELS: | |
logger.info(f"Load Model - {model_config.SERVICE_NAME}") | |
if model_config.MODEL_FILE_PATH == None: | |
model_config.MODEL_FILE_PATH = hf_hub_download( | |
repo_id=model_config.DOWNLOAD_MODEL_REPO_ID, | |
filename=model_config.DOWNLOAD_MODEL_FILE_NAME, | |
local_dir=model_config.DOWNLOAD_MODEL_DIR, | |
) | |
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}") | |
if model_config.DEFAULT_CHAT: | |
if DEFALUT_MODEL_NAME != None: | |
logger.info( | |
f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" | |
) | |
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME | |
if model_config.DEFAULT_REASONING: | |
if DEFAULT_REASONING_MODEL_NAME != None: | |
logger.info( | |
f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" | |
) | |
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME | |
print(model_config.DEFAULT_SAMPLER) | |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() | |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config | |
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( | |
model=model_config.MODEL_FILE_PATH.replace(".pth", ""), | |
strategy=CONFIG.STRATEGY, | |
) | |
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( | |
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB | |
) | |
if "cuda" in CONFIG.STRATEGY: | |
torch.cuda.empty_cache() | |
gc.collect() | |
logGPUState() | |
logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`") | |
logger.info( | |
f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`" | |
) | |
class ChatCompletionRequest(BaseModel): | |
model: str = Field( | |
default="rwkv-latest", | |
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`", | |
) | |
messages: Optional[List[ChatMessage]] = Field(default=None) | |
prompt: Optional[str] = Field(default=None) | |
max_tokens: Optional[int] = Field(default=None) | |
temperature: Optional[float] = Field(default=None) | |
top_p: Optional[float] = Field(default=None) | |
presence_penalty: Optional[float] = Field(default=None) | |
count_penalty: Optional[float] = Field(default=None) | |
penalty_decay: Optional[float] = Field(default=None) | |
stream: Optional[bool] = Field(default=False) | |
state_name: Optional[str] = Field(default=None) | |
include_usage: Optional[bool] = Field(default=False) | |
stop: Optional[list[str]] = Field(["\n\n"]) | |
stop_tokens: Optional[list[int]] = Field([0]) | |
def validate_mutual_exclusivity(cls, data: Any) -> Any: | |
if not isinstance(data, dict): | |
return data | |
messages_provided = "messages" in data and data["messages"] != None | |
prompt_provided = "prompt" in data and data["prompt"] != None | |
if messages_provided and prompt_provided: | |
raise ValueError("messages and prompt cannot coexist. Choose one.") | |
if not messages_provided and not prompt_provided: | |
raise ValueError("Either messages or prompt must be provided.") | |
return data | |
app = FastAPI(title="RWKV OpenAI-Compatible API") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) | |
async def runPrefill( | |
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state | |
): | |
ctx = ctx.replace("\r\n", "\n") | |
tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx) | |
tokens = [int(x) for x in tokens] | |
model_tokens += tokens | |
while len(tokens) > 0: | |
out, model_state = MODEL_STORAGE[request.model].model.forward( | |
tokens[: CONFIG.CHUNK_LEN], model_state | |
) | |
tokens = tokens[CONFIG.CHUNK_LEN :] | |
await asyncio.sleep(0) | |
return out, model_tokens, model_state | |
def generate( | |
request: ChatCompletionRequest, | |
out, | |
model_tokens: List[int], | |
model_state, | |
max_tokens=2048, | |
): | |
args = PIPELINE_ARGS( | |
temperature=max(0.2, request.temperature), | |
top_p=request.top_p, | |
alpha_frequency=request.count_penalty, | |
alpha_presence=request.presence_penalty, | |
token_ban=[], # ban the generation of some tokens | |
token_stop=[0], | |
) # stop generation whenever you see any token here | |
occurrence = {} | |
out_tokens: List[int] = [] | |
out_last = 0 | |
cache_word_list = [] | |
cache_word_len = 5 | |
for i in range(max_tokens): | |
for n in occurrence: | |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency | |
# out[0] -= 1e10 # disable END_OF_TEXT | |
token = MODEL_STORAGE[request.model].pipeline.sample_logits( | |
out, temperature=args.temperature, top_p=args.top_p | |
) | |
if token == 0 and token in request.stop_tokens: | |
yield { | |
"content": "".join(cache_word_list), | |
"tokens": out_tokens[out_last:], | |
"finish_reason": "stop:token:0", | |
"state": model_state, | |
} | |
del out | |
gc.collect() | |
return | |
out, model_state = MODEL_STORAGE[request.model].model.forward( | |
[token], model_state | |
) | |
model_tokens.append(token) | |
out_tokens.append(token) | |
if token in request.stop_tokens: | |
yield { | |
"content": "".join(cache_word_list), | |
"tokens": out_tokens[out_last:], | |
"finish_reason": f"stop:token:{token}", | |
"state": model_state, | |
} | |
del out | |
gc.collect() | |
return | |
for xxx in occurrence: | |
occurrence[xxx] *= request.penalty_decay | |
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) | |
tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:]) | |
if "\ufffd" in tmp: | |
continue | |
cache_word_list.append(tmp) | |
output_cache_str = "".join(cache_word_list) | |
for stop_words in request.stop: | |
if stop_words in output_cache_str: | |
yield { | |
"content": output_cache_str.replace(stop_words, ""), | |
"tokens": out_tokens[out_last - cache_word_len :], | |
"finish_reason": f"stop:words:{stop_words}", | |
"state": model_state, | |
} | |
del out | |
gc.collect() | |
return | |
if len(cache_word_list) > cache_word_len: | |
yield { | |
"content": cache_word_list.pop(0), | |
"tokens": out_tokens[out_last - cache_word_len :], | |
"finish_reason": None, | |
} | |
out_last = i + 1 | |
else: | |
yield { | |
"content": "", | |
"tokens": [], | |
"finish_reason": "length", | |
} | |
async def chatResponse( | |
request: ChatCompletionRequest, | |
model_state: any, | |
completionId: str, | |
enableReasoning: bool, | |
) -> ChatCompletion: | |
createTimestamp = time.time() | |
prompt = ( | |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}" | |
if request.prompt == None | |
else request.prompt.strip() | |
) | |
logger.info(f"[REQ] {completionId} - prompt - {prompt}") | |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state) | |
prefillTime = time.time() | |
promptTokenCount = len(model_tokens) | |
fullResponse = " <think" if enableReasoning else "" | |
completionTokenCount = 0 | |
finishReason = None | |
for chunk in generate( | |
request, | |
out, | |
model_tokens, | |
model_state, | |
max_tokens=( | |
64000 | |
if "max_tokens" not in request.model_fields_set and enableReasoning | |
else request.max_tokens | |
), | |
): | |
fullResponse += chunk["content"] | |
completionTokenCount += 1 | |
if chunk["finish_reason"]: | |
finishReason = chunk["finish_reason"] | |
await asyncio.sleep(0) | |
genenrateTime = time.time() | |
responseLog = { | |
"content": fullResponse, | |
"finish": finishReason, | |
"prefill_len": promptTokenCount, | |
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), | |
"gen_len": completionTokenCount, | |
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), | |
} | |
logger.info(f"[RES] {completionId} - {responseLog}") | |
reasoning_content, content = parse_think_response(fullResponse) | |
response = ChatCompletion( | |
id=completionId, | |
created=int(createTimestamp), | |
model=request.model, | |
usage=Usage( | |
prompt_tokens=promptTokenCount, | |
completion_tokens=completionTokenCount, | |
total_tokens=promptTokenCount + completionTokenCount, | |
prompt_tokens_details={"cached_tokens": 0}, | |
), | |
choices=[ | |
ChatCompletionChoice( | |
index=0, | |
message=ChatCompletionMessage( | |
role="Assistant", | |
content=content, | |
reasoning_content=reasoning_content if reasoning_content else None, | |
), | |
logprobs=None, | |
finish_reason=finishReason, | |
) | |
], | |
) | |
return response | |
async def chatResponseStream( | |
request: ChatCompletionRequest, | |
model_state: any, | |
completionId: str, | |
enableReasoning: bool, | |
): | |
createTimestamp = int(time.time()) | |
prompt = ( | |
f"{cleanMessages(request.messages,enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}" | |
if request.prompt == None | |
else request.prompt.strip() | |
) | |
logger.info(f"[REQ] {completionId} - context - {prompt}") | |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state) | |
prefillTime = time.time() | |
promptTokenCount = len(model_tokens) | |
completionTokenCount = 0 | |
finishReason = None | |
response = ChatCompletionChunk( | |
id=completionId, | |
created=createTimestamp, | |
model=request.model, | |
usage=( | |
Usage( | |
prompt_tokens=promptTokenCount, | |
completion_tokens=completionTokenCount, | |
total_tokens=promptTokenCount + completionTokenCount, | |
prompt_tokens_details={"cached_tokens": 0}, | |
) | |
if request.include_usage | |
else None | |
), | |
choices=[ | |
ChatCompletionChoice( | |
index=0, | |
delta=ChatCompletionMessage( | |
role="Assistant", | |
content="", | |
reasoning_content="" if enableReasoning else None, | |
), | |
logprobs=None, | |
finish_reason=finishReason, | |
) | |
], | |
) | |
yield f"data: {response.model_dump_json()}\n\n" | |
buffer = [] | |
if enableReasoning: | |
buffer.append("<think") | |
streamConfig = { | |
"isChecking": False, # check whether is <think> tag | |
"fullTextCursor": 0, | |
"in_think": False, | |
"cacheStr": "", | |
} | |
for chunk in generate( | |
request, | |
out, | |
model_tokens, | |
model_state, | |
max_tokens=( | |
64000 | |
if "max_tokens" not in request.model_fields_set and enableReasoning | |
else request.max_tokens | |
), | |
): | |
completionTokenCount += 1 | |
chunkContent: str = chunk["content"] | |
buffer.append(chunkContent) | |
fullText = "".join(buffer) | |
if chunk["finish_reason"]: | |
finishReason = chunk["finish_reason"] | |
response = ChatCompletionChunk( | |
id=completionId, | |
created=createTimestamp, | |
model=request.model, | |
usage=( | |
Usage( | |
prompt_tokens=promptTokenCount, | |
completion_tokens=completionTokenCount, | |
total_tokens=promptTokenCount + completionTokenCount, | |
prompt_tokens_details={"cached_tokens": 0}, | |
) | |
if request.include_usage | |
else None | |
), | |
choices=[ | |
ChatCompletionChoice( | |
index=0, | |
delta=ChatCompletionMessage( | |
content=None, reasoning_content=None | |
), | |
logprobs=None, | |
finish_reason=finishReason, | |
) | |
], | |
) | |
markStart = fullText.find("<", streamConfig["fullTextCursor"]) | |
if not streamConfig["isChecking"] and markStart != -1: | |
streamConfig["isChecking"] = True | |
if streamConfig["in_think"]: | |
response.choices[0].delta.reasoning_content = fullText[ | |
streamConfig["fullTextCursor"] : markStart | |
] | |
else: | |
response.choices[0].delta.content = fullText[ | |
streamConfig["fullTextCursor"] : markStart | |
] | |
streamConfig["cacheStr"] = "" | |
streamConfig["fullTextCursor"] = markStart | |
if streamConfig["isChecking"]: | |
streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :] | |
else: | |
if streamConfig["in_think"]: | |
response.choices[0].delta.reasoning_content = chunkContent | |
else: | |
response.choices[0].delta.content = chunkContent | |
streamConfig["fullTextCursor"] = len(fullText) | |
markEnd = fullText.find(">", streamConfig["fullTextCursor"]) | |
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None: | |
streamConfig["isChecking"] = False | |
if ( | |
not streamConfig["in_think"] | |
and streamConfig["cacheStr"].find("<think>") != -1 | |
): | |
streamConfig["in_think"] = True | |
response.choices[0].delta.reasoning_content = ( | |
response.choices[0].delta.reasoning_content | |
if response.choices[0].delta.reasoning_content != None | |
else "" + streamConfig["cacheStr"].replace("<think>", "") | |
) | |
elif ( | |
streamConfig["in_think"] | |
and streamConfig["cacheStr"].find("</think>") != -1 | |
): | |
streamConfig["in_think"] = False | |
response.choices[0].delta.content = ( | |
response.choices[0].delta.content | |
if response.choices[0].delta.content != None | |
else "" + streamConfig["cacheStr"].replace("</think>", "") | |
) | |
else: | |
if streamConfig["in_think"]: | |
response.choices[0].delta.reasoning_content = ( | |
response.choices[0].delta.reasoning_content | |
if response.choices[0].delta.reasoning_content != None | |
else "" + streamConfig["cacheStr"] | |
) | |
else: | |
response.choices[0].delta.content = ( | |
response.choices[0].delta.content | |
if response.choices[0].delta.content != None | |
else "" + streamConfig["cacheStr"] | |
) | |
streamConfig["fullTextCursor"] = len(fullText) | |
if ( | |
response.choices[0].delta.content != None | |
or response.choices[0].delta.reasoning_content != None | |
): | |
yield f"data: {response.model_dump_json()}\n\n" | |
await asyncio.sleep(0) | |
del streamConfig | |
else: | |
for chunk in generate(request, out, model_tokens, model_state): | |
completionTokenCount += 1 | |
buffer.append(chunk["content"]) | |
if chunk["finish_reason"]: | |
finishReason = chunk["finish_reason"] | |
response = ChatCompletionChunk( | |
id=completionId, | |
created=createTimestamp, | |
model=request.model, | |
usage=( | |
Usage( | |
prompt_tokens=promptTokenCount, | |
completion_tokens=completionTokenCount, | |
total_tokens=promptTokenCount + completionTokenCount, | |
prompt_tokens_details={"cached_tokens": 0}, | |
) | |
if request.include_usage | |
else None | |
), | |
choices=[ | |
ChatCompletionChoice( | |
index=0, | |
delta=ChatCompletionMessage(content=chunk["content"]), | |
logprobs=None, | |
finish_reason=finishReason, | |
) | |
], | |
) | |
yield f"data: {response.model_dump_json()}\n\n" | |
await asyncio.sleep(0) | |
genenrateTime = time.time() | |
responseLog = { | |
"content": "".join(buffer), | |
"finish": finishReason, | |
"prefill_len": promptTokenCount, | |
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), | |
"gen_len": completionTokenCount, | |
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), | |
} | |
logger.info(f"[RES] {completionId} - {responseLog}") | |
del buffer | |
yield "data: [DONE]\n\n" | |
async def chat_completions(request: ChatCompletionRequest): | |
completionId = str(next(CompletionIdGenerator)) | |
logger.info(f"[REQ] {completionId} - {request.model_dump()}") | |
modelName = request.model.split(":")[0] | |
enableReasoning = ":thinking" in request.model | |
if "rwkv-latest" in request.model: | |
if enableReasoning: | |
if DEFAULT_REASONING_MODEL_NAME == None: | |
raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set") | |
defaultSamplerConfig = MODEL_STORAGE[ | |
DEFAULT_REASONING_MODEL_NAME | |
].MODEL_CONFIG.DEFAULT_SAMPLER | |
request.model = DEFAULT_REASONING_MODEL_NAME | |
else: | |
if DEFALUT_MODEL_NAME == None: | |
raise HTTPException(404, "DEFALUT_MODEL_NAME not set") | |
defaultSamplerConfig = MODEL_STORAGE[ | |
DEFALUT_MODEL_NAME | |
].MODEL_CONFIG.DEFAULT_SAMPLER | |
request.model = DEFALUT_MODEL_NAME | |
elif modelName in MODEL_STORAGE: | |
defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER | |
request.model = modelName | |
else: | |
raise f"Can not find `{modelName}`" | |
async def chatResponseStreamDisconnect(): | |
logGPUState() | |
model_state = None | |
request_dict = request.model_dump() | |
for k, v in defaultSamplerConfig.model_dump().items(): | |
if request_dict[k] == None: | |
request_dict[k] = v | |
realRequest = ChatCompletionRequest(**request_dict) | |
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}") | |
if request.stream: | |
r = StreamingResponse( | |
chatResponseStream(realRequest, model_state, completionId, enableReasoning), | |
media_type="text/event-stream", | |
background=chatResponseStreamDisconnect, | |
) | |
else: | |
r = await chatResponse(realRequest, model_state, completionId, enableReasoning) | |
return r | |
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT) | |