sparkleman commited on
Commit
109a0c8
·
0 Parent(s):
Files changed (10) hide show
  1. .gitignore +16 -0
  2. .python-version +1 -0
  3. Dockerfile +27 -0
  4. README.md +21 -0
  5. api_types.py +82 -0
  6. app.py +555 -0
  7. openai_test.py +78 -0
  8. pyproject.toml +47 -0
  9. utils.py +35 -0
  10. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .cache
13
+
14
+ *pth
15
+ *.pt
16
+ *.st
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG CUDA_IMAGE="12.1.1-devel-ubuntu22.04"
2
+ FROM nvidia/cuda:${CUDA_IMAGE}
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ git \
7
+ ffmpeg &&
8
+ apt-get clean && rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
11
+
12
+ COPY . .
13
+
14
+ RUN uv sync --frozen
15
+
16
+ RUN useradd -m -u 1000 user
17
+ # Switch to the "user" user
18
+ USER user
19
+
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH
22
+
23
+ WORKDIR $HOME/app
24
+
25
+ COPY --chown=user . $HOME/app
26
+
27
+ CMD ["uv", "app.py","--strategy","cuda fp16","--model_title","RWKV-x070-World-0.1B-v2.8-20241210-ctx4096","--download_repo_id","BlinkDL/rwkv-7-world"]
README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simple RWKV OpenAI-Compatible API
2
+
3
+ ## Usage
4
+
5
+ `RWKV-x070-World-0.1B-v2.8-20241210-ctx4096`
6
+
7
+ ```shell
8
+ python app.py --strategy "cuda fp16" --model_title "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096" --download_repo_id "BlinkDL/rwkv-7-world" --download_model_dir ./
9
+ ```
10
+
11
+ `RWKV7-G1-0.1B-68%trained-20250303-ctx4k`
12
+
13
+ ```shell
14
+ python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.1B-68%trained-20250303-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
15
+ ```
16
+
17
+ `RWKV7-G1-0.1B-68%trained-20250303-ctx4k`
18
+
19
+ ```shell
20
+ python app.py --strategy "cuda fp16" --model_title "RWKV7-G1-0.4B-32%trained-20250304-ctx4k" --download_repo_id "BlinkDL/temp-latest-training-models" --download_model_dir ./
21
+ ```
api_types.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any, Literal
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class ChatMessage(BaseModel):
6
+ role: str = Field()
7
+ content: str = Field()
8
+
9
+
10
+ class Logprob(BaseModel):
11
+ token: str
12
+ logprob: float
13
+ top_logprobs: Optional[List[Dict[str, Any]]] = None
14
+
15
+
16
+ class LogprobsContent(BaseModel):
17
+ content: Optional[List[Logprob]] = None
18
+ refusal: Optional[List[Logprob]] = None
19
+
20
+
21
+ class FunctionCall(BaseModel):
22
+ name: str
23
+ arguments: str
24
+
25
+
26
+ class ChatCompletionMessage(BaseModel):
27
+ role: Optional[str] = Field(
28
+ None, description="The role of the author of this message"
29
+ )
30
+ content: Optional[str] = Field(None, description="The contents of the message")
31
+ reasoning_content: Optional[str] = Field(
32
+ None, description="The reasoning contents of the message"
33
+ )
34
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
35
+ None, description="Tool calls generated by the model"
36
+ )
37
+
38
+
39
+ class PromptTokensDetails(BaseModel):
40
+ cached_tokens: int
41
+
42
+
43
+ class CompletionTokensDetails(BaseModel):
44
+ reasoning_tokens: int
45
+ accepted_prediction_tokens: int
46
+ rejected_prediction_tokens: int
47
+
48
+
49
+ class Usage(BaseModel):
50
+ prompt_tokens: int
51
+ completion_tokens: int
52
+ total_tokens: int
53
+ prompt_tokens_details: Optional[PromptTokensDetails]
54
+ # completion_tokens_details: CompletionTokensDetails
55
+
56
+
57
+ class ChatCompletionChoice(BaseModel):
58
+ index: int
59
+ message: Optional[ChatCompletionMessage] = None
60
+ delta: Optional[ChatCompletionMessage] = None
61
+ logprobs: Optional[LogprobsContent] = None
62
+ finish_reason: Optional[str] = Field(
63
+ ..., description="Reason for stopping: stop, length, content_filter, tool_calls"
64
+ )
65
+
66
+
67
+ class ChatCompletion(BaseModel):
68
+ id: str = Field(..., description="Unique identifier for the chat completion")
69
+ object: Literal["chat.completion"] = "chat.completion"
70
+ created: int = Field(..., description="Unix timestamp of creation")
71
+ model: str
72
+ choices: List[ChatCompletionChoice]
73
+ usage: Usage
74
+
75
+
76
+ class ChatCompletionChunk(BaseModel):
77
+ id: str = Field(..., description="Unique identifier for the chat completion")
78
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
79
+ created: int = Field(..., description="Unix timestamp of creation")
80
+ model: str
81
+ choices: List[ChatCompletionChoice]
82
+ usage: Optional[Usage]
app.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, copy, types, gc, sys, re, time, collections, asyncio
2
+ from huggingface_hub import hf_hub_download
3
+ from loguru import logger
4
+
5
+ from snowflake import SnowflakeGenerator
6
+
7
+ CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
8
+
9
+ from pynvml import *
10
+
11
+ nvmlInit()
12
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
13
+
14
+ from typing import List, Optional, Union
15
+ from pydantic import BaseModel, Field
16
+ from pydantic_settings import BaseSettings
17
+
18
+
19
+ class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
20
+ HOST: str = Field("127.0.0.1", description="Host")
21
+ PORT: int = Field(8000, description="Port")
22
+ DEBUG: bool = Field(False, description="Debug mode")
23
+ STRATEGY: str = Field("cpu", description="Stratergy")
24
+ MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096")
25
+ DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world")
26
+ DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir")
27
+ MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path")
28
+ GEN_penalty_decay: float = Field(0.996, description="Default penalty decay")
29
+ CHUNK_LEN: int = Field(
30
+ 256,
31
+ description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
32
+ )
33
+ VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
34
+
35
+
36
+ CONFIG = Config()
37
+
38
+
39
+ import numpy as np
40
+ import torch
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ torch.backends.cudnn.benchmark = True
44
+ torch.backends.cudnn.allow_tf32 = True
45
+ torch.backends.cuda.matmul.allow_tf32 = True
46
+ os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
47
+ os.environ["RWKV_JIT_ON"] = "1"
48
+ os.environ["RWKV_CUDA_ON"] = (
49
+ "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!
50
+ )
51
+
52
+ from rwkv.model import RWKV
53
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
54
+
55
+ from fastapi import FastAPI
56
+ from fastapi.responses import StreamingResponse
57
+ from fastapi.middleware.cors import CORSMiddleware
58
+
59
+ from api_types import (
60
+ ChatMessage,
61
+ ChatCompletion,
62
+ ChatCompletionChunk,
63
+ Usage,
64
+ PromptTokensDetails,
65
+ ChatCompletionChoice,
66
+ ChatCompletionMessage,
67
+ )
68
+ from utils import cleanMessages, parse_think_response
69
+
70
+
71
+ logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
72
+ if CONFIG.MODEL_FILE_PATH == None:
73
+ CONFIG.MODEL_FILE_PATH = hf_hub_download(
74
+ repo_id=CONFIG.DOWNLOAD_REPO_ID,
75
+ filename=f"{CONFIG.MODEL_TITLE}.pth",
76
+ local_dir=CONFIG.DOWNLOAD_MODEL_DIR,
77
+ )
78
+
79
+ logger.info(f"Load Model - {CONFIG.MODEL_FILE_PATH}")
80
+ model = RWKV(model=CONFIG.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY)
81
+ pipeline = PIPELINE(model, CONFIG.VOCAB)
82
+
83
+
84
+ class ChatCompletionRequest(BaseModel):
85
+ model: str = Field(
86
+ default="rwkv-latest",
87
+ description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
88
+ )
89
+ messages: List[ChatMessage]
90
+ prompt: Union[str, None] = Field(default=None)
91
+ max_tokens: int = Field(default=512)
92
+ temperature: float = Field(default=1.0)
93
+ top_p: float = Field(default=0.3)
94
+ presencePenalty: float = Field(default=0.5)
95
+ countPenalty: float = Field(default=0.5)
96
+ stream: bool = Field(default=False)
97
+ state_name: str = Field(default=None)
98
+ include_usage: bool = Field(default=False)
99
+
100
+
101
+ app = FastAPI(title="RWKV OpenAI-Compatible API")
102
+
103
+ app.add_middleware(
104
+ CORSMiddleware,
105
+ allow_origins=["*"],
106
+ allow_credentials=True,
107
+ allow_methods=["*"],
108
+ allow_headers=["*"],
109
+ )
110
+
111
+
112
+ def runPrefill(ctx: str, model_tokens: List[int], model_state):
113
+ ctx = ctx.replace("\r\n", "\n")
114
+
115
+ tokens = pipeline.encode(ctx)
116
+ tokens = [int(x) for x in tokens]
117
+ model_tokens += tokens
118
+
119
+ while len(tokens) > 0:
120
+ out, model_state = model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
121
+ tokens = tokens[CONFIG.CHUNK_LEN :]
122
+
123
+ return out, model_tokens, model_state
124
+
125
+
126
+ def generate(
127
+ request: ChatCompletionRequest,
128
+ out,
129
+ model_tokens,
130
+ model_state,
131
+ stops=["\n\n"],
132
+ max_tokens=2048,
133
+ ):
134
+ args = PIPELINE_ARGS(
135
+ temperature=max(0.2, request.temperature),
136
+ top_p=request.top_p,
137
+ alpha_frequency=request.countPenalty,
138
+ alpha_presence=request.presencePenalty,
139
+ token_ban=[], # ban the generation of some tokens
140
+ token_stop=[0],
141
+ ) # stop generation whenever you see any token here
142
+
143
+ occurrence = {}
144
+ out_tokens = []
145
+ out_last = 0
146
+
147
+ output_cache = collections.deque(maxlen=5)
148
+
149
+ for i in range(max_tokens):
150
+ for n in occurrence:
151
+ out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
152
+ out[0] -= 1e10 # disable END_OF_TEXT
153
+
154
+ token = pipeline.sample_logits(
155
+ out, temperature=args.temperature, top_p=args.top_p
156
+ )
157
+
158
+ out, model_state = model.forward([token], model_state)
159
+ model_tokens += [token]
160
+
161
+ out_tokens += [token]
162
+
163
+ for xxx in occurrence:
164
+ occurrence[xxx] *= CONFIG.GEN_penalty_decay
165
+ occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
166
+
167
+ tmp: str = pipeline.decode(out_tokens[out_last:])
168
+
169
+ if "\ufffd" in tmp:
170
+ continue
171
+
172
+ output_cache.append(tmp)
173
+ output_cache_str = "".join(output_cache)
174
+
175
+ for stop_words in stops:
176
+ if stop_words in output_cache_str:
177
+
178
+ yield {
179
+ "content": tmp.replace(stop_words, ""),
180
+ "tokens": out_tokens[out_last:],
181
+ "finish_reason": "stop",
182
+ "state": model_state,
183
+ }
184
+
185
+ del out
186
+ gc.collect()
187
+ return
188
+
189
+ yield {
190
+ "content": tmp,
191
+ "tokens": out_tokens[out_last:],
192
+ "finish_reason": None,
193
+ }
194
+
195
+ out_last = i + 1
196
+
197
+ else:
198
+ yield {
199
+ "content": "",
200
+ "tokens": [],
201
+ "finish_reason": "length",
202
+ }
203
+
204
+
205
+ async def chatResponse(
206
+ request: ChatCompletionRequest, model_state: any, completionId: str
207
+ ) -> ChatCompletion:
208
+ createTimestamp = time.time()
209
+
210
+ enableReasoning = request.model.endswith(":thinking")
211
+
212
+ prompt = (
213
+ f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
214
+ if request.prompt == None
215
+ else request.prompt.strip()
216
+ )
217
+
218
+ out, model_tokens, model_state = runPrefill(prompt, [], model_state)
219
+
220
+ prefillTime = time.time()
221
+ promptTokenCount = len(model_tokens)
222
+
223
+ fullResponse = " <think" if enableReasoning else ""
224
+ completionTokenCount = 0
225
+ finishReason = None
226
+
227
+ for chunk in generate(
228
+ request,
229
+ out,
230
+ model_tokens,
231
+ model_state,
232
+ max_tokens=(
233
+ 64000
234
+ if "max_tokens" not in request.model_fields_set and enableReasoning
235
+ else request.max_tokens
236
+ ),
237
+ ):
238
+ fullResponse += chunk["content"]
239
+ completionTokenCount += 1
240
+
241
+ if chunk["finish_reason"]:
242
+ finishReason = chunk["finish_reason"]
243
+ await asyncio.sleep(0)
244
+
245
+ genenrateTime = time.time()
246
+
247
+ responseLog = {
248
+ "content": fullResponse,
249
+ "finish": finishReason,
250
+ "prefill_len": promptTokenCount,
251
+ "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
252
+ "gen_len": completionTokenCount,
253
+ "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
254
+ }
255
+ logger.info(f"[RES] {completionId} - {responseLog}")
256
+
257
+ reasoning_content, content = parse_think_response(fullResponse)
258
+
259
+ response = ChatCompletion(
260
+ id=completionId,
261
+ created=int(createTimestamp),
262
+ model=request.model,
263
+ usage=Usage(
264
+ prompt_tokens=promptTokenCount,
265
+ completion_tokens=completionTokenCount,
266
+ total_tokens=promptTokenCount + completionTokenCount,
267
+ prompt_tokens_details={"cached_tokens": 0},
268
+ ),
269
+ choices=[
270
+ ChatCompletionChoice(
271
+ index=0,
272
+ message=ChatCompletionMessage(
273
+ role="Assistant",
274
+ content=content,
275
+ reasoning_content=reasoning_content if reasoning_content else None,
276
+ ),
277
+ logprobs=None,
278
+ finish_reason=finishReason,
279
+ )
280
+ ],
281
+ )
282
+
283
+ return response
284
+
285
+
286
+ async def chatResponseStream(
287
+ request: ChatCompletionRequest, model_state: any, completionId: str
288
+ ):
289
+ createTimestamp = int(time.time())
290
+
291
+ enableReasoning = request.model.endswith(":thinking")
292
+
293
+ prompt = (
294
+ f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
295
+ if request.prompt == None
296
+ else request.prompt.strip()
297
+ )
298
+
299
+ out, model_tokens, model_state = runPrefill(prompt, [], model_state)
300
+
301
+ prefillTime = time.time()
302
+ promptTokenCount = len(model_tokens)
303
+
304
+ completionTokenCount = 0
305
+ finishReason = None
306
+
307
+ response = ChatCompletionChunk(
308
+ id=completionId,
309
+ created=createTimestamp,
310
+ model=request.model,
311
+ usage=(
312
+ Usage(
313
+ prompt_tokens=promptTokenCount,
314
+ completion_tokens=completionTokenCount,
315
+ total_tokens=promptTokenCount + completionTokenCount,
316
+ prompt_tokens_details={"cached_tokens": 0},
317
+ )
318
+ if request.include_usage
319
+ else None
320
+ ),
321
+ choices=[
322
+ ChatCompletionChoice(
323
+ index=0,
324
+ delta=ChatCompletionMessage(
325
+ role="Assistant",
326
+ content="",
327
+ reasoning_content="" if enableReasoning else None,
328
+ ),
329
+ logprobs=None,
330
+ finish_reason=finishReason,
331
+ )
332
+ ],
333
+ )
334
+ yield f"data: {response.model_dump_json()}\n\n"
335
+
336
+ buffer = []
337
+
338
+ if enableReasoning:
339
+ buffer.append(" <think")
340
+
341
+ streamConfig = {
342
+ "isChecking": False,
343
+ "fullTextCursor": 0,
344
+ "in_think": False,
345
+ "cacheStr": "",
346
+ }
347
+
348
+ for chunk in generate(
349
+ request,
350
+ out,
351
+ model_tokens,
352
+ model_state,
353
+ max_tokens=(
354
+ 64000
355
+ if "max_tokens" not in request.model_fields_set and enableReasoning
356
+ else request.max_tokens
357
+ ),
358
+ ):
359
+ completionTokenCount += 1
360
+
361
+ chunkContent: str = chunk["content"]
362
+ buffer.append(chunkContent)
363
+
364
+ fullText = "".join(buffer)
365
+
366
+ if chunk["finish_reason"]:
367
+ finishReason = chunk["finish_reason"]
368
+
369
+ response = ChatCompletionChunk(
370
+ id=completionId,
371
+ created=createTimestamp,
372
+ model=request.model,
373
+ usage=(
374
+ Usage(
375
+ prompt_tokens=promptTokenCount,
376
+ completion_tokens=completionTokenCount,
377
+ total_tokens=promptTokenCount + completionTokenCount,
378
+ prompt_tokens_details={"cached_tokens": 0},
379
+ )
380
+ if request.include_usage
381
+ else None
382
+ ),
383
+ choices=[
384
+ ChatCompletionChoice(
385
+ index=0,
386
+ delta=ChatCompletionMessage(
387
+ content=None, reasoning_content=None
388
+ ),
389
+ logprobs=None,
390
+ finish_reason=finishReason,
391
+ )
392
+ ],
393
+ )
394
+
395
+ markStart = fullText.find("<", streamConfig["fullTextCursor"])
396
+ if not streamConfig["isChecking"] and markStart != -1:
397
+ streamConfig["isChecking"] = True
398
+
399
+ if streamConfig["in_think"]:
400
+ response.choices[0].delta.reasoning_content = fullText[
401
+ streamConfig["fullTextCursor"] : markStart
402
+ ]
403
+ else:
404
+ response.choices[0].delta.content = fullText[
405
+ streamConfig["fullTextCursor"] : markStart
406
+ ]
407
+
408
+ streamConfig["cacheStr"] = ""
409
+ streamConfig["fullTextCursor"] = markStart
410
+
411
+ if streamConfig["isChecking"]:
412
+ streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
413
+ else:
414
+ if streamConfig["in_think"]:
415
+ response.choices[0].delta.reasoning_content = chunkContent
416
+ else:
417
+ response.choices[0].delta.content = chunkContent
418
+ streamConfig["fullTextCursor"] = len(fullText)
419
+
420
+ markEnd = fullText.find(">", streamConfig["fullTextCursor"])
421
+ if streamConfig["isChecking"] and markEnd != -1:
422
+ streamConfig["isChecking"] = False
423
+
424
+ if (
425
+ not streamConfig["in_think"]
426
+ and streamConfig["cacheStr"].find("<think>") != -1
427
+ ):
428
+ streamConfig["in_think"] = True
429
+
430
+ response.choices[0].delta.reasoning_content = (
431
+ response.choices[0].delta.reasoning_content
432
+ if response.choices[0].delta.reasoning_content != None
433
+ else "" + streamConfig["cacheStr"].replace("<think>", "")
434
+ )
435
+
436
+ elif (
437
+ streamConfig["in_think"]
438
+ and streamConfig["cacheStr"].find("</think>") != -1
439
+ ):
440
+ streamConfig["in_think"] = False
441
+
442
+ response.choices[0].delta.content = (
443
+ response.choices[0].delta.content
444
+ if response.choices[0].delta.content != None
445
+ else "" + streamConfig["cacheStr"].replace("</think>", "")
446
+ )
447
+ else:
448
+ if streamConfig["in_think"]:
449
+ response.choices[0].delta.reasoning_content = (
450
+ response.choices[0].delta.reasoning_content
451
+ if response.choices[0].delta.reasoning_content != None
452
+ else "" + streamConfig["cacheStr"]
453
+ )
454
+ else:
455
+ response.choices[0].delta.content = (
456
+ response.choices[0].delta.content
457
+ if response.choices[0].delta.content != None
458
+ else "" + streamConfig["cacheStr"]
459
+ )
460
+ streamConfig["fullTextCursor"] = len(fullText)
461
+
462
+ if (
463
+ response.choices[0].delta.content != None
464
+ or response.choices[0].delta.reasoning_content != None
465
+ ):
466
+ yield f"data: {response.model_dump_json()}\n\n"
467
+
468
+ await asyncio.sleep(0)
469
+
470
+ del streamConfig
471
+ else:
472
+ for chunk in generate(request, out, model_tokens, model_state):
473
+ completionTokenCount += 1
474
+ buffer.append(chunk["content"])
475
+
476
+ if chunk["finish_reason"]:
477
+ finishReason = chunk["finish_reason"]
478
+
479
+ response = ChatCompletionChunk(
480
+ id=completionId,
481
+ created=createTimestamp,
482
+ model=request.model,
483
+ usage=(
484
+ Usage(
485
+ prompt_tokens=promptTokenCount,
486
+ completion_tokens=completionTokenCount,
487
+ total_tokens=promptTokenCount + completionTokenCount,
488
+ prompt_tokens_details={"cached_tokens": 0},
489
+ )
490
+ if request.include_usage
491
+ else None
492
+ ),
493
+ choices=[
494
+ ChatCompletionChoice(
495
+ index=0,
496
+ delta=ChatCompletionMessage(content=chunk["content"]),
497
+ logprobs=None,
498
+ finish_reason=finishReason,
499
+ )
500
+ ],
501
+ )
502
+
503
+ yield f"data: {response.model_dump_json()}\n\n"
504
+ await asyncio.sleep(0)
505
+
506
+ genenrateTime = time.time()
507
+
508
+ responseLog = {
509
+ "content": "".join(buffer),
510
+ "finish": finishReason,
511
+ "prefill_len": promptTokenCount,
512
+ "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
513
+ "gen_len": completionTokenCount,
514
+ "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
515
+ }
516
+ logger.info(f"[RES] {completionId} - {responseLog}")
517
+
518
+ del buffer
519
+
520
+ yield "data: [DONE]\n\n"
521
+
522
+
523
+
524
+
525
+
526
+ @app.post("/api/v1/chat/completions")
527
+ async def chat_completions(request: ChatCompletionRequest):
528
+ completionId = str(next(CompletionIdGenerator))
529
+ logger.info(f"[REQ] {completionId} - {request.model_dump()}")
530
+
531
+ def chatResponseStreamDisconnect():
532
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
533
+ logger.info(
534
+ f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}"
535
+ )
536
+
537
+ model_state = None
538
+
539
+ if request.stream:
540
+ r = StreamingResponse(
541
+ chatResponseStream(request, model_state, completionId),
542
+ media_type="text/event-stream",
543
+ background=chatResponseStreamDisconnect,
544
+ )
545
+ else:
546
+ r = await chatResponse(request, model_state, completionId)
547
+
548
+
549
+ return r
550
+
551
+
552
+ if __name__ == "__main__":
553
+ import uvicorn
554
+
555
+ uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)
openai_test.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ uv pip install openai
3
+ """
4
+
5
+ import os
6
+
7
+ import logging
8
+
9
+ # logging.basicConfig(
10
+ # level=logging.DEBUG,
11
+ # )
12
+
13
+ os.environ["NO_PROXY"] = "127.0.0.1"
14
+
15
+ from openai import OpenAI
16
+
17
+ client = OpenAI(base_url="http://127.0.0.1:8000/api/v1", api_key="sk-test")
18
+
19
+
20
+ def completionStreamTest():
21
+ print("[*] Stream completion: ")
22
+
23
+ completion = client.chat.completions.create(
24
+ model="rwkv-latest",
25
+ messages=[
26
+ {
27
+ "role": "User",
28
+ "content": "请讲个关于一只灰猫和一个小女孩之间的简短故事。",
29
+ },
30
+ ],
31
+ stream=True,
32
+ max_tokens=2048,
33
+ )
34
+
35
+ isReasoning = False
36
+
37
+ for chunk in completion:
38
+ if chunk.choices[0].delta.reasoning_content and not isReasoning:
39
+ print("<- Reasoning ->")
40
+ isReasoning = True
41
+ elif chunk.choices[0].delta.content and isReasoning:
42
+ isReasoning = False
43
+ print("<- Stop Reasoning ->")
44
+
45
+ if chunk.choices[0].delta.reasoning_content:
46
+ print(chunk.choices[0].delta.reasoning_content, end="", flush=True)
47
+ if chunk.choices[0].delta.content:
48
+ print(chunk.choices[0].delta.content, end="", flush=True)
49
+
50
+ print("")
51
+
52
+
53
+ def completionTest():
54
+ completion = client.chat.completions.create(
55
+ model="rwkv-latest:thinking",
56
+ messages=[
57
+ {
58
+ "role": "User",
59
+ "content": "How many planets are there in our solar system?",
60
+ },
61
+ ],
62
+ max_tokens=2048,
63
+ )
64
+
65
+ print("[*] Completion: ", completion)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ try:
70
+ # completionTest()
71
+
72
+ testRounds = input("Test rounds (Default: 10) :")
73
+
74
+ for i in range(int(testRounds) if testRounds != "" else 10):
75
+ print("\n", "=" * 10, i + 1, "/", testRounds, "=" * 10)
76
+ completionStreamTest()
77
+ except KeyboardInterrupt:
78
+ pass
pyproject.toml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rwkv-hf-space"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "fastapi[standard]>=0.115.11",
9
+ "huggingface-hub>=0.29.1",
10
+ "loguru>=0.7.3",
11
+ "numpy>=2.2.3",
12
+ "pydantic>=2.10.6",
13
+ "pydantic-settings>=2.8.1",
14
+ "pynvml>=12.0.0",
15
+ "rwkv==0.8.28",
16
+ "snowflake-id>=1.0.2",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ cpu = ["torch>=2.6.0"]
21
+ cu124 = ["torch>=2.6.0"]
22
+ cu113 = ["torch"]
23
+
24
+ [tool.uv]
25
+ conflicts = [[{ extra = "cpu" }, { extra = "cu124" }, { extra = "cu113" }]]
26
+
27
+ [tool.uv.sources]
28
+ torch = [
29
+ { index = "pytorch-cpu", extra = "cpu" },
30
+ { index = "pytorch-cu124", extra = "cu124" },
31
+ { index = "pytorch-cu113", extra = "cu113" },
32
+ ]
33
+
34
+ [[tool.uv.index]]
35
+ name = "pytorch-cpu"
36
+ url = "https://download.pytorch.org/whl/cpu"
37
+ explicit = true
38
+
39
+ [[tool.uv.index]]
40
+ name = "pytorch-cu124"
41
+ url = "https://download.pytorch.org/whl/cu124"
42
+ explicit = true
43
+
44
+ [[tool.uv.index]]
45
+ name = "pytorch-cu113"
46
+ url = "https://download.pytorch.org/whl/cu113"
47
+ explicit = true
utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Optional, Union
3
+ from pydantic import BaseModel, Field
4
+ from pydantic_settings import BaseSettings
5
+
6
+ from api_types import ChatMessage
7
+
8
+
9
+ def parse_think_response(full_response: str):
10
+ think_start = full_response.find("<think")
11
+ if think_start == -1:
12
+ return None, full_response.strip()
13
+
14
+ think_end = full_response.find("</think>")
15
+ if think_end == -1: # 未闭合的情况
16
+ reasoning = full_response[think_start:].strip()
17
+ content = ""
18
+ else:
19
+ reasoning = full_response[think_start : think_end + 9].strip() # +9包含完整标签
20
+ content = full_response[think_end + 9 :].strip()
21
+
22
+ # 清理标签保留内容
23
+ reasoning_content = reasoning.replace("<think", "").replace("</think>", "").strip()
24
+ return reasoning_content, content
25
+
26
+
27
+ def cleanMessages(messages: List[ChatMessage]):
28
+ promptStrList = []
29
+
30
+ for message in messages:
31
+ content = message.content.strip()
32
+ content = re.sub(r"\n+", "\n", content)
33
+ promptStrList.append(f"{message.role.strip()}: {content}")
34
+
35
+ return "\n\n".join(promptStrList)
uv.lock ADDED
The diff for this file is too large to render. See raw diff