KumaTea commited on
Commit
6cf18af
·
1 Parent(s): 2b92edc
Files changed (4) hide show
  1. app.py +46 -116
  2. fix_int8.py +2 -1
  3. model.py +36 -0
  4. session.py +71 -0
app.py CHANGED
@@ -2,28 +2,18 @@ from fix_int8 import fix_pytorch_int8
2
  fix_pytorch_int8()
3
 
4
 
5
- # import subprocess
6
- # result = subprocess.run(['git', 'clone', 'https://huggingface.co/KumaTea/twitter-int8', 'model'], capture_output=True, text=True)
7
- # print(result.stdout)
8
-
9
-
10
  # Credit:
11
  # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py
12
 
13
 
14
  import torch
15
- import psutil
16
- import logging
17
  import gradio as gr
18
  from threading import Thread
 
 
19
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
20
 
21
 
22
- chatglm = 'THUDM/chatglm-6b'
23
- chatglm_rev = '4de8efe'
24
- int8_model = 'KumaTea/twitter-int8'
25
- int8_model_rev = '1136001'
26
-
27
  max_length = 224
28
  default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]
29
 
@@ -45,85 +35,6 @@ gr_footer = """<p align='center'>
45
  </p>"""
46
 
47
 
48
-
49
- # device = torch.device('cpu')
50
- # torch.cuda.current_device = lambda : device
51
-
52
- logging.basicConfig(
53
- format='%(asctime)s %(levelname)-8s %(message)s',
54
- level=logging.INFO,
55
- datefmt='%m/%d %H:%M:%S')
56
-
57
-
58
- def log_sys_info():
59
- cpu_cores = psutil.cpu_count()
60
- cpu_freq = '{:.2f}'.format(psutil.cpu_freq().max / 1000) + 'GHz'
61
- mem = psutil.virtual_memory()
62
- mem_total = '{:.2f}'.format(mem.total / 1024 / 1024 / 1024) + 'GB'
63
- mem_used = '{:.2f}'.format(mem.used / 1024 / 1024 / 1024) + 'GB'
64
- mem_percent = '{:.2f}'.format(mem.percent) + '%'
65
- disk = psutil.disk_usage('.')
66
- disk_total = '{:.2f}'.format(disk.total / 1024 / 1024 / 1024) + 'GB'
67
- disk_used = '{:.2f}'.format(disk.used / 1024 / 1024 / 1024) + 'GB'
68
- disk_percent = '{:.2f}'.format(disk.percent) + '%'
69
-
70
- logging.info('======== SYSTEM INFO =========')
71
- logging.info(f'CPU: {cpu_cores} cores, {cpu_freq}')
72
- logging.info(f'RAM: {mem_used} / {mem_total}, {mem_percent} used')
73
- logging.info(f'DISK: {disk_used} / {disk_total}, {disk_percent} used')
74
- logging.info('==============================')
75
-
76
-
77
- log_sys_info()
78
-
79
- model = AutoModel.from_pretrained(
80
- int8_model,
81
- trust_remote_code=True,
82
- revision=int8_model_rev
83
- ).float() # .to(device)
84
- tokenizer = AutoTokenizer.from_pretrained(chatglm, trust_remote_code=True, revision=chatglm_rev)
85
-
86
- # dump a log to ensure everything works well
87
- # print(model.peft_config)
88
- # We have to use full precision, as some tokens are >65535
89
- model.eval()
90
- # print(model)
91
-
92
- torch.set_default_tensor_type(torch.FloatTensor)
93
-
94
- logging.info('[SYS] Model loaded')
95
- log_sys_info()
96
-
97
-
98
- class CHAT_DB:
99
- def __init__(self):
100
- self.prompts = {}
101
- self.results = {}
102
- self.index = 1
103
- self.lock = False
104
-
105
- def set(self, index, prompt=None, result=None):
106
- assert prompt or result
107
- if prompt:
108
- if index in self.prompts:
109
- raise ValueError('Prompt already exists')
110
- self.prompts[index] = prompt
111
- index += 1
112
- if result:
113
- self.results[index] = result
114
-
115
- def clean(self):
116
- if len(self.prompts) > 100:
117
- self.prompts = dict(list(self.prompts.items())[-100:])
118
- k = list(set(self.prompts.keys()).intersection(set(self.results.keys()))) # keys to preserve
119
- self.prompts = {i: self.prompts[i] for i in k}
120
- self.results = {i: self.results[i] for i in k}
121
- log_sys_info()
122
-
123
-
124
- db = CHAT_DB()
125
-
126
-
127
  def evaluate(context, temperature, top_p):
