######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## print("RWKV Chat Simple Demo") import os, copy, types, gc, sys, re import numpy as np from prompt_toolkit import prompt import torch from transformers import AutoTokenizer torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models os.environ["RWKV_JIT_ON"] = "1" os.environ["RWKV_CUDA_ON"] = "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!! from rwkv.model import RWKV from rwkv.utils import PIPELINE ######################################################################################################## args = types.SimpleNamespace() args.strategy = "cuda fp16" # use CUDA, fp16 args.MODEL_NAME = "./rwkv-final-sft-2048.pth" ######################################################################################################## STATE_NAME = None # use vanilla zero initial state? # use custom state? much better chat results (download from https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main) # note: this is English Single-round QA state (will forget what you previously say) # STATE_NAME = "E://RWKV-Runner//models//rwkv-x060-eng_single_round_qa-1B6-20240516-ctx2048" ######################################################################################################## GEN_TEMP = 1.0 GEN_TOP_P = 0.3 GEN_alpha_presence = 0.5 GEN_alpha_frequency = 0.5 GEN_penalty_decay = 0.996 if STATE_NAME != None: GEN_TOP_P = 0.2 GEN_alpha_presence = 0.3 GEN_alpha_frequency = 0.3 CHUNK_LEN = 16 # split input into chunks to save VRAM (shorter -> slower, but saves VRAM) ######################################################################################################## print(f"Loading model - {args.MODEL_NAME}") model = RWKV(model=args.MODEL_NAME, strategy=args.strategy) pipeline = PIPELINE(model, "rwkv_vocab_v20230424") tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer") model_tokens = [] model_state = None if STATE_NAME != None: # load custom state args = model.args state_raw = torch.load(STATE_NAME + '.pth') state_init = [None for i in range(args.n_layer * 3)] for i in range(args.n_layer): dd = model.strategy[i] dev = dd.device atype = dd.atype state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous() state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() model_state = copy.deepcopy(state_init) def run_rnn(ctx): global model_tokens, model_state ctx = ctx.replace("\r\n", "\n") tokens = tokenizer.encode(ctx) tokens = [int(x) for x in tokens] model_tokens += tokens # print(f"### model ###\n{model_tokens}\n[{pipeline.decode(model_tokens)}]") # debug while len(tokens) > 0: out, model_state = model.forward(tokens[:CHUNK_LEN], model_state) tokens = tokens[CHUNK_LEN:] return out if STATE_NAME == None: # use initial prompt if we are not loading a state init_ctx = "User: hi" + "\n\n" init_ctx += "Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it." + "\n\n" # run_rnn(init_ctx) # print(init_ctx, end="") while True: msg = prompt("<|im_start|>user:") msg = msg.strip() msg = re.sub(r"\n+", "\n", msg) if len(msg) > 0: occurrence = {} out_tokens = [] out_last = 0 out = run_rnn("<|im_start|>user\n" + msg + "<|im_end|>\n" + "<|im_start|>assistant\n") print("\nAssistant:", end="") eos_token_id = tokenizer.eos_token_id pad_token_id = tokenizer.pad_token_id for i in range(99999): for n in occurrence: out[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency # repetition penalty out[0] -= 1e10 # disable END_OF_TEXT token = pipeline.sample_logits(out, temperature=GEN_TEMP, top_p=GEN_TOP_P) out, model_state = model.forward([token], model_state) model_tokens += [token] out_tokens += [token] for xxx in occurrence: occurrence[xxx] *= GEN_penalty_decay occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) tmp = tokenizer.decode(out_tokens[out_last:]) if ("\ufffd" not in tmp) and (not tmp.endswith("\n")): print(tmp, end="", flush=True) out_last = i + 1 # 使用 token_id 判断是否为 eos_token if token == eos_token_id: print(tmp, end="\n\n", flush=True) break else: print("!!! Error: please say something !!!")