sparkleman
UPDATE
6b82cc0
import os
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
from modelscope import patch_hub
patch_hub()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
from config import CONFIG, ModelConfig
from utils import (
cleanMessages,
parse_think_response,
remove_nested_think_tags_stack,
format_bytes,
)
import copy, types, gc, sys, re, time, collections, asyncio
from huggingface_hub import hf_hub_download
from loguru import logger
from rich import print
from snowflake import SnowflakeGenerator
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
from typing import List, Optional, Union, Any, Dict
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings
import numpy as np
import torch
if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
logger.info(f"CUDA not found, fall back to cpu")
CONFIG.STRATEGY = "cpu fp16"
if "cuda" in CONFIG.STRATEGY.lower():
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
def logGPUState():
if "cuda" in CONFIG.STRATEGY:
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
logger.info(
f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = (
"1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.gzip import GZipMiddleware
from api_types import (
ChatMessage,
ChatCompletion,
ChatCompletionChunk,
Usage,
PromptTokensDetails,
ChatCompletionChoice,
ChatCompletionMessage,
)
class ModelStorage:
MODEL_CONFIG: Optional[ModelConfig] = None
model: Optional[RWKV] = None
pipeline: Optional[PIPELINE] = None
MODEL_STORAGE: Dict[str, ModelStorage] = {}
DEFALUT_MODEL_NAME = None
DEFAULT_REASONING_MODEL_NAME = None
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
logGPUState()
for model_config in CONFIG.MODELS:
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
if model_config.MODEL_FILE_PATH == None:
model_config.MODEL_FILE_PATH = hf_hub_download(
repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
local_dir=model_config.DOWNLOAD_MODEL_DIR,
)
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
if model_config.DEFAULT_CHAT:
if DEFALUT_MODEL_NAME != None:
logger.info(
f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
)
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
if model_config.DEFAULT_REASONING:
if DEFAULT_REASONING_MODEL_NAME != None:
logger.info(
f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
)
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
print(model_config.DEFAULT_SAMPLER)
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
strategy=CONFIG.STRATEGY,
)
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
)
if "cuda" in CONFIG.STRATEGY:
torch.cuda.empty_cache()
gc.collect()
logGPUState()
logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
logger.info(
f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
)
class ChatCompletionRequest(BaseModel):
model: str = Field(
default="rwkv-latest",
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
)
messages: Optional[List[ChatMessage]] = Field(default=None)
prompt: Optional[str] = Field(default=None)
max_tokens: Optional[int] = Field(default=None)
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
presence_penalty: Optional[float] = Field(default=None)
count_penalty: Optional[float] = Field(default=None)
penalty_decay: Optional[float] = Field(default=None)
stream: Optional[bool] = Field(default=False)
state_name: Optional[str] = Field(default=None)
include_usage: Optional[bool] = Field(default=False)
stop: Optional[list[str]] = Field(["\n\n"])
stop_tokens: Optional[list[int]] = Field([0])
@model_validator(mode="before")
@classmethod
def validate_mutual_exclusivity(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
messages_provided = "messages" in data and data["messages"] != None
prompt_provided = "prompt" in data and data["prompt"] != None
if messages_provided and prompt_provided:
raise ValueError("messages and prompt cannot coexist. Choose one.")
if not messages_provided and not prompt_provided:
raise ValueError("Either messages or prompt must be provided.")
return data
app = FastAPI(title="RWKV OpenAI-Compatible API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
async def runPrefill(
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
):
ctx = ctx.replace("\r\n", "\n")
tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens
while len(tokens) > 0:
out, model_state = MODEL_STORAGE[request.model].model.forward(
tokens[: CONFIG.CHUNK_LEN], model_state
)
tokens = tokens[CONFIG.CHUNK_LEN :]
await asyncio.sleep(0)
return out, model_tokens, model_state
def generate(
request: ChatCompletionRequest,
out,
model_tokens: List[int],
model_state,
max_tokens=2048,
):
args = PIPELINE_ARGS(
temperature=max(0.2, request.temperature),
top_p=request.top_p,
alpha_frequency=request.count_penalty,
alpha_presence=request.presence_penalty,
token_ban=[], # ban the generation of some tokens
token_stop=[0],
) # stop generation whenever you see any token here
occurrence = {}
out_tokens: List[int] = []
out_last = 0
cache_word_list = []
cache_word_len = 5
for i in range(max_tokens):
for n in occurrence:
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
# out[0] -= 1e10 # disable END_OF_TEXT
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
out, temperature=args.temperature, top_p=args.top_p
)
if token == 0 and token in request.stop_tokens:
yield {
"content": "".join(cache_word_list),
"tokens": out_tokens[out_last:],
"finish_reason": "stop:token:0",
"state": model_state,
}
del out
gc.collect()
return
out, model_state = MODEL_STORAGE[request.model].model.forward(
[token], model_state
)
model_tokens.append(token)
out_tokens.append(token)
if token in request.stop_tokens:
yield {
"content": "".join(cache_word_list),
"tokens": out_tokens[out_last:],
"finish_reason": f"stop:token:{token}",
"state": model_state,
}
del out
gc.collect()
return
for xxx in occurrence:
occurrence[xxx] *= request.penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
if "\ufffd" in tmp:
continue
cache_word_list.append(tmp)
output_cache_str = "".join(cache_word_list)
for stop_words in request.stop:
if stop_words in output_cache_str:
yield {
"content": output_cache_str.replace(stop_words, ""),
"tokens": out_tokens[out_last - cache_word_len :],
"finish_reason": f"stop:words:{stop_words}",
"state": model_state,
}
del out
gc.collect()
return
if len(cache_word_list) > cache_word_len:
yield {
"content": cache_word_list.pop(0),
"tokens": out_tokens[out_last - cache_word_len :],
"finish_reason": None,
}
out_last = i + 1
else:
yield {
"content": "",
"tokens": [],
"finish_reason": "length",
}
async def chatResponse(
request: ChatCompletionRequest,
model_state: any,
completionId: str,
enableReasoning: bool,
) -> ChatCompletion:
createTimestamp = time.time()
prompt = (
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
if request.prompt == None
else request.prompt.strip()
)
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
prefillTime = time.time()
promptTokenCount = len(model_tokens)
fullResponse = " <think" if enableReasoning else ""
completionTokenCount = 0
finishReason = None
for chunk in generate(
request,
out,
model_tokens,
model_state,
max_tokens=(
64000
if "max_tokens" not in request.model_fields_set and enableReasoning
else request.max_tokens
),
):
fullResponse += chunk["content"]
completionTokenCount += 1
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
await asyncio.sleep(0)
genenrateTime = time.time()
responseLog = {
"content": fullResponse,
"finish": finishReason,
"prefill_len": promptTokenCount,
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
"gen_len": completionTokenCount,
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
}
logger.info(f"[RES] {completionId} - {responseLog}")
reasoning_content, content = parse_think_response(fullResponse)
response = ChatCompletion(
id=completionId,
created=int(createTimestamp),
model=request.model,
usage=Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
),
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="Assistant",
content=content,
reasoning_content=reasoning_content if reasoning_content else None,
),
logprobs=None,
finish_reason=finishReason,
)
],
)
return response
async def chatResponseStream(
request: ChatCompletionRequest,
model_state: any,
completionId: str,
enableReasoning: bool,
):
createTimestamp = int(time.time())
prompt = (
f"{cleanMessages(request.messages,enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
if request.prompt == None
else request.prompt.strip()
)
logger.info(f"[REQ] {completionId} - context - {prompt}")
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
prefillTime = time.time()
promptTokenCount = len(model_tokens)
completionTokenCount = 0
finishReason = None
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(
role="Assistant",
content="",
reasoning_content="" if enableReasoning else None,
),
logprobs=None,
finish_reason=finishReason,
)
],
)
yield f"data: {response.model_dump_json()}\n\n"
buffer = []
if enableReasoning:
buffer.append("<think")
streamConfig = {
"isChecking": False, # check whether is <think> tag
"fullTextCursor": 0,
"in_think": False,
"cacheStr": "",
}
for chunk in generate(
request,
out,
model_tokens,
model_state,
max_tokens=(
64000
if "max_tokens" not in request.model_fields_set and enableReasoning
else request.max_tokens
),
):
completionTokenCount += 1
chunkContent: str = chunk["content"]
buffer.append(chunkContent)
fullText = "".join(buffer)
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(
content=None, reasoning_content=None
),
logprobs=None,
finish_reason=finishReason,
)
],
)
markStart = fullText.find("<", streamConfig["fullTextCursor"])
if not streamConfig["isChecking"] and markStart != -1:
streamConfig["isChecking"] = True
if streamConfig["in_think"]:
response.choices[0].delta.reasoning_content = fullText[
streamConfig["fullTextCursor"] : markStart
]
else:
response.choices[0].delta.content = fullText[
streamConfig["fullTextCursor"] : markStart
]
streamConfig["cacheStr"] = ""
streamConfig["fullTextCursor"] = markStart
if streamConfig["isChecking"]:
streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
else:
if streamConfig["in_think"]:
response.choices[0].delta.reasoning_content = chunkContent
else:
response.choices[0].delta.content = chunkContent
streamConfig["fullTextCursor"] = len(fullText)
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
streamConfig["isChecking"] = False
if (
not streamConfig["in_think"]
and streamConfig["cacheStr"].find("<think>") != -1
):
streamConfig["in_think"] = True
response.choices[0].delta.reasoning_content = (
response.choices[0].delta.reasoning_content
if response.choices[0].delta.reasoning_content != None
else "" + streamConfig["cacheStr"].replace("<think>", "")
)
elif (
streamConfig["in_think"]
and streamConfig["cacheStr"].find("</think>") != -1
):
streamConfig["in_think"] = False
response.choices[0].delta.content = (
response.choices[0].delta.content
if response.choices[0].delta.content != None
else "" + streamConfig["cacheStr"].replace("</think>", "")
)
else:
if streamConfig["in_think"]:
response.choices[0].delta.reasoning_content = (
response.choices[0].delta.reasoning_content
if response.choices[0].delta.reasoning_content != None
else "" + streamConfig["cacheStr"]
)
else:
response.choices[0].delta.content = (
response.choices[0].delta.content
if response.choices[0].delta.content != None
else "" + streamConfig["cacheStr"]
)
streamConfig["fullTextCursor"] = len(fullText)
if (
response.choices[0].delta.content != None
or response.choices[0].delta.reasoning_content != None
):
yield f"data: {response.model_dump_json()}\n\n"
await asyncio.sleep(0)
del streamConfig
else:
for chunk in generate(request, out, model_tokens, model_state):
completionTokenCount += 1
buffer.append(chunk["content"])
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(content=chunk["content"]),
logprobs=None,
finish_reason=finishReason,
)
],
)
yield f"data: {response.model_dump_json()}\n\n"
await asyncio.sleep(0)
genenrateTime = time.time()
responseLog = {
"content": "".join(buffer),
"finish": finishReason,
"prefill_len": promptTokenCount,
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
"gen_len": completionTokenCount,
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
}
logger.info(f"[RES] {completionId} - {responseLog}")
del buffer
yield "data: [DONE]\n\n"
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
completionId = str(next(CompletionIdGenerator))
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
modelName = request.model.split(":")[0]
enableReasoning = ":thinking" in request.model
if "rwkv-latest" in request.model:
if enableReasoning:
if DEFAULT_REASONING_MODEL_NAME == None:
raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set")
defaultSamplerConfig = MODEL_STORAGE[
DEFAULT_REASONING_MODEL_NAME
].MODEL_CONFIG.DEFAULT_SAMPLER
request.model = DEFAULT_REASONING_MODEL_NAME
else:
if DEFALUT_MODEL_NAME == None:
raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
defaultSamplerConfig = MODEL_STORAGE[
DEFALUT_MODEL_NAME
].MODEL_CONFIG.DEFAULT_SAMPLER
request.model = DEFALUT_MODEL_NAME
elif modelName in MODEL_STORAGE:
defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER
request.model = modelName
else:
raise f"Can not find `{modelName}`"
async def chatResponseStreamDisconnect():
logGPUState()
model_state = None
request_dict = request.model_dump()
for k, v in defaultSamplerConfig.model_dump().items():
if request_dict[k] == None:
request_dict[k] = v
realRequest = ChatCompletionRequest(**request_dict)
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
if request.stream:
r = StreamingResponse(
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
media_type="text/event-stream",
background=chatResponseStreamDisconnect,
)
else:
r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
return r
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)