128
  generation_config = GenerationConfig(
129
  temperature=temperature,
@@ -139,7 +50,7 @@ def evaluate(context, temperature, top_p):
139
  # No need for starting prompt in API
140
  if not context.endswith('||'):
141
  context += '||'
142
- logging.info('[API] Request: ' + context)
143
  ids = tokenizer([context], return_tensors="pt")
144
  inputs = ids.to("cpu")
145
  out = model.generate(
@@ -151,17 +62,17 @@ def evaluate(context, temperature, top_p):
151
  decoder_output = tokenizer.decode(out)
152
  # out_text = decoder_output.split("Answer: ")[1]
153
  out_text = decoder_output
154
- logging.info('[API] Results: ' + out_text)
155
  return out_text
156
 
157
 
158
  def evaluate_wrapper(context, temperature, top_p):
159
- db.lock = True
160
  index = db.index
161
  db.set(index, prompt=context)
162
  result = evaluate(context, temperature, top_p)
163
  db.set(index, result=result)
164
- db.lock = False
165
  return result
166
 
167
 
@@ -178,37 +89,53 @@ def api_wrapper(context='', temperature=0.5, top_p=0.8, query=0):
178
  }
179
 
180
  if context:
181
- if db.lock:
182
- logging.info(f'[API] Request: {context}, Status: busy')
183
  return_json['status'] = 'busy'
184
  return_json['code'] = 503
185
- return_json['message'] = 'Server is busy, please try again later.'
186
  return return_json
187
  else:
 
 
 
 
 
 
 
 
 
188
  index = db.index
189
  t = Thread(target=evaluate_wrapper, args=(context, temperature, top_p))
190
  t.start()
191
- logging.info(f'[API] Request: {context}, Status: processing, Index: {index}')
192
  return_json['status'] = 'processing'
193
  return_json['code'] = 202
194
- return_json['message'] = 'Request accepted, please check back later.'
195
  return_json['index'] = index
196
  return return_json
197
  else: # query
198
- if query in db.prompts:
199
- if query in db.results:
200
- logging.info(f'[API] Query: {query}, Status: hit')
201
- return_json['status'] = 'done'
202
- return_json['code'] = 200
203
- return_json['message'] = 'Request processed.'
 
 
 
 
 
 
 
 
204
  return_json['index'] = query
205
- return_json['result'] = db.results[query]
206
  return return_json
207
  else:
208
- logging.info(f'[API] Query: {query}, Status: processing')
209
- return_json['status'] = 'processing'
210
- return_json['code'] = 202
211
- return_json['message'] = 'Request accepted, please check back later.'
212
  return_json['index'] = query
213
  return return_json
214
 
@@ -247,11 +174,11 @@ def evaluate_stream(msg, history, temperature, top_p):
247
  context = context[15:]
248
 
249
  h = []
250
- logging.info('[UI] Request: ' + context)
251
  for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature):
252
  history[-1][1] = response
253
  yield history, ""
254
- logging.info('[UI] Results: ' + response)
255
 
256
 
257
  with gr.Blocks() as demo:
@@ -274,13 +201,16 @@ with gr.Blocks() as demo:
274
  clear = gr.Button("清除聊天")
275
 
276
  api_handler = gr.Button("API", visible=False)
277
- num_for_api = gr.Number(visible=False)
278
- json_for_api = gr.JSON(visible=False)
 
 
279
 
280
 
281
  msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
282
  clear.click(lambda: None, None, chatbot, queue=False)
283
- api_handler.click(api_wrapper, [msg, temp, top_p, num_for_api], [json_for_api], api_name='chat')
 
284
  gr.HTML(gr_footer)
285
 
286
  demo.queue()
 
2
  fix_pytorch_int8()
3
 
4
 
 
 
 
 
 
5
  # Credit:
6
  # https://huggingface.co/spaces/ljsabc/Fujisaki/blob/main/app.py
7
 
8
 
9
  import torch
 
 
10
  import gradio as gr
11
  from threading import Thread
12
+ from model import model, tokenizer
13
+ from session import db, logger, log_sys_info
14
  from transformers import AutoTokenizer, GenerationConfig, AutoModel
15
 
16
 
 
 
 
 
 
17
  max_length = 224
18
  default_start = ["你是Kuma,请和我聊天,每句话以两个竖杠分隔。", "好的,你想聊什么?"]
