Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,48 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
|
4 |
-
|
5 |
-
client = Client("dicta-il/dictalm2.0-instruct-demo")
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def chat_with_model(history):
|
8 |
-
# ืงืืืช ืืืืืขื ืืืืจืื ื ืืืืฉืชืืฉ
|
9 |
prompt = history[-1]["content"]
|
10 |
-
|
11 |
-
result = client.predict(
|
12 |
-
message=prompt
|
13 |
-
)
|
14 |
-
# ืืืกืคืช ืืชืืืื ืืืืืื ืืืืกืืืจืื
|
15 |
return history + [{"role": "bot", "content": result}]
|
16 |
|
17 |
# ืืฆืืจืช ืืืฉืง ืืชืงืื ืขื Gradio ืืฆืืจืช ืฆ'ื-ืืื ืืกืื ืื ืืงืืื
|
@@ -28,7 +59,6 @@ with gr.Blocks(theme="default") as demo:
|
|
28 |
send_button = gr.Button("ืฉืื")
|
29 |
|
30 |
def user_chat(history, message):
|
31 |
-
# ืืืกืคืช ืืืืขืช ืืืฉืชืืฉ ืืืืกืืืจืื
|
32 |
return history + [{"role": "user", "content": message}], ""
|
33 |
|
34 |
# ืฉืืืืช ืืืืืขื ืื ืืืืืฆื ืขื Enter ืืื ืขื ืืื ืืืืฆื ืขื ืืคืชืืจ "ืฉืื"
|
@@ -39,4 +69,4 @@ with gr.Blocks(theme="default") as demo:
|
|
39 |
fn=chat_with_model, inputs=chatbot, outputs=chatbot
|
40 |
)
|
41 |
|
42 |
-
demo.launch()
|
|
|
1 |
+
import deepspeed
|
2 |
+
import torch
|
3 |
+
from transformers import pipeline
|
4 |
+
import os
|
5 |
import gradio as gr
|
6 |
|
7 |
+
model_id = 'dicta-il/dictalm-7b-instruct'
|
|
|
8 |
|
9 |
+
# ืืขืื ืช ืืืืื ืืืื ืช ืืื ืืข
|
10 |
+
should_use_fast = True
|
11 |
+
print(f'should_use_fast = {should_use_fast}')
|
12 |
+
|
13 |
+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
|
14 |
+
world_size = int(os.getenv('WORLD_SIZE', '1'))
|
15 |
+
generator = pipeline('text-generation', model=model_id,
|
16 |
+
tokenizer=model_id,
|
17 |
+
torch_dtype=torch.float16,
|
18 |
+
use_fast=should_use_fast,
|
19 |
+
trust_remote_code=True,
|
20 |
+
device_map="auto")
|
21 |
+
|
22 |
+
# ืืืืงืช ืืชืงื - GPU ืื CPU
|
23 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
24 |
+
print('Using device:', device)
|
25 |
+
print()
|
26 |
+
|
27 |
+
total_mem = 0
|
28 |
+
if device.type == 'cuda':
|
29 |
+
print(torch.cuda.get_device_name(0))
|
30 |
+
total_mem = round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 1)
|
31 |
+
print('Total Memory: ', total_mem, 'GB')
|
32 |
+
|
33 |
+
should_replace_with_kernel_inject = total_mem >= 12
|
34 |
+
print(f'should_replace_with_kernel_inject = {should_replace_with_kernel_inject}')
|
35 |
+
|
36 |
+
ds_engine = deepspeed.init_inference(generator.model,
|
37 |
+
mp_size=world_size,
|
38 |
+
dtype=torch.half,
|
39 |
+
replace_with_kernel_inject=should_replace_with_kernel_inject)
|
40 |
+
generator.model = ds_engine.module
|
41 |
+
|
42 |
+
# ืคืื ืงืฆืืืช ืืฆืืจืช ืืืงืกื
|
43 |
def chat_with_model(history):
|
|
|
44 |
prompt = history[-1]["content"]
|
45 |
+
result = generator(prompt, do_sample=True, min_length=20, max_length=64, top_k=40, top_p=0.92, temperature=0.9)[0]["generated_text"]
|
|
|
|
|
|
|
|
|
46 |
return history + [{"role": "bot", "content": result}]
|
47 |
|
48 |
# ืืฆืืจืช ืืืฉืง ืืชืงืื ืขื Gradio ืืฆืืจืช ืฆ'ื-ืืื ืืกืื ืื ืืงืืื
|
|
|
59 |
send_button = gr.Button("ืฉืื")
|
60 |
|
61 |
def user_chat(history, message):
|
|
|
62 |
return history + [{"role": "user", "content": message}], ""
|
63 |
|
64 |
# ืฉืืืืช ืืืืืขื ืื ืืืืืฆื ืขื Enter ืืื ืขื ืืื ืืืืฆื ืขื ืืคืชืืจ "ืฉืื"
|
|
|
69 |
fn=chat_with_model, inputs=chatbot, outputs=chatbot
|
70 |
)
|
71 |
|
72 |
+
demo.launch()
|