Alic-Li's picture
Update app.py
5700d75 verified
import os
import copy
import types
import torch
from transformers import AutoTokenizer
import gradio as gr
os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0"
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
args = types.SimpleNamespace()
args.strategy = "cpu fp32"
args.MODEL_NAME = "./rwkv-final-sft-2048"
STATE_NAME = None
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
CHUNK_LEN = 16
print(f"Loading model - {args.MODEL_NAME}")
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer")
model_tokens = []
model_state = None
if STATE_NAME is not None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3
args = model.args
state_raw = torch.load(STATE_NAME + '.pth')
state_init = [None for i in range(args.n_layer * 3)]
for i in range(args.n_layer):
dd = model.strategy[i]
dev = dd.device
atype = dd.atype
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init)
def run_rnn(ctx, state):
ctx = ctx.replace("\r\n", "\n")
tokens = tokenizer.encode(ctx)
tokens = [int(x) for x in tokens]
current_state = copy.deepcopy(state) if state is not None else None
while len(tokens) > 0:
out, current_state = model.forward(tokens[:CHUNK_LEN], current_state)
tokens = tokens[CHUNK_LEN:]
return out, current_state
def generate_response(message, history, temperature=1.0, top_p=0.3):
global model_tokens, model_state
model_state = None
ctx = ""
for human, assistant in history:
ctx += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n{assistant}<!--eos--><|im_end|>\n"
ctx += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
out, model_state = run_rnn(ctx, model_state)
occurrence = {}
out_tokens = []
out_last = 0
response = ""
eos_token_id = tokenizer.eos_token_id
im_end_id = tokenizer.encode("<|im_end|>")[0]
for i in range(99999):
logits = out.clone()
for n in occurrence:
logits[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency
logits[0] -= 1e10
token = pipeline.sample_logits(logits, temperature=temperature, top_p=top_p)
if token == im_end_id:
break
out, model_state = model.forward([token], model_state)
out_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp = tokenizer.decode(out_tokens[out_last:])
if "\ufffd" not in tmp:
response += tmp
cleaned_response = response.replace("<|im_end|>", "")
yield cleaned_response
out_last = i + 1
if token == eos_token_id:
break
def chat_with_bot(message, history, temperature, top_p):
response = ""
for partial_response in generate_response(message, history, temperature, top_p):
response = partial_response
yield response
with gr.Blocks(title="MiniRWKV_7 34.2M 🪿 2vGPU Space") as demo:
gr.Markdown("# MiniRWKV_7 34.2M 🪿 ")
gr.Markdown("### Only 34.2M Params!!! Use 2V CPU Backend to run this model. ")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="对话记录",
height=500,
)
with gr.Column(scale=1):
msg = gr.Textbox(
label="输入消息",
placeholder="请输入您的问题...",
lines=3
)
with gr.Row():
send_btn = gr.Button("发送", variant="primary")
clear_btn = gr.Button("清除历史")
gr.Markdown("### 参数调节")
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=GEN_TEMP,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=GEN_TOP_P,
step=0.05,
label="Top-P"
)
def respond(message, chat_history, temperature, top_p):
if not message:
return "", chat_history
chat_history.append((message, ""))
response = ""
for partial_response in chat_with_bot(message, chat_history[:-1], temperature, top_p):
response = partial_response
cleaned_response = response.replace("<|im_end|>", "")
chat_history[-1] = (message, cleaned_response)
yield "", chat_history
def clear_history():
global model_tokens, model_state
model_tokens = []
model_state = None
return []
msg.submit(respond, [msg, chatbot, temperature_slider, top_p_slider], [msg, chatbot])
send_btn.click(respond, [msg, chatbot, temperature_slider, top_p_slider], [msg, chatbot])
clear_btn.click(clear_history, None, chatbot)
if __name__ == "__main__":
demo.launch()