19
 
 
35
  </p>"""
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def evaluate(context, temperature, top_p):
39
  generation_config = GenerationConfig(
40
  temperature=temperature,
 
50
  # No need for starting prompt in API
51
  if not context.endswith('||'):
52
  context += '||'
53
+ # logger.info('[API] Request: ' + context)
54
  ids = tokenizer([context], return_tensors="pt")
55
  inputs = ids.to("cpu")
56
  out = model.generate(
 
62
  decoder_output = tokenizer.decode(out)
63
  # out_text = decoder_output.split("Answer: ")[1]
64
  out_text = decoder_output
65
+ logger.info('[API] Results: ' + out_text.replace('\n', '<br>'))
66
  return out_text
67
 
68
 
69
  def evaluate_wrapper(context, temperature, top_p):
70
+ db.lock()
71
  index = db.index
72
  db.set(index, prompt=context)
73
  result = evaluate(context, temperature, top_p)
74
  db.set(index, result=result)
75
+ db.unlock()
76
  return result
77
 
78
 
 
89
  }
90
 
91
  if context:
92
+ if db.islocked():
93
+ logger.info(f'[API] Request: {context}, Status: busy')
94
  return_json['status'] = 'busy'
95
  return_json['code'] = 503
96
+ return_json['message'] = '[context] Server is busy, please try again later.'
97
  return return_json
98
  else:
99
+ for index in db.prompts:
100
+ if db.prompts[index] == context:
101
+ return_json['status'] = 'done'
102
+ return_json['code'] = 200
103
+ return_json['message'] = '[context] Request cached.'
104
+ return_json['index'] = index
105
+ return_json['result'] = db.results[index]
106
+ return return_json
107
+ # new
108
  index = db.index
109
  t = Thread(target=evaluate_wrapper, args=(context, temperature, top_p))
110
  t.start()
111
+ logger.info(f'[API] Request: {context}, Status: processing, Index: {index}')
112
  return_json['status'] = 'processing'
113
  return_json['code'] = 202
114
+ return_json['message'] = '[context] Request accepted, please check back later.'
115
  return_json['index'] = index
116
  return return_json
117
  else: # query
118
+ if query in db.prompts and query in db.results:
119
+ logger.info(f'[API] Query: {query}, Status: hit')
120
+ return_json['status'] = 'done'
121
+ return_json['code'] = 200
122
+ return_json['message'] = '[query] Request processed.'
123
+ return_json['index'] = query
124
+ return_json['result'] = db.results[query]
125
+ return return_json
126
+ else:
127
+ if db.islocked():
128
+ logger.info(f'[API] Query: {query}, Status: processing')
129
+ return_json['status'] = 'processing'
130
+ return_json['code'] = 202
131
+ return_json['message'] = '[query] Request in processing, please check back later.'
132
  return_json['index'] = query
 
133
  return return_json
134
  else:
135
+ logger.info(f'[API] Query: {query}, Status: error')
136
+ return_json['status'] = 'error'
137
+ return_json['code'] = 404
138
+ return_json['message'] = '[query] Index not found.'
139
  return_json['index'] = query
140
  return return_json
141
 
 
174
  context = context[15:]
175
 
176
  h = []
177
+ logger.info('[UI] Request: ' + context)
178
  for response, h in model.stream_chat(tokenizer, context, h, max_length=max_length, top_p=top_p, temperature=temperature):
179
  history[-1][1] = response
180
  yield history, ""
181
+ logger.info('[UI] Results: ' + response.replace('\n', '<br>'))
182
 
183
 
184
  with gr.Blocks() as demo:
 
201
  clear = gr.Button("清除聊天")
202
 
203
  api_handler = gr.Button("API", visible=False)
204
+ api_index = gr.Number(visible=False)
205
+ api_result = gr.JSON(visible=False)
206
+ info_handler = gr.Button("Info", visible=False)
207
+ info_text = gr.Textbox('System info logged. Check it in the log viewer.', visible=False)
208
 
209
 
210
  msg.submit(evaluate_stream, [msg, chatbot, temp, top_p], [chatbot, msg])
211
  clear.click(lambda: None, None, chatbot, queue=False)
212
+ api_handler.click(api_wrapper, [msg, temp, top_p, api_index], api_result, api_name='chat')
213
+ info_handler.click(log_sys_info, None, info_text, api_name='info')
214
  gr.HTML(gr_footer)
215
 
216
  demo.queue()
fix_int8.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
 
4
 
5
  def fix_pytorch_int8():
@@ -26,4 +27,4 @@ def fix_pytorch_int8():
26
  with open(fix_path, 'w') as f:
27
  f.write(text)
28
 
29
- return print('Fixed torch/nn/parameter.py')
 
1
  import os
2
  import sys
3
+ from session import logger
4
 
5
 
6
  def fix_pytorch_int8():
 
27
  with open(fix_path, 'w') as f:
28
  f.write(text)
29
 
30
+ return logger.info('Fixed torch/nn/parameter.py')
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from session import logger, log_sys_info
3
+ from transformers import AutoTokenizer, GenerationConfig, AutoModel
4
+
5
+
6
+ chatglm = 'THUDM/chatglm-6b'
7
+ chatglm_rev = '4de8efe'
8
+ int8_model = 'KumaTea/twitter-int8'
9
+ int8_model_rev = '1136001'
10
+
11
+ # import subprocess
12
+ # result = subprocess.run(['git', 'clone', 'https://huggingface.co/KumaTea/twitter-int8', 'model'], capture_output=True, text=True)
13
+ # print(result.stdout)
14
+
15
+ # device = torch.device('cpu')
16
+ # torch.cuda.current_device = lambda : device
17
+
18
+ log_sys_info()
19
+
20
+ model = AutoModel.from_pretrained(
21
+ int8_model,
22
+ trust_remote_code=True,
23
+ revision=int8_model_rev
24
+ ).float() # .to(device)
25
+ tokenizer = AutoTokenizer.from_pretrained(chatglm, trust_remote_code=True, revision=chatglm_rev)
26
+
27
+ # dump a log to ensure everything works well
28
+ # print(model.peft_config)
29
+ # We have to use full precision, as some tokens are >65535
30
+ model.eval()
31
+ # print(model)
32
+
33
+ torch.set_default_tensor_type(torch.FloatTensor)
34
+
35
+ logger.info('[SYS] Model loaded')
36
+ log_sys_info()
session.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import psutil
3
+ import logging
4
+ from pathlib import Path
5
+
6
+
7
+ logging.basicConfig(
8
+ format='%(asctime)s %(levelname)-8s %(message)s',
9
+ level=logging.INFO,
10
+ datefmt='%m/%d %H:%M:%S')
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def log_sys_info():
15
+ cpu_cores = psutil.cpu_count()
16
+ # cpu_freq = '{:.2f}'.format(psutil.cpu_freq().max / 1000) + 'GHz'
17
+ cpu_percent = '{:.2f}'.format(psutil.cpu_percent()) + '%'
18
+ mem = psutil.virtual_memory()
19
+ mem_total = '{:.2f}'.format(mem.total / 1024 / 1024 / 1024) + 'GB'
20
+ mem_used = '{:.2f}'.format(mem.used / 1024 / 1024 / 1024) + 'GB'
21
+ mem_percent = '{:.2f}'.format(mem.percent) + '%'
22
+ disk = psutil.disk_usage('.')
23
+ disk_total = '{:.2f}'.format(disk.total / 1024 / 1024 / 1024) + 'GB'
24
+ disk_used = '{:.2f}'.format(disk.used / 1024 / 1024 / 1024) + 'GB'
25
+ disk_percent = '{:.2f}'.format(disk.percent) + '%'
26
+
27
+ logger.info('======== SYSTEM INFO =========')
28
+ logger.info(f'CPU: {cpu_cores} cores, {cpu_percent} used')
29
+ logger.info(f'RAM: {mem_used} / {mem_total}, {mem_percent} used')
30
+ logger.info(f'DISK: {disk_used} / {disk_total}, {disk_percent} used')
31
+ logger.info('==============================')
32
+
33
+
34
+ class CHAT_DB:
35
+ def __init__(self):
36
+ self.prompts = {}
37
+ self.results = {}
38
+ self.index = 1
39
+ self.lockfile = '.lock'
40
+
41
+ def set(self, index, prompt=None, result=None):
42
+ assert prompt or result
43
+ if prompt:
44
+ if index in self.prompts:
45
+ raise ValueError('Prompt already exists')
46
+ self.prompts[index] = prompt
47
+ self.index += 1
48
+ if result:
49
+ self.results[index] = result
50
+
51
+ def lock(self):
52
+ if not os.path.exists(self.lockfile):
53
+ Path(self.lockfile).touch(exist_ok=True)
54
+
55
+ def unlock(self):
56
+ if os.path.exists(self.lockfile):
57
+ os.remove(self.lockfile)
58
+
59
+ def islocked(self):
60
+ return os.path.exists(self.lockfile)
61
+
62
+ def clean(self):
63
+ if len(self.prompts) > 100:
64
+ self.prompts = dict(list(self.prompts.items())[-100:])
65
+ k = list(set(self.prompts.keys()).intersection(set(self.results.keys()))) # keys to preserve
66
+ self.prompts = {i: self.prompts[i] for i in k}
67
+ self.results = {i: self.results[i] for i in k}
68
+ log_sys_info()
69
+
70
+
71
+ db = CHAT_DB()