Spaces:
Running
on
T4
Running
on
T4
| import os, copy, types, gc, sys, re, time, collections, asyncio | |
| from huggingface_hub import hf_hub_download | |
| from loguru import logger | |
| from snowflake import SnowflakeGenerator | |
| CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) | |
| from pynvml import * | |
| nvmlInit() | |
| gpu_h = nvmlDeviceGetHandleByIndex(0) | |
| from typing import List, Optional, Union | |
| from pydantic import BaseModel, Field | |
| from pydantic_settings import BaseSettings | |
| class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True): | |
| HOST: str = Field("127.0.0.1", description="Host") | |
| PORT: int = Field(8000, description="Port") | |
| DEBUG: bool = Field(False, description="Debug mode") | |
| STRATEGY: str = Field("cpu", description="Stratergy") | |
| MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096") | |
| DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world") | |
| DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir") | |
| MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path") | |
| GEN_penalty_decay: float = Field(0.996, description="Default penalty decay") | |
| CHUNK_LEN: int = Field( | |
| 256, | |
| description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)", | |
| ) | |
| VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name") | |
| CONFIG = Config() | |
| import numpy as np | |
| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models | |
| os.environ["RWKV_JIT_ON"] = "1" | |
| os.environ["RWKV_CUDA_ON"] = ( | |
| "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!! | |
| ) | |
| from rwkv.model import RWKV | |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from api_types import ( | |
| ChatMessage, | |
| ChatCompletion, | |
| ChatCompletionChunk, | |
| Usage, | |
| PromptTokensDetails, | |
| ChatCompletionChoice, | |
| ChatCompletionMessage, | |
| ) | |
| from utils import cleanMessages, parse_think_response | |
| logger.info(f"STRATEGY - {CONFIG.STRATEGY}") | |
| if CONFIG.MODEL_FILE_PATH == None: | |
| CONFIG.MODEL_FILE_PATH = hf_hub_download( | |
| repo_id=CONFIG.DOWNLOAD_REPO_ID, | |
| filename=f"{CONFIG.MODEL_TITLE}.pth", | |
| local_dir=CONFIG.DOWNLOAD_MODEL_DIR, | |
| ) | |
| logger.info(f"Load Model - {CONFIG.MODEL_FILE_PATH}") | |
| model = RWKV(model=CONFIG.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY) | |
| pipeline = PIPELINE(model, CONFIG.VOCAB) | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field( | |
| default="rwkv-latest", | |
| description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`", | |
| ) | |
| messages: List[ChatMessage] | |
| prompt: Union[str, None] = Field(default=None) | |
| max_tokens: int = Field(default=512) | |
| temperature: float = Field(default=1.0) | |
| top_p: float = Field(default=0.3) | |
| presencePenalty: float = Field(default=0.5) | |
| countPenalty: float = Field(default=0.5) | |
| stream: bool = Field(default=False) | |
| state_name: str = Field(default=None) | |
| include_usage: bool = Field(default=False) | |
| app = FastAPI(title="RWKV OpenAI-Compatible API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def runPrefill(ctx: str, model_tokens: List[int], model_state): | |
| ctx = ctx.replace("\r\n", "\n") | |
| tokens = pipeline.encode(ctx) | |
| tokens = [int(x) for x in tokens] | |
| model_tokens += tokens | |
| while len(tokens) > 0: | |
| out, model_state = model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) | |
| tokens = tokens[CONFIG.CHUNK_LEN :] | |
| return out, model_tokens, model_state | |
| def generate( | |
| request: ChatCompletionRequest, | |
| out, | |
| model_tokens, | |
| model_state, | |
| stops=["\n\n"], | |
| max_tokens=2048, | |
| ): | |
| args = PIPELINE_ARGS( | |
| temperature=max(0.2, request.temperature), | |
| top_p=request.top_p, | |
| alpha_frequency=request.countPenalty, | |
| alpha_presence=request.presencePenalty, | |
| token_ban=[], # ban the generation of some tokens | |
| token_stop=[0], | |
| ) # stop generation whenever you see any token here | |
| occurrence = {} | |
| out_tokens = [] | |
| out_last = 0 | |
| output_cache = collections.deque(maxlen=5) | |
| for i in range(max_tokens): | |
| for n in occurrence: | |
| out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency | |
| out[0] -= 1e10 # disable END_OF_TEXT | |
| token = pipeline.sample_logits( | |
| out, temperature=args.temperature, top_p=args.top_p | |
| ) | |
| out, model_state = model.forward([token], model_state) | |
| model_tokens += [token] | |
| out_tokens += [token] | |
| for xxx in occurrence: | |
| occurrence[xxx] *= CONFIG.GEN_penalty_decay | |
| occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) | |
| tmp: str = pipeline.decode(out_tokens[out_last:]) | |
| if "\ufffd" in tmp: | |
| continue | |
| output_cache.append(tmp) | |
| output_cache_str = "".join(output_cache) | |
| for stop_words in stops: | |
| if stop_words in output_cache_str: | |
| yield { | |
| "content": tmp.replace(stop_words, ""), | |
| "tokens": out_tokens[out_last:], | |
| "finish_reason": "stop", | |
| "state": model_state, | |
| } | |
| del out | |
| gc.collect() | |
| return | |
| yield { | |
| "content": tmp, | |
| "tokens": out_tokens[out_last:], | |
| "finish_reason": None, | |
| } | |
| out_last = i + 1 | |
| else: | |
| yield { | |
| "content": "", | |
| "tokens": [], | |
| "finish_reason": "length", | |
| } | |
| async def chatResponse( | |
| request: ChatCompletionRequest, model_state: any, completionId: str | |
| ) -> ChatCompletion: | |
| createTimestamp = time.time() | |
| enableReasoning = request.model.endswith(":thinking") | |
| prompt = ( | |
| f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}" | |
| if request.prompt == None | |
| else request.prompt.strip() | |
| ) | |
| out, model_tokens, model_state = runPrefill(prompt, [], model_state) | |
| prefillTime = time.time() | |
| promptTokenCount = len(model_tokens) | |
| fullResponse = " <think" if enableReasoning else "" | |
| completionTokenCount = 0 | |
| finishReason = None | |
| for chunk in generate( | |
| request, | |
| out, | |
| model_tokens, | |
| model_state, | |
| max_tokens=( | |
| 64000 | |
| if "max_tokens" not in request.model_fields_set and enableReasoning | |
| else request.max_tokens | |
| ), | |
| ): | |
| fullResponse += chunk["content"] | |
| completionTokenCount += 1 | |
| if chunk["finish_reason"]: | |
| finishReason = chunk["finish_reason"] | |
| await asyncio.sleep(0) | |
| genenrateTime = time.time() | |
| responseLog = { | |
| "content": fullResponse, | |
| "finish": finishReason, | |
| "prefill_len": promptTokenCount, | |
| "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), | |
| "gen_len": completionTokenCount, | |
| "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), | |
| } | |
| logger.info(f"[RES] {completionId} - {responseLog}") | |
| reasoning_content, content = parse_think_response(fullResponse) | |
| response = ChatCompletion( | |
| id=completionId, | |
| created=int(createTimestamp), | |
| model=request.model, | |
| usage=Usage( | |
| prompt_tokens=promptTokenCount, | |
| completion_tokens=completionTokenCount, | |
| total_tokens=promptTokenCount + completionTokenCount, | |
| prompt_tokens_details={"cached_tokens": 0}, | |
| ), | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| message=ChatCompletionMessage( | |
| role="Assistant", | |
| content=content, | |
| reasoning_content=reasoning_content if reasoning_content else None, | |
| ), | |
| logprobs=None, | |
| finish_reason=finishReason, | |
| ) | |
| ], | |
| ) | |
| return response | |
| async def chatResponseStream( | |
| request: ChatCompletionRequest, model_state: any, completionId: str | |
| ): | |
| createTimestamp = int(time.time()) | |
| enableReasoning = request.model.endswith(":thinking") | |
| prompt = ( | |
| f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}" | |
| if request.prompt == None | |
| else request.prompt.strip() | |
| ) | |
| out, model_tokens, model_state = runPrefill(prompt, [], model_state) | |
| prefillTime = time.time() | |
| promptTokenCount = len(model_tokens) | |
| completionTokenCount = 0 | |
| finishReason = None | |
| response = ChatCompletionChunk( | |
| id=completionId, | |
| created=createTimestamp, | |
| model=request.model, | |
| usage=( | |
| Usage( | |
| prompt_tokens=promptTokenCount, | |
| completion_tokens=completionTokenCount, | |
| total_tokens=promptTokenCount + completionTokenCount, | |
| prompt_tokens_details={"cached_tokens": 0}, | |
| ) | |
| if request.include_usage | |
| else None | |
| ), | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| delta=ChatCompletionMessage( | |
| role="Assistant", | |
| content="", | |
| reasoning_content="" if enableReasoning else None, | |
| ), | |
| logprobs=None, | |
| finish_reason=finishReason, | |
| ) | |
| ], | |
| ) | |
| yield f"data: {response.model_dump_json()}\n\n" | |
| buffer = [] | |
| if enableReasoning: | |
| buffer.append(" <think") | |
| streamConfig = { | |
| "isChecking": False, | |
| "fullTextCursor": 0, | |
| "in_think": False, | |
| "cacheStr": "", | |
| } | |
| for chunk in generate( | |
| request, | |
| out, | |
| model_tokens, | |
| model_state, | |
| max_tokens=( | |
| 64000 | |
| if "max_tokens" not in request.model_fields_set and enableReasoning | |
| else request.max_tokens | |
| ), | |
| ): | |
| completionTokenCount += 1 | |
| chunkContent: str = chunk["content"] | |
| buffer.append(chunkContent) | |
| fullText = "".join(buffer) | |
| if chunk["finish_reason"]: | |
| finishReason = chunk["finish_reason"] | |
| response = ChatCompletionChunk( | |
| id=completionId, | |
| created=createTimestamp, | |
| model=request.model, | |
| usage=( | |
| Usage( | |
| prompt_tokens=promptTokenCount, | |
| completion_tokens=completionTokenCount, | |
| total_tokens=promptTokenCount + completionTokenCount, | |
| prompt_tokens_details={"cached_tokens": 0}, | |
| ) | |
| if request.include_usage | |
| else None | |
| ), | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| delta=ChatCompletionMessage( | |
| content=None, reasoning_content=None | |
| ), | |
| logprobs=None, | |
| finish_reason=finishReason, | |
| ) | |
| ], | |
| ) | |
| markStart = fullText.find("<", streamConfig["fullTextCursor"]) | |
| if not streamConfig["isChecking"] and markStart != -1: | |
| streamConfig["isChecking"] = True | |
| if streamConfig["in_think"]: | |
| response.choices[0].delta.reasoning_content = fullText[ | |
| streamConfig["fullTextCursor"] : markStart | |
| ] | |
| else: | |
| response.choices[0].delta.content = fullText[ | |
| streamConfig["fullTextCursor"] : markStart | |
| ] | |
| streamConfig["cacheStr"] = "" | |
| streamConfig["fullTextCursor"] = markStart | |
| if streamConfig["isChecking"]: | |
| streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :] | |
| else: | |
| if streamConfig["in_think"]: | |
| response.choices[0].delta.reasoning_content = chunkContent | |
| else: | |
| response.choices[0].delta.content = chunkContent | |
| streamConfig["fullTextCursor"] = len(fullText) | |
| markEnd = fullText.find(">", streamConfig["fullTextCursor"]) | |
| if streamConfig["isChecking"] and markEnd != -1: | |
| streamConfig["isChecking"] = False | |
| if ( | |
| not streamConfig["in_think"] | |
| and streamConfig["cacheStr"].find("<think>") != -1 | |
| ): | |
| streamConfig["in_think"] = True | |
| response.choices[0].delta.reasoning_content = ( | |
| response.choices[0].delta.reasoning_content | |
| if response.choices[0].delta.reasoning_content != None | |
| else "" + streamConfig["cacheStr"].replace("<think>", "") | |
| ) | |
| elif ( | |
| streamConfig["in_think"] | |
| and streamConfig["cacheStr"].find("</think>") != -1 | |
| ): | |
| streamConfig["in_think"] = False | |
| response.choices[0].delta.content = ( | |
| response.choices[0].delta.content | |
| if response.choices[0].delta.content != None | |
| else "" + streamConfig["cacheStr"].replace("</think>", "") | |
| ) | |
| else: | |
| if streamConfig["in_think"]: | |
| response.choices[0].delta.reasoning_content = ( | |
| response.choices[0].delta.reasoning_content | |
| if response.choices[0].delta.reasoning_content != None | |
| else "" + streamConfig["cacheStr"] | |
| ) | |
| else: | |
| response.choices[0].delta.content = ( | |
| response.choices[0].delta.content | |
| if response.choices[0].delta.content != None | |
| else "" + streamConfig["cacheStr"] | |
| ) | |
| streamConfig["fullTextCursor"] = len(fullText) | |
| if ( | |
| response.choices[0].delta.content != None | |
| or response.choices[0].delta.reasoning_content != None | |
| ): | |
| yield f"data: {response.model_dump_json()}\n\n" | |
| await asyncio.sleep(0) | |
| del streamConfig | |
| else: | |
| for chunk in generate(request, out, model_tokens, model_state): | |
| completionTokenCount += 1 | |
| buffer.append(chunk["content"]) | |
| if chunk["finish_reason"]: | |
| finishReason = chunk["finish_reason"] | |
| response = ChatCompletionChunk( | |
| id=completionId, | |
| created=createTimestamp, | |
| model=request.model, | |
| usage=( | |
| Usage( | |
| prompt_tokens=promptTokenCount, | |
| completion_tokens=completionTokenCount, | |
| total_tokens=promptTokenCount + completionTokenCount, | |
| prompt_tokens_details={"cached_tokens": 0}, | |
| ) | |
| if request.include_usage | |
| else None | |
| ), | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| delta=ChatCompletionMessage(content=chunk["content"]), | |
| logprobs=None, | |
| finish_reason=finishReason, | |
| ) | |
| ], | |
| ) | |
| yield f"data: {response.model_dump_json()}\n\n" | |
| await asyncio.sleep(0) | |
| genenrateTime = time.time() | |
| responseLog = { | |
| "content": "".join(buffer), | |
| "finish": finishReason, | |
| "prefill_len": promptTokenCount, | |
| "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), | |
| "gen_len": completionTokenCount, | |
| "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), | |
| } | |
| logger.info(f"[RES] {completionId} - {responseLog}") | |
| del buffer | |
| yield "data: [DONE]\n\n" | |
| async def chat_completions(request: ChatCompletionRequest): | |
| completionId = str(next(CompletionIdGenerator)) | |
| logger.info(f"[REQ] {completionId} - {request.model_dump()}") | |
| def chatResponseStreamDisconnect(): | |
| gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) | |
| logger.info( | |
| f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}" | |
| ) | |
| model_state = None | |
| if request.stream: | |
| r = StreamingResponse( | |
| chatResponseStream(request, model_state, completionId), | |
| media_type="text/event-stream", | |
| background=chatResponseStreamDisconnect, | |
| ) | |
| else: | |
| r = await chatResponse(request, model_state, completionId) | |
| return r | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT) | |