File size: 4,720 Bytes
1155f19
 
d2742cf
 
e65d733
 
 
 
 
 
 
 
 
 
d2742cf
e65d733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2742cf
e65d733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from fix_int8 import fix_pytorch_int8
fix_pytorch_int8()


# Credit:
# https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py


import torch
import gradio as gr
from peft import PeftModel
from transformers import AutoTokenizer, GenerationConfig, AutoModel


model = AutoModel.from_pretrained("KumaTea/twitter-int8", trust_remote_code=True).float()
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe")
peft_path = 'KumaTea/twitter'
model = PeftModel.from_pretrained(
    model,
    peft_path,
    torch_dtype=torch.float,
)

# dump a log to ensure everything works well
# print(model.peft_config)
# We have to use full precision, as some tokens are >65535
model.eval()
# print(model)

torch.set_default_tensor_type(torch.FloatTensor)


def evaluate(context, temperature, top_p, top_k):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        #repetition_penalty=1.1,
        num_beams=1,
        do_sample=True,
    )
    with torch.no_grad():
        input_text = f"Context: {context}Answer: " 
        ids = tokenizer.encode(input_text)
        input_ids = torch.LongTensor([ids]).to('cpu')
        out = model.generate(
            input_ids=input_ids,
            max_length=160,
            generation_config=generation_config
        )
        out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
        return out_text


def evaluate_stream(msg, history, temperature, top_p):
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        #repetition_penalty=1.1,
        num_beams=1,
        do_sample=True,
    )

    history.append([msg, None])

    context = ""
    if len(history) > 4:
        history.pop(0)

    for j in range(len(history)):
        history[j][0] = history[j][0].replace("<br>", "")

    # concatenate context
    for h in history[:-1]:
        context += h[0] + "||" + h[1] + "||"

    context += history[-1][0]
    context = context.replace(r'<br>', '')

    # TODO: Avoid the tokens are too long.
    CUTOFF = 224
    while len(tokenizer.encode(context)) > CUTOFF:
        # save 15 token size for the answer
        context = context[15:]

    h = []
    print("History:", history)
    print("Context:", context)
    for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
        history[-1][1] = response
        yield history, ""

    #return response


title = """<h1 align="center">KumaGLM</h1>
<h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
<p align='center'>采样范围 2020/06/13 - 2023/04/15</p>"""
footer =  """<p align='center'>
本项目基于
<a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
,模型采用
<a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>

</p>
<p align='center'>
<em>每天起床第一句!</em>
</p>"""

with gr.Blocks() as demo:
    gr.HTML(title)
    state = gr.State()
    with gr.Row():
        with gr.Column(scale=2):
            temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.8, label="Temperature",
                info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
            top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.975, label="Top-p",
                info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
            #code = gr.Textbox(label="temp_output", info="解码器输出")
            #top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
            #    info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
            
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="聊天框", info="")
            msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
                info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。对话一般不能太长,否则就复读机了,建议清除数据。")
            clear = gr.Button("清除聊天")

    msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
    clear.click(lambda: None, None, chatbot, queue=False)
    gr.HTML(footer)

demo.queue()
demo.launch(debug=False)