|
|
|
|
|
|
|
|
|
print("RWKV Chat Simple Demo") |
|
|
|
import os, copy, types, gc, sys, re |
|
import numpy as np |
|
from prompt_toolkit import prompt |
|
import torch |
|
from transformers import AutoTokenizer |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
os.environ["RWKV_V7_ON"] = "1" |
|
os.environ["RWKV_JIT_ON"] = "1" |
|
os.environ["RWKV_CUDA_ON"] = "0" |
|
|
|
from rwkv.model import RWKV |
|
from rwkv.utils import PIPELINE |
|
|
|
|
|
|
|
args = types.SimpleNamespace() |
|
|
|
args.strategy = "cuda fp16" |
|
|
|
args.MODEL_NAME = "./rwkv-final-sft-2048.pth" |
|
|
|
|
|
|
|
STATE_NAME = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
GEN_TEMP = 1.0 |
|
GEN_TOP_P = 0.3 |
|
GEN_alpha_presence = 0.5 |
|
GEN_alpha_frequency = 0.5 |
|
GEN_penalty_decay = 0.996 |
|
|
|
if STATE_NAME != None: |
|
GEN_TOP_P = 0.2 |
|
GEN_alpha_presence = 0.3 |
|
GEN_alpha_frequency = 0.3 |
|
|
|
CHUNK_LEN = 16 |
|
|
|
|
|
|
|
print(f"Loading model - {args.MODEL_NAME}") |
|
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy) |
|
pipeline = PIPELINE(model, "rwkv_vocab_v20230424") |
|
tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer") |
|
|
|
model_tokens = [] |
|
model_state = None |
|
|
|
if STATE_NAME != None: |
|
args = model.args |
|
state_raw = torch.load(STATE_NAME + '.pth') |
|
state_init = [None for i in range(args.n_layer * 3)] |
|
for i in range(args.n_layer): |
|
dd = model.strategy[i] |
|
dev = dd.device |
|
atype = dd.atype |
|
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() |
|
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous() |
|
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() |
|
model_state = copy.deepcopy(state_init) |
|
|
|
def run_rnn(ctx): |
|
global model_tokens, model_state |
|
|
|
ctx = ctx.replace("\r\n", "\n") |
|
|
|
tokens = tokenizer.encode(ctx) |
|
tokens = [int(x) for x in tokens] |
|
model_tokens += tokens |
|
|
|
|
|
|
|
while len(tokens) > 0: |
|
out, model_state = model.forward(tokens[:CHUNK_LEN], model_state) |
|
tokens = tokens[CHUNK_LEN:] |
|
|
|
return out |
|
|
|
if STATE_NAME == None: |
|
init_ctx = "User: hi" + "\n\n" |
|
init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n" |
|
|
|
|
|
|
|
while True: |
|
msg = prompt("<|im_start|>user:") |
|
msg = msg.strip() |
|
msg = re.sub(r"\n+", "\n", msg) |
|
if len(msg) > 0: |
|
occurrence = {} |
|
out_tokens = [] |
|
out_last = 0 |
|
|
|
out = run_rnn("<|im_start|>user\n" + msg + "<|im_end|>\n" + "<|im_start|>assistant\n") |
|
print("\nAssistant:", end="") |
|
|
|
eos_token_id = tokenizer.eos_token_id |
|
pad_token_id = tokenizer.pad_token_id |
|
|
|
for i in range(99999): |
|
for n in occurrence: |
|
out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency |
|
out[0] -= 1e10 |
|
|
|
token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P) |
|
|
|
out, model_state = model.forward([token], model_state) |
|
model_tokens += [token] |
|
|
|
out_tokens += [token] |
|
|
|
for xxx in occurrence: |
|
occurrence[xxx] *= GEN_penalty_decay |
|
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) |
|
|
|
tmp = tokenizer.decode(out_tokens[out_last:]) |
|
if ("\ufffd" not in tmp) and (not tmp.endswith("\n")): |
|
print(tmp, end="", flush=True) |
|
out_last = i + 1 |
|
|
|
|
|
if token == eos_token_id: |
|
print(tmp, end="\n\n", flush=True) |
|
break |
|
else: |
|
print("!!! Error: please say something !!!") |
|
|