No need for starting prompt in API
Browse files
app.py
CHANGED
@@ -17,6 +17,14 @@ import gradio as gr
|
|
17 |
from transformers import AutoTokenizer, GenerationConfig, AutoModel
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
gr_title = """<h1 align="center">KumaGLM</h1>
|
21 |
<h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
|
22 |
<p align='center'>采样范围 2020/06/13 - 2023/04/15</p>
|
@@ -33,7 +41,7 @@ gr_footer = """<p align='center'>
|
|
33 |
<p align='center'>
|
34 |
<em>每天起床第一句!</em>
|
35 |
</p>"""
|
36 |
-
|
37 |
|
38 |
|
39 |
# device = torch.device('cpu')
|
@@ -45,11 +53,11 @@ logging.basicConfig(
|
|
45 |
datefmt='%m/%d %H:%M:%S')
|
46 |
|
47 |
model = AutoModel.from_pretrained(
|
48 |
-
|
49 |
trust_remote_code=True,
|
50 |
-
revision=
|
51 |
).float() # .to(device)
|
52 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
53 |
|
54 |
# dump a log to ensure everything works well
|
55 |
# print(model.peft_config)
|
@@ -60,7 +68,7 @@ model.eval()
|
|
60 |
torch.set_default_tensor_type(torch.FloatTensor)
|
61 |
|
62 |
|
63 |
-
def evaluate(context, temperature, top_p
|
64 |
generation_config = GenerationConfig(
|
65 |
temperature=temperature,
|
66 |
top_p=top_p,
|
@@ -71,21 +79,23 @@ def evaluate(context, temperature, top_p, top_k=None):
|
|
71 |
)
|
72 |
with torch.no_grad():
|
73 |
# input_text = f"Context: {context}Answer: "
|
74 |
-
input_text = '||'.join(default_start) + '||'
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
inputs = ids.to("cpu")
|
79 |
out = model.generate(
|
80 |
**inputs,
|
81 |
-
max_length=
|
82 |
generation_config=generation_config
|
83 |
)
|
84 |
out = out.tolist()[0]
|
85 |
decoder_output = tokenizer.decode(out)
|
86 |
# out_text = decoder_output.split("Answer: ")[1]
|
87 |
out_text = decoder_output
|
88 |
-
logging.info('[API]
|
89 |
return out_text
|
90 |
|
91 |
|
@@ -117,17 +127,17 @@ def evaluate_stream(msg, history, temperature, top_p):
|
|
117 |
context = context.replace(r'<br>', '')
|
118 |
|
119 |
# TODO: Avoid the tokens are too long.
|
120 |
-
CUTOFF = 224
|
121 |
-
while len(tokenizer.encode(context)) >
|
122 |
# save 15 token size for the answer
|
123 |
context = context[15:]
|
124 |
|
125 |
h = []
|
126 |
-
logging.info('[UI]
|
127 |
-
for response, h in model.stream_chat(tokenizer, context, h, max_length=
|
128 |
history[-1][1] = response
|
129 |
yield history, ""
|
130 |
-
logging.info('[UI]
|
131 |
|
132 |
|
133 |
with gr.Blocks() as demo:
|
|
|
17 |
from transformers import AutoTokenizer, GenerationConfig, AutoModel
|
18 |
|
19 |
|
20 |
+
chatglm = 'THUDM/chatglm-6b'
|
21 |
+
chatglm_rev = '4de8efe'
|
22 |
+
int8_model = 'KumaTea/twitter-int8'
|
23 |
+
int8_model_rev = '1136001'
|
24 |
+
|
25 |
+
max_length = 224
|
26 |
+
default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]
|
27 |
+
|
28 |
gr_title = """<h1 align="center">KumaGLM</h1>
|
29 |
<h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
|
30 |
<p align='center'>采样范围 2020/06/13 - 2023/04/15</p>
|
|
|
41 |
<p align='center'>
|
42 |
<em>每天起床第一句!</em>
|
43 |
</p>"""
|
44 |
+
|
45 |
|
46 |
|
47 |
# device = torch.device('cpu')
|
|
|
53 |
datefmt='%m/%d %H:%M:%S')
|
54 |
|
55 |
model = AutoModel.from_pretrained(
|
56 |
+
int8_model,
|
57 |
trust_remote_code=True,
|
58 |
+
revision=int8_model_rev
|
59 |
).float() # .to(device)
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(chatglm, trust_remote_code=True, revision=chatglm_rev)
|
61 |
|
62 |
# dump a log to ensure everything works well
|
63 |
# print(model.peft_config)
|
|
|
68 |
torch.set_default_tensor_type(torch.FloatTensor)
|
69 |
|
70 |
|
71 |
+
def evaluate(context, temperature, top_p):
|
72 |
generation_config = GenerationConfig(
|
73 |
temperature=temperature,
|
74 |
top_p=top_p,
|
|
|
79 |
)
|
80 |
with torch.no_grad():
|
81 |
# input_text = f"Context: {context}Answer: "
|
82 |
+
# input_text = '||'.join(default_start) + '||'
|
83 |
+
# No need for starting prompt in API
|
84 |
+
if not context.endswith('||'):
|
85 |
+
context += '||'
|
86 |
+
logging.info('[API] Request: ' + context)
|
87 |
+
ids = tokenizer([context], return_tensors="pt")
|
88 |
inputs = ids.to("cpu")
|
89 |
out = model.generate(
|
90 |
**inputs,
|
91 |
+
max_length=max_length,
|
92 |
generation_config=generation_config
|
93 |
)
|
94 |
out = out.tolist()[0]
|
95 |
decoder_output = tokenizer.decode(out)
|
96 |
# out_text = decoder_output.split("Answer: ")[1]
|
97 |
out_text = decoder_output
|
98 |
+
logging.info('[API] Results: ' + out_text)
|
99 |
return out_text
|
100 |
|
101 |
|
|
|
127 |
context = context.replace(r'<br>', '')
|
128 |
|
129 |
# TODO: Avoid the tokens are too long.
|
130 |
+
# CUTOFF = 224
|
131 |
+
while len(tokenizer.encode(context)) > max_length:
|
132 |
# save 15 token size for the answer
|
133 |
context = context[15:]
|
134 |
|
135 |
h = []
|
136 |
+
logging.info('[UI] Request: ' + context)
|
137 |
+
for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature):
|
138 |
history[-1][1] = response
|
139 |
yield history, ""
|
140 |
+
logging.info('[UI] Results: ' + response)
|
141 |
|
142 |
|
143 |
with gr.Blocks() as demo:
|