KumaTea commited on
Commit
2b92edc
·
1 Parent(s): 0143ad2

Add API wrapper

Browse files
Files changed (2) hide show
  1. app.py +120 -3
  2. requirements.txt +2 -0
app.py CHANGED
@@ -12,8 +12,10 @@ fix_pytorch_int8()
12
 
13
 
14
  import torch
 
15
  import logging
16
  import gradio as gr
 
17
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
18
 
19
 
@@ -52,6 +54,28 @@ logging.basicConfig(
52
  level=logging.INFO,
53
  datefmt='%m/%d %H:%M:%S')
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  model = AutoModel.from_pretrained(
56
  int8_model,
57
  trust_remote_code=True,
@@ -67,6 +91,38 @@ model.eval()
67
 
68
  torch.set_default_tensor_type(torch.FloatTensor)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def evaluate(context, temperature, top_p):
72
  generation_config = GenerationConfig(
@@ -99,6 +155,64 @@ def evaluate(context, temperature, top_p):
99
  return out_text
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def evaluate_stream(msg, history, temperature, top_p):
103
  generation_config = GenerationConfig(
104
  temperature=temperature,
@@ -158,12 +272,15 @@ with gr.Blocks() as demo:
158
  msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
159
  info="输入你的内容,按 [Enter] 发送。什么都不填经常会出错。")
160
  clear = gr.Button("清除聊天")
161
- api_handler = gr.Button("API", visible=False)
162
- textbox_for_api = gr.Textbox(visible=False)
 
 
 
163
 
164
  msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
165
  clear.click(lambda: None, None, chatbot, queue=False)
166
- api_handler.click(evaluate, [textbox_for_api, temp, top_p], [textbox_for_api], api_name='chat')
167
  gr.HTML(gr_footer)
168
 
169
  demo.queue()
 
12
 
13
 
14
  import torch
15
+ import psutil
16
  import logging
17
  import gradio as gr
18
+ from threading import Thread
19
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
20
 
21
 
 
54
  level=logging.INFO,
55
  datefmt='%m/%d %H:%M:%S')
56
 
57
+
58
+ def log_sys_info():
59
+ cpu_cores = psutil.cpu_count()
60
+ cpu_freq = '{:.2f}'.format(psutil.cpu_freq().max / 1000) + 'GHz'
61
+ mem = psutil.virtual_memory()
62
+ mem_total = '{:.2f}'.format(mem.total / 1024 / 1024 / 1024) + 'GB'
63
+ mem_used = '{:.2f}'.format(mem.used / 1024 / 1024 / 1024) + 'GB'
64
+ mem_percent = '{:.2f}'.format(mem.percent) + '%'
65
+ disk = psutil.disk_usage('.')
66
+ disk_total = '{:.2f}'.format(disk.total / 1024 / 1024 / 1024) + 'GB'
67
+ disk_used = '{:.2f}'.format(disk.used / 1024 / 1024 / 1024) + 'GB'
68
+ disk_percent = '{:.2f}'.format(disk.percent) + '%'
69
+
70
+ logging.info('======== SYSTEM INFO =========')
71
+ logging.info(f'CPU: {cpu_cores} cores, {cpu_freq}')
72
+ logging.info(f'RAM: {mem_used} / {mem_total}, {mem_percent} used')
73
+ logging.info(f'DISK: {disk_used} / {disk_total}, {disk_percent} used')
74
+ logging.info('==============================')
75
+
76
+
77
+ log_sys_info()
78
+
79
  model = AutoModel.from_pretrained(
80
  int8_model,
81
  trust_remote_code=True,
 
91
 
92
  torch.set_default_tensor_type(torch.FloatTensor)
93
 
94
+ logging.info('[SYS] Model loaded')
95
+ log_sys_info()
96
+
97
+
98
+ class CHAT_DB:
99
+ def __init__(self):
100
+ self.prompts = {}
101
+ self.results = {}
102
+ self.index = 1
103
+ self.lock = False
104
+
105
+ def set(self, index, prompt=None, result=None):
106
+ assert prompt or result
107
+ if prompt:
108
+ if index in self.prompts:
109
+ raise ValueError('Prompt already exists')
110
+ self.prompts[index] = prompt
111
+ index += 1
112
+ if result:
113
+ self.results[index] = result
114
+
115
+ def clean(self):
116
+ if len(self.prompts) > 100:
117
+ self.prompts = dict(list(self.prompts.items())[-100:])
118
+ k = list(set(self.prompts.keys()).intersection(set(self.results.keys()))) # keys to preserve
119
+ self.prompts = {i: self.prompts[i] for i in k}
120
+ self.results = {i: self.results[i] for i in k}
121
+ log_sys_info()
122
+
123
+
124
+ db = CHAT_DB()
125
+
126
 
127
  def evaluate(context, temperature, top_p):
128
  generation_config = GenerationConfig(
 
155
  return out_text
156
 
157
 
158
+ def evaluate_wrapper(context, temperature, top_p):
159
+ db.lock = True
160
+ index = db.index
161
+ db.set(index, prompt=context)
162
+ result = evaluate(context, temperature, top_p)
163
+ db.set(index, result=result)
164
+ db.lock = False
165
+ return result
166
+
167
+
168
+ def api_wrapper(context='', temperature=0.5, top_p=0.8, query=0):
169
+ query = int(query)
170
+ assert context or query
171
+
172
+ return_json = {
173
+ 'status': '',
174
+ 'code': 0,
175
+ 'message': '',
176
+ 'index': 0,
177
+ 'result': ''
178
+ }
179
+
180
+ if context:
181
+ if db.lock:
182
+ logging.info(f'[API] Request: {context}, Status: busy')
183
+ return_json['status'] = 'busy'
184
+ return_json['code'] = 503
185
+ return_json['message'] = 'Server is busy, please try again later.'
186
+ return return_json
187
+ else:
188
+ index = db.index
189
+ t = Thread(target=evaluate_wrapper, args=(context, temperature, top_p))
190
+ t.start()
191
+ logging.info(f'[API] Request: {context}, Status: processing, Index: {index}')
192
+ return_json['status'] = 'processing'
193
+ return_json['code'] = 202
194
+ return_json['message'] = 'Request accepted, please check back later.'
195
+ return_json['index'] = index
196
+ return return_json
197
+ else: # query
198
+ if query in db.prompts:
199
+ if query in db.results:
200
+ logging.info(f'[API] Query: {query}, Status: hit')
201
+ return_json['status'] = 'done'
202
+ return_json['code'] = 200
203
+ return_json['message'] = 'Request processed.'
204
+ return_json['index'] = query
205
+ return_json['result'] = db.results[query]
206
+ return return_json
207
+ else:
208
+ logging.info(f'[API] Query: {query}, Status: processing')
209
+ return_json['status'] = 'processing'
210
+ return_json['code'] = 202
211
+ return_json['message'] = 'Request accepted, please check back later.'
212
+ return_json['index'] = query
213
+ return return_json
214
+
215
+
216
  def evaluate_stream(msg, history, temperature, top_p):
217
  generation_config = GenerationConfig(
218
  temperature=temperature,
 
272
  msg = gr.Textbox(label="输入框", placeholder="最近过得怎么样?",
273
  info="输入你的内容,按 [Enter] 发送。什么都不填经常会出错。")
274
  clear = gr.Button("清除聊天")
275
+
276
+ api_handler = gr.Button("API", visible=False)
277
+ num_for_api = gr.Number(visible=False)
278
+ json_for_api = gr.JSON(visible=False)
279
+
280
 
281
  msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
282
  clear.click(lambda: None, None, chatbot, queue=False)
283
+ api_handler.click(api_wrapper, [msg, temp, top_p, num_for_api], [json_for_api], api_name='chat')
284
  gr.HTML(gr_footer)
285
 
286
  demo.queue()
requirements.txt CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/requirements.txt
2
 
3
  # int8
 
1
+ psutil
2
+
3
  # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/requirements.txt
4
 
5
  # int8