File size: 5,864 Bytes
f95603c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
050bf24
f95603c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5700d75
f95603c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
050bf24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import copy
import types
import torch
from transformers import AutoTokenizer
import gradio as gr

os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "0"

from rwkv.model import RWKV
from rwkv.utils import PIPELINE

args = types.SimpleNamespace()
args.strategy = "cpu fp32"
args.MODEL_NAME = "./rwkv-final-sft-2048"

STATE_NAME = None
GEN_TEMP = 1.0
GEN_TOP_P = 0.3
GEN_alpha_presence = 0.5
GEN_alpha_frequency = 0.5
GEN_penalty_decay = 0.996
CHUNK_LEN = 16

print(f"Loading model - {args.MODEL_NAME}")
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
tokenizer = AutoTokenizer.from_pretrained("./MiniMind2_tokenizer")

model_tokens = []
model_state = None

if STATE_NAME is not None:
    GEN_TOP_P = 0.2
    GEN_alpha_presence = 0.3
    GEN_alpha_frequency = 0.3
    
    args = model.args
    state_raw = torch.load(STATE_NAME + '.pth')
    state_init = [None for i in range(args.n_layer * 3)]
    for i in range(args.n_layer):
        dd = model.strategy[i]
        dev = dd.device
        atype = dd.atype    
        state_init[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
        state_init[i*3+1] = state_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
        state_init[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
    model_state = copy.deepcopy(state_init)

def run_rnn(ctx, state):
    ctx = ctx.replace("\r\n", "\n")
    tokens = tokenizer.encode(ctx)
    tokens = [int(x) for x in tokens]
    
    current_state = copy.deepcopy(state) if state is not None else None
    
    while len(tokens) > 0:
        out, current_state = model.forward(tokens[:CHUNK_LEN], current_state)
        tokens = tokens[CHUNK_LEN:]
    
    return out, current_state

def generate_response(message, history, temperature=1.0, top_p=0.3):
    global model_tokens, model_state
    model_state = None
    
    ctx = ""
    for human, assistant in history:
        ctx += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n{assistant}<!--eos--><|im_end|>\n"
    
    ctx += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
    
    out, model_state = run_rnn(ctx, model_state)
    
    occurrence = {}
    out_tokens = []
    out_last = 0
    response = ""
    
    eos_token_id = tokenizer.eos_token_id
    im_end_id = tokenizer.encode("<|im_end|>")[0]  
    for i in range(99999):
        logits = out.clone()
        for n in occurrence:
            logits[n] -= GEN_alpha_presence + occurrence[n] * GEN_alpha_frequency
        
        logits[0] -= 1e10  
        
        token = pipeline.sample_logits(logits, temperature=temperature, top_p=top_p)
        
        if token == im_end_id:
            break
            
        out, model_state = model.forward([token], model_state)
        
        out_tokens += [token]
        for xxx in occurrence:
            occurrence[xxx] *= GEN_penalty_decay
        occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
        
        tmp = tokenizer.decode(out_tokens[out_last:])
        if "\ufffd" not in tmp:  
            response += tmp
            cleaned_response = response.replace("<|im_end|>", "")
            yield cleaned_response
            out_last = i + 1
            
            if token == eos_token_id:
                break

def chat_with_bot(message, history, temperature, top_p):
    response = ""
    for partial_response in generate_response(message, history, temperature, top_p):
        response = partial_response
        yield response

with gr.Blocks(title="MiniRWKV_7 34.2M 🪿 2vGPU Space") as demo:
    gr.Markdown("# MiniRWKV_7 34.2M 🪿 ")
    gr.Markdown("### Only 34.2M Params!!! Use 2V CPU Backend to run this model. ")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(
                label="对话记录",
                height=500,
            )
        
        with gr.Column(scale=1):
            msg = gr.Textbox(
                label="输入消息",
                placeholder="请输入您的问题...",
                lines=3
            )
            
            with gr.Row():
                send_btn = gr.Button("发送", variant="primary")
                clear_btn = gr.Button("清除历史")
            
            gr.Markdown("### 参数调节")
            temperature_slider = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=GEN_TEMP,
                step=0.1,
                label="Temperature"
            )
            top_p_slider = gr.Slider(
                minimum=0.0,
                maximum=2.0,
                value=GEN_TOP_P,
                step=0.05,
                label="Top-P"
            )
    
    
    def respond(message, chat_history, temperature, top_p):
        if not message:
            return "", chat_history
        
        chat_history.append((message, ""))
        
        response = ""
        for partial_response in chat_with_bot(message, chat_history[:-1], temperature, top_p):
            response = partial_response
            cleaned_response = response.replace("<|im_end|>", "")
            chat_history[-1] = (message, cleaned_response)
            yield "", chat_history
    
    def clear_history():
        global model_tokens, model_state
        model_tokens = []
        model_state = None
        return []
    
    msg.submit(respond, [msg, chatbot, temperature_slider, top_p_slider], [msg, chatbot])
    send_btn.click(respond, [msg, chatbot, temperature_slider, top_p_slider], [msg, chatbot])
    clear_btn.click(clear_history, None, chatbot)

if __name__ == "__main__":
    demo.launch()