Chinese
English
34.2M
Mini_RWKV_7_34.2M / API_DEMO_CHAT.py
Alic-Li's picture
Update API_DEMO_CHAT.py
10dad5c verified
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print("RWKV Chat Simple Demo")
import os, copy, types, gc, sys, re
import numpy as np
from prompt_toolkit import prompt
import torch
from transformers import AutoTokenizer
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
########################################################################################################
args = types.SimpleNamespace()
args.strategy = "cuda fp16" # use CUDA, fp16
args.MODEL_NAME = "./rwkv-final-sft-2048.pth"
########################################################################################################
STATE_NAME = None # use vanilla zero initial state?
# use custom state? much better chat results (download from https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main)
# note: this is English Single-round QA state (will forget what you previously say)
# STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"
########################################################################################################
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
if STATE_NAME != None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3
CHUNK_LEN = 16 # split input into chunks to save VRAM (shorter -> slower, but saves VRAM)
########################################################################################################
print(f"Loading model - {args.MODEL_NAME}")
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer")
model_tokens = []
model_state = None
if STATE_NAME != None: # load custom state
args = model.args
state_raw = torch.load(STATE_NAME + '.pth')
state_init = [None for i in range(args.n_layer * 3)]
for i in range(args.n_layer):
dd = model.strategy[i]
dev = dd.device
atype = dd.atype
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init)
def run_rnn(ctx):
global model_tokens, model_state
ctx = ctx.replace("\r\n", "\n")
tokens = tokenizer.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens
# print(f"### model ###\n{model_tokens}\n[{pipeline.decode(model_tokens)}]") # debug
while len(tokens) > 0:
out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
tokens = tokens[CHUNK_LEN:]
return out
if STATE_NAME == None: # use initial prompt if we are not loading a state
init_ctx = "User: hi" + "\n\n"
init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
# run_rnn(init_ctx)
# print(init_ctx, end="")
while True:
msg = prompt("<|im_start|>user:")
msg = msg.strip()
msg = re.sub(r"\n+", "\n", msg)
if len(msg) > 0:
occurrence = {}
out_tokens = []
out_last = 0
out = run_rnn("<|im_start|>user\n" + msg + "<|im_end|>\n" + "<|im_start|>assistant\n")
print("\nAssistant:", end="")
eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id
for i in range(99999):
for n in occurrence:
out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency # repetition penalty
out[0] -= 1e10 # disable END_OF_TEXT
token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)
out, model_state = model.forward([token], model_state)
model_tokens += [token]
out_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp = tokenizer.decode(out_tokens[out_last:])
if ("\ufffd" not in tmp) and (not tmp.endswith("\n")):
print(tmp, end="", flush=True)
out_last = i + 1
# 使用 token_id 判断是否为 eos_token
if token == eos_token_id:
print(tmp, end="\n\n", flush=True)
break
else:
print("!!! Error: please say something !!!")