KumaTea commited on
Commit
0143ad2
·
1 Parent(s): 1c99887

No need for starting prompt in API

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -17,6 +17,14 @@ import gradio as gr
17
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
18
 
19
 
 
 
 
 
 
 
 
 
20
  gr_title = """<h1 align="center">KumaGLM</h1>
21
  <h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
22
  <p align='center'>采样范围 2020/06/13 - 2023/04/15</p>
@@ -33,7 +41,7 @@ gr_footer = """<p align='center'>
33
  <p align='center'>
34
  <em>每天起床第一句!</em>
35
  </p>"""
36
- default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]
37
 
38
 
39
  # device = torch.device('cpu')
@@ -45,11 +53,11 @@ logging.basicConfig(
45
  datefmt='%m/%d %H:%M:%S')
46
 
47
  model = AutoModel.from_pretrained(
48
- "KumaTea/twitter-int8",
49
  trust_remote_code=True,
50
- revision="1136001"
51
  ).float() # .to(device)
52
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, revision="4de8efe")
53
 
54
  # dump a log to ensure everything works well
55
  # print(model.peft_config)
@@ -60,7 +68,7 @@ model.eval()
60
  torch.set_default_tensor_type(torch.FloatTensor)
61
 
62
 
63
- def evaluate(context, temperature, top_p, top_k=None):
64
  generation_config = GenerationConfig(
65
  temperature=temperature,
66
  top_p=top_p,
@@ -71,21 +79,23 @@ def evaluate(context, temperature, top_p, top_k=None):
71
  )
72
  with torch.no_grad():
73
  # input_text = f"Context: {context}Answer: "
74
- input_text = '||'.join(default_start) + '||'
75
- input_text += context + '||'
76
- logging.info('[API] Incoming request: ' + input_text)
77
- ids = tokenizer([input_text], return_tensors="pt")
 
 
78
  inputs = ids.to("cpu")
79
  out = model.generate(
80
  **inputs,
81
- max_length=224,
82
  generation_config=generation_config
83
  )
84
  out = out.tolist()[0]
85
  decoder_output = tokenizer.decode(out)
86
  # out_text = decoder_output.split("Answer: ")[1]
87
  out_text = decoder_output
88
- logging.info('[API] Result: ' + out_text)
89
  return out_text
90
 
91
 
@@ -117,17 +127,17 @@ def evaluate_stream(msg, history, temperature, top_p):
117
  context = context.replace(r'<br>', '')
118
 
119
  # TODO: Avoid the tokens are too long.
120
- CUTOFF = 224
121
- while len(tokenizer.encode(context)) > CUTOFF:
122
  # save 15 token size for the answer
123
  context = context[15:]
124
 
125
  h = []
126
- logging.info('[UI] Incoming request: ' + context)
127
- for response, h in model.stream_chat(tokenizer, context, h, max_length=CUTOFF, top_p=top_p, temperature=temperature):
128
  history[-1][1] = response
129
  yield history, ""
130
- logging.info('[UI] Result: ' + response)
131
 
132
 
133
  with gr.Blocks() as demo:
 
17
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
18
 
19
 
20
+ chatglm = 'THUDM/chatglm-6b'
21
+ chatglm_rev = '4de8efe'
22
+ int8_model = 'KumaTea/twitter-int8'
23
+ int8_model_rev = '1136001'
24
+
25
+ max_length = 224
26
+ default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]
27
+
28
  gr_title = """<h1 align="center">KumaGLM</h1>
29
  <h3 align='center'>这是一个 AI Kuma,你可以与他聊天,或者直接在文本框按下Enter</h3>
30
  <p align='center'>采样范围 2020/06/13 - 2023/04/15</p>
 
41
  <p align='center'>
42
  <em>每天起床第一句!</em>
43
  </p>"""
44
+
45
 
46
 
47
  # device = torch.device('cpu')
 
53
  datefmt='%m/%d %H:%M:%S')
54
 
55
  model = AutoModel.from_pretrained(
56
+ int8_model,
57
  trust_remote_code=True,
58
+ revision=int8_model_rev
59
  ).float() # .to(device)
60
+ tokenizer = AutoTokenizer.from_pretrained(chatglm, trust_remote_code=True, revision=chatglm_rev)
61
 
62
  # dump a log to ensure everything works well
63
  # print(model.peft_config)
 
68
  torch.set_default_tensor_type(torch.FloatTensor)
69
 
70
 
71
+ def evaluate(context, temperature, top_p):
72
  generation_config = GenerationConfig(
73
  temperature=temperature,
74
  top_p=top_p,
 
79
  )
80
  with torch.no_grad():
81
  # input_text = f"Context: {context}Answer: "
82
+ # input_text = '||'.join(default_start) + '||'
83
+ # No need for starting prompt in API
84
+ if not context.endswith('||'):
85
+ context += '||'
86
+ logging.info('[API] Request: ' + context)
87
+ ids = tokenizer([context], return_tensors="pt")
88
  inputs = ids.to("cpu")
89
  out = model.generate(
90
  **inputs,
91
+ max_length=max_length,
92
  generation_config=generation_config
93
  )
94
  out = out.tolist()[0]
95
  decoder_output = tokenizer.decode(out)
96
  # out_text = decoder_output.split("Answer: ")[1]
97
  out_text = decoder_output
98
+ logging.info('[API] Results: ' + out_text)
99
  return out_text
100
 
101
 
 
127
  context = context.replace(r'<br>', '')
128
 
129
  # TODO: Avoid the tokens are too long.
130
+ # CUTOFF = 224
131
+ while len(tokenizer.encode(context)) > max_length:
132
  # save 15 token size for the answer
133
  context = context[15:]
134
 
135
  h = []
136
+ logging.info('[UI] Request: ' + context)
137
+ for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature):
138
  history[-1][1] = response
139
  yield history, ""
140
+ logging.info('[UI] Results: ' + response)
141
 
142
 
143
  with gr.Blocks() as demo: