s3nh commited on
Commit
96782a4
·
1 Parent(s): 1401d79

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -0
app.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import random
4
+ import time
5
+ from dataclasses import asdict, dataclass
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import psutil
10
+ from about_time import about_time
11
+ from ctransformers import AutoModelForCausalLM
12
+ from dl_hf_model import dl_hf_model
13
+ from loguru import logger
14
+
15
+
16
+ URL = "https://huggingface.co/s3nh/OpenOrca-Preview1-13B-GGML/resolve/main/OpenOrca-Preview1-13B.ggmlv3.q8_0.bin" # 4.05G
17
+
18
+ _ = (
19
+ "golay" in platform.node()
20
+ or "okteto" in platform.node()
21
+ or Path("/kaggle").exists()
22
+ # or psutil.cpu_count(logical=False) < 4
23
+ or 1 # run 7b in hf
24
+ )
25
+
26
+ if _:
27
+ url = "https://huggingface.co/s3nh/OpenOrca-Preview1-13B-GGML/resolve/main/OpenOrca-Preview1-13B.ggmlv3.q8_0.bin" # 2.87G
28
+
29
+
30
+ prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
31
+
32
+ ### Instruction: {user_prompt}
33
+
34
+ ### Response:
35
+ """
36
+
37
+ prompt_template = """System: You are a helpful,
38
+ respectful and honest assistant. Always answer as
39
+ helpfully as possible, while being safe. Your answers
40
+ should not include any harmful, unethical, racist,
41
+ sexist, toxic, dangerous, or illegal content. Please
42
+ ensure that your responses are socially unbiased and
43
+ positive in nature. If a question does not make any
44
+ sense, or is not factually coherent, explain why instead
45
+ of answering something not correct. If you don't know
46
+ the answer to a question, please don't share false
47
+ information.
48
+ User: {prompt}
49
+ Assistant: """
50
+
51
+ prompt_template = """System: You are a helpful assistant.
52
+ User: {prompt}
53
+ Assistant: """
54
+
55
+ prompt_template = """Question: {question}
56
+ Answer: Let's work this out in a step by step way to be sure we have the right answer."""
57
+
58
+ prompt_template = """[INST] <>
59
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible assistant. Think step by step.
60
+ <>
61
+
62
+ What NFL team won the Super Bowl in the year Justin Bieber was born?
63
+ [/INST]"""
64
+
65
+ prompt_template = """[INST] <<SYS>>
66
+ You are an unhelpful assistant. Always answer as helpfully as possible. Think step by step. <</SYS>>
67
+
68
+ {question} [/INST]
69
+ """
70
+
71
+ prompt_template = """[INST] <<SYS>>
72
+ You are a helpful assistant.
73
+ <</SYS>>
74
+
75
+ {question} [/INST]
76
+ """
77
+
78
+ prompt_template = """### HUMAN:
79
+ {question}
80
+
81
+ ### RESPONSE:"""
82
+
83
+ _ = [elm for elm in prompt_template.splitlines() if elm.strip()]
84
+ stop_string = [elm.split(":")[0] + ":" for elm in _][-2]
85
+
86
+ logger.debug(f"{stop_string=} not used")
87
+
88
+ _ = psutil.cpu_count(logical=False) - 1
89
+ cpu_count: int = int(_) if _ else 1
90
+ logger.debug(f"{cpu_count=}")
91
+
92
+ LLM = None
93
+
94
+ try:
95
+ model_loc, file_size = dl_hf_model(url)
96
+ except Exception as exc_:
97
+ logger.error(exc_)
98
+ raise SystemExit(1) from exc_
99
+
100
+ LLM = AutoModelForCausalLM.from_pretrained(
101
+ model_loc,
102
+ model_type="llama",
103
+ )
104
+
105
+ logger.info(f"done load llm {model_loc=} {file_size=}G")
106
+
107
+ os.environ["TZ"] = "Asia/Shanghai"
108
+ try:
109
+ time.tzset()
110
+
111
+ logger.warning("Windows, cant run time.tzset()")
112
+ except Exception:
113
+ logger.warning("Windows, cant run time.tzset()")
114
+
115
+
116
+ @dataclass
117
+ class GenerationConfig:
118
+ temperature: float = 0.7
119
+ top_k: int = 50
120
+ top_p: float = 0.9
121
+ repetition_penalty: float = 1.0
122
+ max_new_tokens: int = 512
123
+ seed: int = 42
124
+ reset: bool = False
125
+ stream: bool = True
126
+ # threads: int = cpu_count
127
+ # stop: list[str] = field(default_factory=lambda: [stop_string])
128
+
129
+
130
+ def generate(
131
+ question: str,
132
+ llm=LLM,
133
+ config: GenerationConfig = GenerationConfig(),
134
+ ):
135
+ """Run model inference, will return a Generator if streaming is true."""
136
+
137
+
138
+ prompt = prompt_template.format(question=question)
139
+
140
+ return llm(
141
+ prompt,
142
+ **asdict(config),
143
+ )
144
+
145
+
146
+ logger.debug(f"{asdict(GenerationConfig())=}")
147
+
148
+
149
+ def user(user_message, history):
150
+ history.append([user_message, None])
151
+ return user_message, history
152
+
153
+
154
+ def user1(user_message, history):
155
+ history.append([user_message, None])
156
+ return "", history
157
+
158
+ def bot_(history):
159
+ user_message = history[-1][0]
160
+ resp = random.choice(["How are you?", "I love you", "I'm very hungry"])
161
+ bot_message = user_message + ": " + resp
162
+ history[-1][1] = ""
163
+ for character in bot_message:
164
+ history[-1][1] += character
165
+ time.sleep(0.02)
166
+ yield history
167
+
168
+ history[-1][1] = resp
169
+ yield history
170
+
171
+
172
+ def bot(history):
173
+ user_message = history[-1][0]
174
+ response = []
175
+
176
+ logger.debug(f"{user_message=}")
177
+
178
+ with about_time() as atime:
179
+ flag = 1
180
+ prefix = ""
181
+ then = time.time()
182
+
183
+ logger.debug("about to generate")
184
+
185
+ config = GenerationConfig(reset=True)
186
+ for elm in generate(user_message, config=config):
187
+ if flag == 1:
188
+ logger.debug("in the loop")
189
+ prefix = f"({time.time() - then:.2f}s) "
190
+ flag = 0
191
+ print(prefix, end="", flush=True)
192
+ logger.debug(f"{prefix=}")
193
+ print(elm, end="", flush=True)
194
+
195
+ response.append(elm)
196
+ history[-1][1] = prefix + "".join(response)
197
+ yield history
198
+
199
+ _ = (
200
+ f"(time elapsed: {atime.duration_human}, "
201
+ f"{atime.duration/len(''.join(response)):.2f}s/char)"
202
+ )
203
+
204
+ history[-1][1] = "".join(response) + f"\n{_}"
205
+ yield history
206
+
207
+
208
+ def predict_api(prompt):
209
+ logger.debug(f"{prompt=}")
210
+ try:
211
+ # user_prompt = prompt
212
+ config = GenerationConfig(
213
+ temperature=0.2,
214
+ top_k=10,
215
+ top_p=0.9,
216
+ repetition_penalty=1.0,
217
+ max_new_tokens=512, # adjust as needed
218
+ seed=42,
219
+ reset=True,
220
+ stream=False,
221
+ )
222
+
223
+ response = generate(
224
+ prompt,
225
+ config=config,
226
+ )
227
+
228
+ logger.debug(f"api: {response=}")
229
+ except Exception as exc:
230
+ logger.error(exc)
231
+ response = f"{exc=}"
232
+ return response
233
+
234
+
235
+ css = """
236
+ .importantButton {
237
+ background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
238
+ border: none !important;
239
+ }
240
+ .importantButton:hover {
241
+ background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
242
+ border: none !important;
243
+ }
244
+ .disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;}
245
+ .xsmall {font-size: x-small;}
246
+ """
247
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
248
+ examples_list = [
249
+ ["Send an email requesting that people use language models responsibly."],
250
+ ["Write a shouting match between Julius Caesar and Napoleon"],
251
+ ["Write a theory to explain why cat never existed"],
252
+ ["write a story about a grain of sand as it watches millions of years go by"],
253
+ ["What are 3 popular chess openings?"],
254
+ ["write a conversation between the sun and pluto"],
255
+ ["Did you know that Yann LeCun dropped a rap album last year? We listened to it andhere’s what we thought:"],
256
+ ]
257
+
258
+ logger.info("start block")
259
+
260
+ with gr.Blocks(
261
+ title=f"{Path(model_loc).name}",
262
+ theme=gr.themes.Soft(text_size="sm", spacing_size="sm"),
263
+ css=css,
264
+ ) as block:
265
+ # buff_var = gr.State("")
266
+ with gr.Accordion("🎈 Info", open=False):
267
+ # gr.HTML(
268
+ # """<center><a href="https://huggingface.co/spaces/mikeee/mpt-30b-chat?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate"></a> and spin a CPU UPGRADE to avoid the queue</center>"""
269
+ # )
270
+ gr.Markdown(
271
+ f"""<h5><center>{Path(model_loc).name}</center></h4>
272
+ Most examples are meant for another model.
273
+ You probably should try to test
274
+ some related prompts.""",
275
+ elem_classes="xsmall",
276
+ )
277
+
278
+ # chatbot = gr.Chatbot().style(height=700) # 500
279
+ chatbot = gr.Chatbot(height=500)
280
+
281
+ # buff = gr.Textbox(show_label=False, visible=True)
282
+
283
+ with gr.Row():
284
+ with gr.Column(scale=5):
285
+ msg = gr.Textbox(
286
+ label="Chat Message Box",
287
+ placeholder="Ask me anything (press Shift+Enter or click Submit to send)",
288
+ show_label=False,
289
+ # container=False,
290
+ lines=6,
291
+ max_lines=30,
292
+ show_copy_button=True,
293
+ # ).style(container=False)
294
+ )
295
+ with gr.Column(scale=1, min_width=50):
296
+ with gr.Row():
297
+ submit = gr.Button("Submit", elem_classes="xsmall")
298
+ stop = gr.Button("Stop", visible=True)
299
+ clear = gr.Button("Clear History", visible=True)
300
+ with gr.Row(visible=False):
301
+ with gr.Accordion("Advanced Options:", open=False):
302
+ with gr.Row():
303
+ with gr.Column(scale=2):
304
+ system = gr.Textbox(
305
+ label="System Prompt",
306
+ value=prompt_template,
307
+ show_label=False,
308
+ container=False,
309
+ # ).style(container=False)
310
+ )
311
+ with gr.Column():
312
+ with gr.Row():
313
+ change = gr.Button("Change System Prompt")
314
+ reset = gr.Button("Reset System Prompt")
315
+
316
+ with gr.Accordion("Example Inputs", open=True):
317
+ examples = gr.Examples(
318
+ examples=examples_list,
319
+ inputs=[msg],
320
+ examples_per_page=40,
321
+ )
322
+
323
+ # with gr.Row():
324
+ with gr.Accordion("Disclaimer", open=False):
325
+ _ = Path(model_loc).name
326
+ gr.Markdown(
327
+ f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce "
328
+ "factually accurate information. {_} was trained on various public datasets; while great efforts "
329
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
330
+ "biased, or otherwise offensive outputs.",
331
+ elem_classes=["disclaimer"],
332
+ )
333
+
334
+ msg_submit_event = msg.submit(
335
+ # fn=conversation.user_turn,
336
+ fn=user,
337
+ inputs=[msg, chatbot],
338
+ outputs=[msg, chatbot],
339
+ queue=True,
340
+ show_progress="full",
341
+ # api_name=None,
342
+ ).then(bot, chatbot, chatbot, queue=True)
343
+ submit_click_event = submit.click(
344
+ # fn=lambda x, y: ("",) + user(x, y)[1:], # clear msg
345
+ fn=user1, # clear msg
346
+ inputs=[msg, chatbot],
347
+ outputs=[msg, chatbot],
348
+ queue=True,
349
+ # queue=False,
350
+ show_progress="full",
351
+ # api_name=None,
352
+ ).then(bot, chatbot, chatbot, queue=True)
353
+ stop.click(
354
+ fn=None,
355
+ inputs=None,
356
+ outputs=None,
357
+ cancels=[msg_submit_event, submit_click_event],
358
+ queue=False,
359
+ )
360
+ clear.click(lambda: None, None, chatbot, queue=False)
361
+
362
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
363
+ input_text = gr.Text()
364
+ api_btn = gr.Button("Go", variant="primary")
365
+ out_text = gr.Text()
366
+
367
+ api_btn.click(
368
+ predict_api,
369
+ input_text,
370
+ out_text,
371
+ api_name="api",
372
+ )
373
+
374
+ # block.load(update_buff, [], buff, every=1)
375
+ # block.load(update_buff, [buff_var], [buff_var, buff], every=1)
376
+
377
+ # concurrency_count=5, max_size=20
378
+ # max_size=36, concurrency_count=14
379
+ # CPU cpu_count=2 16G, model 7G
380
+ # CPU UPGRADE cpu_count=8 32G, model 7G
381
+
382
+ # does not work
383
+ _ = """
384
+ # _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1)
385
+ # concurrency_count = max(_, 1)
386
+ if psutil.cpu_count(logical=False) >= 8:
387
+ # concurrency_count = max(int(32 / file_size) - 1, 1)
388
+ else:
389
+ # concurrency_count = max(int(16 / file_size) - 1, 1)
390
+ # """
391
+
392
+ concurrency_count = 1
393
+ logger.info(f"{concurrency_count=}")
394
+
395
+ block.queue(concurrency_count=concurrency_count, max_size=5).launch(debug=True)