Spaces:
Running
on
T4
Running
on
T4
sparkleman
commited on
Commit
·
adb6ad5
1
Parent(s):
ff3952a
UPDATE: Add frontend
Browse files- .gitignore +4 -1
- Dockerfile +53 -2
- README.md +1 -1
- app.py +136 -61
- config.py +82 -0
- openai_test.py +0 -78
.gitignore
CHANGED
|
@@ -13,4 +13,7 @@ wheels/
|
|
| 13 |
|
| 14 |
*pth
|
| 15 |
*.pt
|
| 16 |
-
*.st
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
*pth
|
| 15 |
*.pt
|
| 16 |
+
*.st
|
| 17 |
+
*local*
|
| 18 |
+
|
| 19 |
+
dist-frontend/
|
Dockerfile
CHANGED
|
@@ -9,12 +9,23 @@ apt install --no-install-recommends -y \
|
|
| 9 |
apt clean && rm -rf /var/lib/apt/lists/*
|
| 10 |
EOF
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 13 |
|
| 14 |
COPY . .
|
| 15 |
|
| 16 |
RUN useradd -m -u 1000 user
|
| 17 |
-
# Switch to the "user" user
|
| 18 |
USER user
|
| 19 |
|
| 20 |
ENV HOME=/home/user \
|
|
@@ -23,7 +34,47 @@ ENV HOME=/home/user \
|
|
| 23 |
WORKDIR $HOME/app
|
| 24 |
|
| 25 |
COPY --chown=user . $HOME/app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
RUN uv sync --frozen --extra cu124
|
| 28 |
|
| 29 |
-
CMD ["uv","run","app.py",
|
|
|
|
| 9 |
apt clean && rm -rf /var/lib/apt/lists/*
|
| 10 |
EOF
|
| 11 |
|
| 12 |
+
# 安装Node.js和npm
|
| 13 |
+
RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
| 14 |
+
apt-get install -y nodejs
|
| 15 |
+
|
| 16 |
+
# 安装pnpm
|
| 17 |
+
RUN npm install -g pnpm
|
| 18 |
+
|
| 19 |
+
# 克隆前端仓库并构建
|
| 20 |
+
RUN git clone https://github.com/SolomonLeon/web-rwkv-realweb.git /frontend
|
| 21 |
+
WORKDIR /frontend
|
| 22 |
+
RUN pnpm install && pnpm run build
|
| 23 |
+
|
| 24 |
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 25 |
|
| 26 |
COPY . .
|
| 27 |
|
| 28 |
RUN useradd -m -u 1000 user
|
|
|
|
| 29 |
USER user
|
| 30 |
|
| 31 |
ENV HOME=/home/user \
|
|
|
|
| 34 |
WORKDIR $HOME/app
|
| 35 |
|
| 36 |
COPY --chown=user . $HOME/app
|
| 37 |
+
COPY --chown=user /frontend/dist $HOME/app/dist-frontend
|
| 38 |
+
|
| 39 |
+
RUN cat > $HOME/app/config.local.yaml<<EOF
|
| 40 |
+
HOST: "0.0.0.0"
|
| 41 |
+
PORT: 7860
|
| 42 |
+
STRATEGY: "cuda fp16"
|
| 43 |
+
RWKV_CUDA_ON: False
|
| 44 |
+
CHUNK_LEN: 256
|
| 45 |
+
MODELS:
|
| 46 |
+
- SERVICE_NAME: "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096"
|
| 47 |
+
DOWNLOAD_MODEL_FILE_NAME: "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth"
|
| 48 |
+
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv-7-world"
|
| 49 |
+
DOWNLOAD_MODEL_DIR: "./"
|
| 50 |
+
REASONING: False
|
| 51 |
+
DEFAULT: True
|
| 52 |
+
DEFAULT_SAMPLER:
|
| 53 |
+
max_tokens: 512
|
| 54 |
+
temperature: 1.0
|
| 55 |
+
top_p: 0.3
|
| 56 |
+
presence_penalty: 0.5
|
| 57 |
+
count_penalty: 0.5
|
| 58 |
+
penalty_decay: 0.996
|
| 59 |
+
stop:
|
| 60 |
+
- "\n\n"
|
| 61 |
+
- SERVICE_NAME: "RWKV7-G1-0.1B-68%trained-20250303-ctx4k"
|
| 62 |
+
DOWNLOAD_MODEL_FILE_NAME: "RWKV7-G1-0.1B-68%trained-20250303-ctx4k.pth"
|
| 63 |
+
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/temp-latest-training-models"
|
| 64 |
+
DOWNLOAD_MODEL_DIR: "./"
|
| 65 |
+
REASONING: True
|
| 66 |
+
DEFAULT: True
|
| 67 |
+
DEFAULT_SAMPLER:
|
| 68 |
+
max_tokens: 4096
|
| 69 |
+
temperature: 1.0
|
| 70 |
+
top_p: 0.3
|
| 71 |
+
presence_penalty: 0.5
|
| 72 |
+
count_penalty: 0.5
|
| 73 |
+
penalty_decay: 0.996
|
| 74 |
+
stop:
|
| 75 |
+
- "\n\n"
|
| 76 |
+
EOF
|
| 77 |
|
| 78 |
RUN uv sync --frozen --extra cu124
|
| 79 |
|
| 80 |
+
CMD ["uv","run","app.py",]
|
README.md
CHANGED
|
@@ -25,7 +25,7 @@ python app.py --strategy "cuda fp16" --model_title "RWKV-x070-World-0.1B-v2.8-20
|
|
| 25 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
| 26 |
```
|
| 27 |
|
| 28 |
-
`RWKV7-G1-0.
|
| 29 |
|
| 30 |
```shell
|
| 31 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
|
|
|
| 25 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
| 26 |
```
|
| 27 |
|
| 28 |
+
`RWKV7-G1-0.4B-68%trained-20250303-ctx4k`
|
| 29 |
|
| 30 |
```shell
|
| 31 |
python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
|
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
| 2 |
from huggingface_hub import hf_hub_download
|
| 3 |
from loguru import logger
|
|
@@ -6,32 +8,11 @@ from snowflake import SnowflakeGenerator
|
|
| 6 |
|
| 7 |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
| 8 |
|
| 9 |
-
from typing import List, Optional, Union
|
| 10 |
-
from pydantic import BaseModel, Field
|
| 11 |
from pydantic_settings import BaseSettings
|
| 12 |
|
| 13 |
|
| 14 |
-
class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
|
| 15 |
-
HOST: str = Field("127.0.0.1", description="Host")
|
| 16 |
-
PORT: int = Field(8000, description="Port")
|
| 17 |
-
DEBUG: bool = Field(False, description="Debug mode")
|
| 18 |
-
STRATEGY: str = Field("cpu", description="Stratergy")
|
| 19 |
-
MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096")
|
| 20 |
-
DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world")
|
| 21 |
-
DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir")
|
| 22 |
-
MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path")
|
| 23 |
-
GEN_penalty_decay: float = Field(0.996, description="Default penalty decay")
|
| 24 |
-
CHUNK_LEN: int = Field(
|
| 25 |
-
256,
|
| 26 |
-
description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
|
| 27 |
-
)
|
| 28 |
-
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
| 29 |
-
RWKV_CUDA_ON:bool = Field(False, description="`True` to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
CONFIG = Config()
|
| 33 |
-
|
| 34 |
-
|
| 35 |
import numpy as np
|
| 36 |
import torch
|
| 37 |
|
|
@@ -58,9 +39,10 @@ os.environ["RWKV_CUDA_ON"] = (
|
|
| 58 |
from rwkv.model import RWKV
|
| 59 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 60 |
|
| 61 |
-
from fastapi import FastAPI
|
| 62 |
from fastapi.responses import StreamingResponse
|
| 63 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 64 |
|
| 65 |
from api_types import (
|
| 66 |
ChatMessage,
|
|
@@ -74,17 +56,50 @@ from api_types import (
|
|
| 74 |
from utils import cleanMessages, parse_think_response
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
)
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
class ChatCompletionRequest(BaseModel):
|
|
@@ -92,16 +107,33 @@ class ChatCompletionRequest(BaseModel):
|
|
| 92 |
default="rwkv-latest",
|
| 93 |
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
|
| 94 |
)
|
| 95 |
-
messages: List[ChatMessage]
|
| 96 |
prompt: Optional[str] = Field(default=None)
|
| 97 |
-
max_tokens: int = Field(default=
|
| 98 |
-
temperature: float = Field(default=
|
| 99 |
-
top_p: float = Field(default=
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
app = FastAPI(title="RWKV OpenAI-Compatible API")
|
|
@@ -115,15 +147,19 @@ app.add_middleware(
|
|
| 115 |
)
|
| 116 |
|
| 117 |
|
| 118 |
-
async def runPrefill(
|
|
|
|
|
|
|
| 119 |
ctx = ctx.replace("\r\n", "\n")
|
| 120 |
|
| 121 |
-
tokens = pipeline.encode(ctx)
|
| 122 |
tokens = [int(x) for x in tokens]
|
| 123 |
model_tokens += tokens
|
| 124 |
|
| 125 |
while len(tokens) > 0:
|
| 126 |
-
out, model_state = model.forward(
|
|
|
|
|
|
|
| 127 |
tokens = tokens[CONFIG.CHUNK_LEN :]
|
| 128 |
await asyncio.sleep(0)
|
| 129 |
|
|
@@ -141,8 +177,8 @@ def generate(
|
|
| 141 |
args = PIPELINE_ARGS(
|
| 142 |
temperature=max(0.2, request.temperature),
|
| 143 |
top_p=request.top_p,
|
| 144 |
-
alpha_frequency=request.
|
| 145 |
-
alpha_presence=request.
|
| 146 |
token_ban=[], # ban the generation of some tokens
|
| 147 |
token_stop=[0],
|
| 148 |
) # stop generation whenever you see any token here
|
|
@@ -158,20 +194,22 @@ def generate(
|
|
| 158 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
| 159 |
out[0] -= 1e10 # disable END_OF_TEXT
|
| 160 |
|
| 161 |
-
token = pipeline.sample_logits(
|
| 162 |
out, temperature=args.temperature, top_p=args.top_p
|
| 163 |
)
|
| 164 |
|
| 165 |
-
out, model_state = model.forward(
|
|
|
|
|
|
|
| 166 |
model_tokens += [token]
|
| 167 |
|
| 168 |
out_tokens += [token]
|
| 169 |
|
| 170 |
for xxx in occurrence:
|
| 171 |
-
occurrence[xxx] *=
|
| 172 |
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
| 173 |
|
| 174 |
-
tmp: str = pipeline.decode(out_tokens[out_last:])
|
| 175 |
|
| 176 |
if "\ufffd" in tmp:
|
| 177 |
continue
|
|
@@ -210,19 +248,20 @@ def generate(
|
|
| 210 |
|
| 211 |
|
| 212 |
async def chatResponse(
|
| 213 |
-
request: ChatCompletionRequest,
|
|
|
|
|
|
|
|
|
|
| 214 |
) -> ChatCompletion:
|
| 215 |
createTimestamp = time.time()
|
| 216 |
|
| 217 |
-
enableReasoning = request.model.endswith(":thinking")
|
| 218 |
-
|
| 219 |
prompt = (
|
| 220 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 221 |
if request.prompt == None
|
| 222 |
else request.prompt.strip()
|
| 223 |
)
|
| 224 |
|
| 225 |
-
out, model_tokens, model_state = await runPrefill(prompt, [], model_state)
|
| 226 |
|
| 227 |
prefillTime = time.time()
|
| 228 |
promptTokenCount = len(model_tokens)
|
|
@@ -291,19 +330,20 @@ async def chatResponse(
|
|
| 291 |
|
| 292 |
|
| 293 |
async def chatResponseStream(
|
| 294 |
-
request: ChatCompletionRequest,
|
|
|
|
|
|
|
|
|
|
| 295 |
):
|
| 296 |
createTimestamp = int(time.time())
|
| 297 |
|
| 298 |
-
enableReasoning = request.model.endswith(":thinking")
|
| 299 |
-
|
| 300 |
prompt = (
|
| 301 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 302 |
if request.prompt == None
|
| 303 |
else request.prompt.strip()
|
| 304 |
)
|
| 305 |
|
| 306 |
-
out, model_tokens, model_state = await runPrefill(prompt, [], model_state)
|
| 307 |
|
| 308 |
prefillTime = time.time()
|
| 309 |
promptTokenCount = len(model_tokens)
|
|
@@ -343,7 +383,7 @@ async def chatResponseStream(
|
|
| 343 |
buffer = []
|
| 344 |
|
| 345 |
if enableReasoning:
|
| 346 |
-
buffer.append("
|
| 347 |
|
| 348 |
streamConfig = {
|
| 349 |
"isChecking": False,
|
|
@@ -532,6 +572,32 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 532 |
completionId = str(next(CompletionIdGenerator))
|
| 533 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
async def chatResponseStreamDisconnect():
|
| 536 |
if "cuda" in CONFIG.STRATEGY:
|
| 537 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
|
@@ -540,18 +606,27 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 540 |
)
|
| 541 |
|
| 542 |
model_state = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
|
| 544 |
if request.stream:
|
| 545 |
r = StreamingResponse(
|
| 546 |
-
chatResponseStream(
|
| 547 |
media_type="text/event-stream",
|
| 548 |
background=chatResponseStreamDisconnect,
|
| 549 |
)
|
| 550 |
else:
|
| 551 |
-
r = await chatResponse(
|
| 552 |
|
| 553 |
return r
|
| 554 |
|
|
|
|
| 555 |
|
| 556 |
if __name__ == "__main__":
|
| 557 |
import uvicorn
|
|
|
|
| 1 |
+
from config import CONFIG, ModelConfig
|
| 2 |
+
|
| 3 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
from loguru import logger
|
|
|
|
| 8 |
|
| 9 |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
| 10 |
|
| 11 |
+
from typing import List, Optional, Union, Any, Dict
|
| 12 |
+
from pydantic import BaseModel, Field, model_validator
|
| 13 |
from pydantic_settings import BaseSettings
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
import torch
|
| 18 |
|
|
|
|
| 39 |
from rwkv.model import RWKV
|
| 40 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 41 |
|
| 42 |
+
from fastapi import FastAPI, HTTPException
|
| 43 |
from fastapi.responses import StreamingResponse
|
| 44 |
from fastapi.middleware.cors import CORSMiddleware
|
| 45 |
+
from fastapi.staticfiles import StaticFiles
|
| 46 |
|
| 47 |
from api_types import (
|
| 48 |
ChatMessage,
|
|
|
|
| 56 |
from utils import cleanMessages, parse_think_response
|
| 57 |
|
| 58 |
|
| 59 |
+
class ModelStorage:
|
| 60 |
+
MODEL_CONFIG: Optional[ModelConfig] = None
|
| 61 |
+
model: Optional[RWKV] = None
|
| 62 |
+
pipeline: Optional[PIPELINE] = None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
MODEL_STORAGE: Dict[str, ModelStorage] = {}
|
| 66 |
+
|
| 67 |
+
DEFALUT_MODEL_NAME = None
|
| 68 |
+
DEFAULT_REASONING_MODEL_NAME = None
|
| 69 |
+
|
| 70 |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
| 71 |
+
|
| 72 |
+
for model_config in CONFIG.MODELS:
|
| 73 |
+
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
|
| 74 |
+
|
| 75 |
+
if model_config.MODEL_FILE_PATH == None:
|
| 76 |
+
model_config.MODEL_FILE_PATH = hf_hub_download(
|
| 77 |
+
repo_id=model_config.DOWNLOAD_MODEL_REPO_ID,
|
| 78 |
+
filename=model_config.DOWNLOAD_MODEL_FILE_NAME,
|
| 79 |
+
local_dir=model_config.DOWNLOAD_MODEL_DIR,
|
| 80 |
+
)
|
| 81 |
+
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
|
| 82 |
+
|
| 83 |
+
tmp_model = RWKV(
|
| 84 |
+
model=model_config.DOWNLOAD_MODEL_FILE_NAME.replace(".pth", ""),
|
| 85 |
+
strategy=CONFIG.STRATEGY,
|
| 86 |
)
|
| 87 |
+
tmp_pipeline = PIPELINE(tmp_model, model_config.VOCAB)
|
| 88 |
|
| 89 |
+
if model_config.DEFAULT:
|
| 90 |
+
if model_config.REASONING:
|
| 91 |
+
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
|
| 92 |
+
else:
|
| 93 |
+
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
| 94 |
+
|
| 95 |
+
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
| 96 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
| 97 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
|
| 98 |
+
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = tmp_pipeline
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
logger.info(f"DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
|
| 102 |
+
logger.info(f"DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`")
|
| 103 |
|
| 104 |
|
| 105 |
class ChatCompletionRequest(BaseModel):
|
|
|
|
| 107 |
default="rwkv-latest",
|
| 108 |
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
|
| 109 |
)
|
| 110 |
+
messages: Optional[List[ChatMessage]] = Field(default=None)
|
| 111 |
prompt: Optional[str] = Field(default=None)
|
| 112 |
+
max_tokens: Optional[int] = Field(default=None)
|
| 113 |
+
temperature: Optional[float] = Field(default=None)
|
| 114 |
+
top_p: Optional[float] = Field(default=None)
|
| 115 |
+
presence_penalty: Optional[float] = Field(default=None)
|
| 116 |
+
count_penalty: Optional[float] = Field(default=None)
|
| 117 |
+
penalty_decay: Optional[float] = Field(default=None)
|
| 118 |
+
stream: Optional[bool] = Field(default=False)
|
| 119 |
+
state_name: Optional[str] = Field(default=None)
|
| 120 |
+
include_usage: Optional[bool] = Field(default=False)
|
| 121 |
+
stop: Optional[list[str]] = Field(["\n\n"])
|
| 122 |
+
|
| 123 |
+
@model_validator(mode="before")
|
| 124 |
+
@classmethod
|
| 125 |
+
def validate_mutual_exclusivity(cls, data: Any) -> Any:
|
| 126 |
+
if not isinstance(data, dict):
|
| 127 |
+
return data
|
| 128 |
+
|
| 129 |
+
messages_provided = "messages" in data and data["messages"] != None
|
| 130 |
+
prompt_provided = "prompt" in data and data["prompt"] != None
|
| 131 |
+
|
| 132 |
+
if messages_provided and prompt_provided:
|
| 133 |
+
raise ValueError("messages and prompt cannot coexist. Choose one.")
|
| 134 |
+
if not messages_provided and not prompt_provided:
|
| 135 |
+
raise ValueError("Either messages or prompt must be provided.")
|
| 136 |
+
return data
|
| 137 |
|
| 138 |
|
| 139 |
app = FastAPI(title="RWKV OpenAI-Compatible API")
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
|
| 150 |
+
async def runPrefill(
|
| 151 |
+
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
|
| 152 |
+
):
|
| 153 |
ctx = ctx.replace("\r\n", "\n")
|
| 154 |
|
| 155 |
+
tokens = MODEL_STORAGE[request.model].pipeline.encode(ctx)
|
| 156 |
tokens = [int(x) for x in tokens]
|
| 157 |
model_tokens += tokens
|
| 158 |
|
| 159 |
while len(tokens) > 0:
|
| 160 |
+
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
| 161 |
+
tokens[: CONFIG.CHUNK_LEN], model_state
|
| 162 |
+
)
|
| 163 |
tokens = tokens[CONFIG.CHUNK_LEN :]
|
| 164 |
await asyncio.sleep(0)
|
| 165 |
|
|
|
|
| 177 |
args = PIPELINE_ARGS(
|
| 178 |
temperature=max(0.2, request.temperature),
|
| 179 |
top_p=request.top_p,
|
| 180 |
+
alpha_frequency=request.count_penalty,
|
| 181 |
+
alpha_presence=request.presence_penalty,
|
| 182 |
token_ban=[], # ban the generation of some tokens
|
| 183 |
token_stop=[0],
|
| 184 |
) # stop generation whenever you see any token here
|
|
|
|
| 194 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
| 195 |
out[0] -= 1e10 # disable END_OF_TEXT
|
| 196 |
|
| 197 |
+
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
|
| 198 |
out, temperature=args.temperature, top_p=args.top_p
|
| 199 |
)
|
| 200 |
|
| 201 |
+
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
| 202 |
+
[token], model_state
|
| 203 |
+
)
|
| 204 |
model_tokens += [token]
|
| 205 |
|
| 206 |
out_tokens += [token]
|
| 207 |
|
| 208 |
for xxx in occurrence:
|
| 209 |
+
occurrence[xxx] *= request.penalty_decay
|
| 210 |
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
| 211 |
|
| 212 |
+
tmp: str = MODEL_STORAGE[request.model].pipeline.decode(out_tokens[out_last:])
|
| 213 |
|
| 214 |
if "\ufffd" in tmp:
|
| 215 |
continue
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
async def chatResponse(
|
| 251 |
+
request: ChatCompletionRequest,
|
| 252 |
+
model_state: any,
|
| 253 |
+
completionId: str,
|
| 254 |
+
enableReasoning: bool,
|
| 255 |
) -> ChatCompletion:
|
| 256 |
createTimestamp = time.time()
|
| 257 |
|
|
|
|
|
|
|
| 258 |
prompt = (
|
| 259 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 260 |
if request.prompt == None
|
| 261 |
else request.prompt.strip()
|
| 262 |
)
|
| 263 |
|
| 264 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
| 265 |
|
| 266 |
prefillTime = time.time()
|
| 267 |
promptTokenCount = len(model_tokens)
|
|
|
|
| 330 |
|
| 331 |
|
| 332 |
async def chatResponseStream(
|
| 333 |
+
request: ChatCompletionRequest,
|
| 334 |
+
model_state: any,
|
| 335 |
+
completionId: str,
|
| 336 |
+
enableReasoning: bool,
|
| 337 |
):
|
| 338 |
createTimestamp = int(time.time())
|
| 339 |
|
|
|
|
|
|
|
| 340 |
prompt = (
|
| 341 |
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 342 |
if request.prompt == None
|
| 343 |
else request.prompt.strip()
|
| 344 |
)
|
| 345 |
|
| 346 |
+
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
| 347 |
|
| 348 |
prefillTime = time.time()
|
| 349 |
promptTokenCount = len(model_tokens)
|
|
|
|
| 383 |
buffer = []
|
| 384 |
|
| 385 |
if enableReasoning:
|
| 386 |
+
buffer.append("<think")
|
| 387 |
|
| 388 |
streamConfig = {
|
| 389 |
"isChecking": False,
|
|
|
|
| 572 |
completionId = str(next(CompletionIdGenerator))
|
| 573 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 574 |
|
| 575 |
+
modelName = request.model.split(":")[0]
|
| 576 |
+
enableReasoning = ":thinking" in request.model
|
| 577 |
+
|
| 578 |
+
if "rwkv-latest" in request.model:
|
| 579 |
+
if enableReasoning:
|
| 580 |
+
if DEFAULT_REASONING_MODEL_NAME == None:
|
| 581 |
+
raise HTTPException(404, "DEFAULT_REASONING_MODEL_NAME not set")
|
| 582 |
+
defaultSamplerConfig = MODEL_STORAGE[
|
| 583 |
+
DEFAULT_REASONING_MODEL_NAME
|
| 584 |
+
].MODEL_CONFIG.DEFAULT_SAMPLER
|
| 585 |
+
request.model = DEFAULT_REASONING_MODEL_NAME
|
| 586 |
+
|
| 587 |
+
else:
|
| 588 |
+
if DEFALUT_MODEL_NAME == None:
|
| 589 |
+
raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
|
| 590 |
+
defaultSamplerConfig = MODEL_STORAGE[
|
| 591 |
+
DEFALUT_MODEL_NAME
|
| 592 |
+
].MODEL_CONFIG.DEFAULT_SAMPLER
|
| 593 |
+
request.model = DEFALUT_MODEL_NAME
|
| 594 |
+
|
| 595 |
+
elif modelName in MODEL_STORAGE:
|
| 596 |
+
defaultSamplerConfig = MODEL_STORAGE[modelName].MODEL_CONFIG.DEFAULT_SAMPLER
|
| 597 |
+
request.model = modelName
|
| 598 |
+
else:
|
| 599 |
+
raise f"Can not find `{modelName}`"
|
| 600 |
+
|
| 601 |
async def chatResponseStreamDisconnect():
|
| 602 |
if "cuda" in CONFIG.STRATEGY:
|
| 603 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
|
|
|
| 606 |
)
|
| 607 |
|
| 608 |
model_state = None
|
| 609 |
+
request_dict = request.model_dump()
|
| 610 |
+
|
| 611 |
+
for k, v in defaultSamplerConfig.model_dump().items():
|
| 612 |
+
if request_dict[k] == None:
|
| 613 |
+
request_dict[k] = v
|
| 614 |
+
realRequest = ChatCompletionRequest(**request_dict)
|
| 615 |
+
|
| 616 |
+
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
|
| 617 |
|
| 618 |
if request.stream:
|
| 619 |
r = StreamingResponse(
|
| 620 |
+
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
|
| 621 |
media_type="text/event-stream",
|
| 622 |
background=chatResponseStreamDisconnect,
|
| 623 |
)
|
| 624 |
else:
|
| 625 |
+
r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
|
| 626 |
|
| 627 |
return r
|
| 628 |
|
| 629 |
+
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
| 630 |
|
| 631 |
if __name__ == "__main__":
|
| 632 |
import uvicorn
|
config.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from typing import List, Optional, Union, Any
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from pydantic_settings import BaseSettings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
|
| 12 |
+
CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
CLI_CONFIG = CliConfig()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SamplerConfig(BaseModel):
|
| 19 |
+
"""Default sampler configuration for each model."""
|
| 20 |
+
|
| 21 |
+
max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
|
| 22 |
+
temperature: float = Field(1.0, description="Sampling temperature.")
|
| 23 |
+
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
| 24 |
+
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
| 25 |
+
count_penalty: float = Field(0.5, description="Count penalty.")
|
| 26 |
+
penalty_decay: float = Field(0.5, description="Penalty decay factor.")
|
| 27 |
+
stop: List[str] = Field(0.996, description="List of stop sequences.")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ModelConfig(BaseModel):
|
| 31 |
+
"""Configuration for each individual model."""
|
| 32 |
+
|
| 33 |
+
SERVICE_NAME: str = Field(..., description="Service name of the model.")
|
| 34 |
+
|
| 35 |
+
MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
|
| 36 |
+
|
| 37 |
+
DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
|
| 38 |
+
None, description="Model name, should end with .pth"
|
| 39 |
+
)
|
| 40 |
+
DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
|
| 41 |
+
None, description="Model repository ID on Hugging Face Hub."
|
| 42 |
+
)
|
| 43 |
+
DOWNLOAD_MODEL_DIR: Optional[str] = Field(
|
| 44 |
+
None, description="Directory to download the model to."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
REASONING: bool = Field(
|
| 48 |
+
False, description="Whether reasoning is enabled for this model."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
DEFAULT: bool = Field(False, description="Whether this model is the default model.")
|
| 52 |
+
DEFAULT_SAMPLER: SamplerConfig = Field(
|
| 53 |
+
SamplerConfig(), description="Default sampler configuration for this model."
|
| 54 |
+
)
|
| 55 |
+
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RootConfig(BaseModel):
|
| 59 |
+
"""Root configuration for the RWKV service."""
|
| 60 |
+
|
| 61 |
+
HOST: Optional[str] = Field(
|
| 62 |
+
"127.0.0.1", description="Host IP address to bind to."
|
| 63 |
+
) # 注释掉可选的HOST和PORT
|
| 64 |
+
PORT: Optional[int] = Field(
|
| 65 |
+
8000, description="Port number to listen on."
|
| 66 |
+
) # 因为YAML示例中被注释掉了
|
| 67 |
+
STRATEGY: str = Field(
|
| 68 |
+
"cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
|
| 69 |
+
)
|
| 70 |
+
RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
|
| 71 |
+
CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
|
| 72 |
+
MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
import yaml
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
|
| 79 |
+
CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Pydantic Model Validation Failed: {e}")
|
| 82 |
+
sys.exit(0)
|
openai_test.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
uv pip install openai
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
|
| 7 |
-
import logging
|
| 8 |
-
|
| 9 |
-
# logging.basicConfig(
|
| 10 |
-
# level=logging.DEBUG,
|
| 11 |
-
# )
|
| 12 |
-
|
| 13 |
-
os.environ["NO_PROXY"] = "127.0.0.1"
|
| 14 |
-
|
| 15 |
-
from openai import OpenAI
|
| 16 |
-
|
| 17 |
-
client = OpenAI(base_url="http://127.0.0.1:8000/api/v1", api_key="sk-test")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def completionStreamTest():
|
| 21 |
-
print("[*] Stream completion: ")
|
| 22 |
-
|
| 23 |
-
completion = client.chat.completions.create(
|
| 24 |
-
model="rwkv-latest",
|
| 25 |
-
messages=[
|
| 26 |
-
{
|
| 27 |
-
"role": "User",
|
| 28 |
-
"content": "请讲个关于一只灰猫和一个小女孩之间的简短故事。",
|
| 29 |
-
},
|
| 30 |
-
],
|
| 31 |
-
stream=True,
|
| 32 |
-
max_tokens=2048,
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
isReasoning = False
|
| 36 |
-
|
| 37 |
-
for chunk in completion:
|
| 38 |
-
if chunk.choices[0].delta.reasoning_content and not isReasoning:
|
| 39 |
-
print("<- Reasoning ->")
|
| 40 |
-
isReasoning = True
|
| 41 |
-
elif chunk.choices[0].delta.content and isReasoning:
|
| 42 |
-
isReasoning = False
|
| 43 |
-
print("<- Stop Reasoning ->")
|
| 44 |
-
|
| 45 |
-
if chunk.choices[0].delta.reasoning_content:
|
| 46 |
-
print(chunk.choices[0].delta.reasoning_content, end="", flush=True)
|
| 47 |
-
if chunk.choices[0].delta.content:
|
| 48 |
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
| 49 |
-
|
| 50 |
-
print("")
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def completionTest():
|
| 54 |
-
completion = client.chat.completions.create(
|
| 55 |
-
model="rwkv-latest:thinking",
|
| 56 |
-
messages=[
|
| 57 |
-
{
|
| 58 |
-
"role": "User",
|
| 59 |
-
"content": "How many planets are there in our solar system?",
|
| 60 |
-
},
|
| 61 |
-
],
|
| 62 |
-
max_tokens=2048,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
print("[*] Completion: ", completion)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
if __name__ == "__main__":
|
| 69 |
-
try:
|
| 70 |
-
# completionTest()
|
| 71 |
-
|
| 72 |
-
testRounds = input("Test rounds (Default: 10) :")
|
| 73 |
-
|
| 74 |
-
for i in range(int(testRounds) if testRounds != "" else 10):
|
| 75 |
-
print("\n", "=" * 10, i + 1, "/", testRounds, "=" * 10)
|
| 76 |
-
completionStreamTest()
|
| 77 |
-
except KeyboardInterrupt:
|
| 78 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|