File size: 4,720 Bytes
1155f19 d2742cf e65d733 d2742cf e65d733 d2742cf e65d733 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from fix_int8 import fix_pytorch_int8
fix_pytorch_int8()
# Credit:
# https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py
import torch
import gradio as gr
from peft import PeftModel
from transformers import AutoTokenizer, GenerationConfig, AutoModel
model = AutoModel.from_pretrained("KumaTea/twitter-int8", trust_remote_code=True).float()
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe")
peft_path = 'KumaTea/twitter'
model = PeftModel.from_pretrained(
model,
peft_path,
torch_dtype=torch.float,
)
# dump a log to ensure everything works well
# print(model.peft_config)
# We have to use full precision, as some tokens are >65535
model.eval()
# print(model)
torch.set_default_tensor_type(torch.FloatTensor)
def evaluate(context, temperature, top_p, top_k):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
#repetition_penalty=1.1,
num_beams=1,
do_sample=True,
)
with torch.no_grad():
input_text = f"Context: {context}Answer: "
ids = tokenizer.encode(input_text)
input_ids = torch.LongTensor([ids]).to('cpu')
out = model.generate(
input_ids=input_ids,
max_length=160,
generation_config=generation_config
)
out_text = tokenizer.decode(out[0]).split("Answer: ")[1]
return out_text
def evaluate_stream(msg, history, temperature, top_p):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
#repetition_penalty=1.1,
num_beams=1,
do_sample=True,
)
history.append([msg, None])
context = ""
if len(history) > 4:
history.pop(0)
for j in range(len(history)):
history[j][0] = history[j][0].replace("<br>", "")
# concatenate context
for h in history[:-1]:
context += h[0] + "||" + h[1] + "||"
context += history[-1][0]
context = context.replace(r'<br>', '')
# TODO: Avoid the tokens are too long.
CUTOFF = 224
while len(tokenizer.encode(context)) > CUTOFF:
# save 15 token size for the answer
context = context[15:]
h = []
print("History:", history)
print("Context:", context)
for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
history[-1][1] = response
yield history, ""
#return response
title = """<h1 align="center">KumaGLM</h1>
<h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
<p align='center'>采样范围 2020/06/13 - 2023/04/15</p>"""
footer = """<p align='center'>
本项目基于
<a href='https://github.com/ljsabc/Fujisaki' target='_blank'>ljsabc/Fujisaki</a>
,模型采用
<a href='https://huggingface.co/THUDM/chatglm-6b' target='_blank'>THUDM/chatglm-6b</a>
。
</p>
<p align='center'>
<em>每天起床第一句!</em>
</p>"""
with gr.Blocks() as demo:
gr.HTML(title)
state = gr.State()
with gr.Row():
with gr.Column(scale=2):
temp = gr.components.Slider(minimum=0, maximum=1.1, value=0.8, label="Temperature",
info="温度参数,越高的温度生成的内容越丰富,但是有可能出现语法问题。小的温度也能帮助生成更相关的回答。")
top_p = gr.components.Slider(minimum=0.5, maximum=1.0, value=0.975, label="Top-p",
info="top-p参数,只输出前p>top-p的文字,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
#code = gr.Textbox(label="temp_output", info="解码器输出")
#top_k = gr.components.Slider(minimum=1, maximum=200, step=1, value=25, label="Top k",
# info="top-k参数,下一个输出的文字会从top-k个文字中进行选择,越大生成的内容越丰富,但也可能出现语法问题。数字越小似乎上下文的衔接性越好。")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="聊天框", info="")
msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
info="输入你的内容,按[Enter]发送。也可以什么都不填写生成随机数据。对话一般不能太长,否则就复读机了,建议清除数据。")
clear = gr.Button("清除聊天")
msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
clear.click(lambda: None, None, chatbot, queue=False)
gr.HTML(footer)
demo.queue()
demo.launch(debug=False)
|