Spaces:
Running
Running
File size: 5,864 Bytes
f95603c 050bf24 f95603c 5700d75 f95603c 050bf24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import copy
import types
import torch
from transformers import AutoTokenizer
import gradio as gr
os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0"
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
args = types.SimpleNamespace()
args.strategy = "cpu fp32"
args.MODEL_NAME = "./rwkv-final-sft-2048"
STATE_NAME = None
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
CHUNK_LEN = 16
print(f"Loading model - {args.MODEL_NAME}")
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer")
model_tokens = []
model_state = None
if STATE_NAME is not None:
GEN_TOP_P = 0.2
GEN_alpha_presence = 0.3
GEN_alpha_frequency = 0.3
args = model.args
state_raw = torch.load(STATE_NAME + '.pth')
state_init = [None for i in range(args.n_layer * 3)]
for i in range(args.n_layer):
dd = model.strategy[i]
dev = dd.device
atype = dd.atype
state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
model_state = copy.deepcopy(state_init)
def run_rnn(ctx, state):
ctx = ctx.replace("\r\n", "\n")
tokens = tokenizer.encode(ctx)
tokens = [int(x) for x in tokens]
current_state = copy.deepcopy(state) if state is not None else None
while len(tokens) > 0:
out, current_state = model.forward(tokens[:CHUNK_LEN], current_state)
tokens = tokens[CHUNK_LEN:]
return out, current_state
def generate_response(message, history, temperature=1.0, top_p=0.3):
global model_tokens, model_state
model_state = None
ctx = ""
for human, assistant in history:
ctx += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n{assistant}<!--eos--><|im_end|>\n"
ctx += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
out, model_state = run_rnn(ctx, model_state)
occurrence = {}
out_tokens = []
out_last = 0
response = ""
eos_token_id = tokenizer.eos_token_id
im_end_id = tokenizer.encode("<|im_end|>")[0]
for i in range(99999):
logits = out.clone()
for n in occurrence:
logits[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency
logits[0] -= 1e10
token = pipeline.sample_logits(logits, temperature=temperature, top_p=top_p)
if token == im_end_id:
break
out, model_state = model.forward([token], model_state)
out_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp = tokenizer.decode(out_tokens[out_last:])
if "\ufffd" not in tmp:
response += tmp
cleaned_response = response.replace("<|im_end|>", "")
yield cleaned_response
out_last = i + 1
if token == eos_token_id:
break
def chat_with_bot(message, history, temperature, top_p):
response = ""
for partial_response in generate_response(message, history, temperature, top_p):
response = partial_response
yield response
with gr.Blocks(title="MiniRWKV_7 34.2M 🪿 2vGPU Space") as demo:
gr.Markdown("# MiniRWKV_7 34.2M 🪿 ")
gr.Markdown("### Only 34.2M Params!!! Use 2V CPU Backend to run this model. ")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="对话记录",
height=500,
)
with gr.Column(scale=1):
msg = gr.Textbox(
label="输入消息",
placeholder="请输入您的问题...",
lines=3
)
with gr.Row():
send_btn = gr.Button("发送", variant="primary")
clear_btn = gr.Button("清除历史")
gr.Markdown("### 参数调节")
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=GEN_TEMP,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=GEN_TOP_P,
step=0.05,
label="Top-P"
)
def respond(message, chat_history, temperature, top_p):
if not message:
return "", chat_history
chat_history.append((message, ""))
response = ""
for partial_response in chat_with_bot(message, chat_history[:-1], temperature, top_p):
response = partial_response
cleaned_response = response.replace("<|im_end|>", "")
chat_history[-1] = (message, cleaned_response)
yield "", chat_history
def clear_history():
global model_tokens, model_state
model_tokens = []
model_state = None
return []
msg.submit(respond, [msg, chatbot, temperature_slider, top_p_slider], [msg, chatbot])
send_btn.click(respond, [msg, chatbot, temperature_slider, top_p_slider], [msg, chatbot])
clear_btn.click(clear_history, None, chatbot)
if __name__ == "__main__":
demo.launch()
|