File size: 5,307 Bytes
2dd2cd9 10dad5c 2dd2cd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
print("RWKV Chat Simple Demo")
import os, copy, types, gc, sys, re
import numpy as np
from prompt_toolkit import prompt
import torch
from transformers import AutoTokenizer
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
########################################################################################################
args = types.SimpleNamespace()
args.strategy = "cuda fp16" # use CUDA, fp16
args.MODEL_NAME = "./rwkv-final-sft-2048.pth"
########################################################################################################
STATE_NAME = None # use vanilla zero initial state?
# use custom state? much better chat results (download from https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main)
# note: this is English Single-round QA state (will forget what you previously say)
# STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048"
########################################################################################################
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
if STATE_NAME != None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3
CHUNK_LEN = 16 # split input into chunks to save VRAM (shorter -> slower, but saves VRAM)
########################################################################################################
print(f"Loading model - {args.MODEL_NAME}")
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer")
model_tokens = []
model_state = None
if STATE_NAME != None: # load custom state
args = model.args
state_raw = torch.load(STATE_NAME + '.pth')
state_init = [None for i in range(args.n_layer * 3)]
for i in range(args.n_layer):
dd = model.strategy[i]
dev = dd.device
atype = dd.atype
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init)
def run_rnn(ctx):
global model_tokens, model_state
ctx = ctx.replace("\r\n", "\n")
tokens = tokenizer.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens
# print(f"### model ###\n{model_tokens}\n[{pipeline.decode(model_tokens)}]") # debug
while len(tokens) > 0:
out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
tokens = tokens[CHUNK_LEN:]
return out
if STATE_NAME == None: # use initial prompt if we are not loading a state
init_ctx = "User: hi" + "\n\n"
init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n"
# run_rnn(init_ctx)
# print(init_ctx, end="")
while True:
msg = prompt("<|im_start|>user:")
msg = msg.strip()
msg = re.sub(r"\n+", "\n", msg)
if len(msg) > 0:
occurrence = {}
out_tokens = []
out_last = 0
out = run_rnn("<|im_start|>user\n" + msg + "<|im_end|>\n" + "<|im_start|>assistant\n")
print("\nAssistant:", end="")
eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id
for i in range(99999):
for n in occurrence:
out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency # repetition penalty
out[0] -= 1e10 # disable END_OF_TEXT
token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P)
out, model_state = model.forward([token], model_state)
model_tokens += [token]
out_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp = tokenizer.decode(out_tokens[out_last:])
if ("\ufffd" not in tmp) and (not tmp.endswith("\n")):
print(tmp, end="", flush=True)
out_last = i + 1
# 使用 token_id 判断是否为 eos_token
if token == eos_token_id:
print(tmp, end="\n\n", flush=True)
break
else:
print("!!! Error: please say something !!!")
|