Nitzantry1 commited on
Commit
b99a3bf
ยท
verified ยท
1 Parent(s): 51eae2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -11
app.py CHANGED
@@ -1,17 +1,48 @@
1
- from gradio_client import Client
 
 
 
2
  import gradio as gr
3
 
4
- # ื—ื™ื‘ื•ืจ ืœ-Space ืขื ื”ืžื•ื“ืœ ื‘-Hugging Face
5
- client = Client("dicta-il/dictalm2.0-instruct-demo")
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def chat_with_model(history):
8
- # ืงื‘ืœืช ื”ื”ื•ื“ืขื” ื”ืื—ืจื•ื ื” ืžื”ืžืฉืชืžืฉ
9
  prompt = history[-1]["content"]
10
- # ืงืจื™ืื” ืœืžื•ื“ืœ ืœืงื‘ืœืช ืชื’ื•ื‘ื”
11
- result = client.predict(
12
- message=prompt
13
- )
14
- # ื”ื•ืกืคืช ื”ืชื’ื•ื‘ื” ืžื”ืžื•ื“ืœ ืœื”ื™ืกื˜ื•ืจื™ื”
15
  return history + [{"role": "bot", "content": result}]
16
 
17
  # ื™ืฆื™ืจืช ืžืžืฉืง ืžืชืงื“ื ืขื Gradio ื‘ืฆื•ืจืช ืฆ'ื˜-ื‘ื•ื˜ ื‘ืกื’ื ื•ืŸ ืืงื“ืžื™
@@ -28,7 +59,6 @@ with gr.Blocks(theme="default") as demo:
28
  send_button = gr.Button("ืฉืœื—")
29
 
30
  def user_chat(history, message):
31
- # ื”ื•ืกืคืช ื”ื•ื“ืขืช ื”ืžืฉืชืžืฉ ืœื”ื™ืกื˜ื•ืจื™ื”
32
  return history + [{"role": "user", "content": message}], ""
33
 
34
  # ืฉืœื™ื—ืช ื”ื”ื•ื“ืขื” ื’ื ื‘ืœื—ื™ืฆื” ืขืœ Enter ื•ื’ื ืขืœ ื™ื“ื™ ืœื—ื™ืฆื” ืขืœ ื›ืคืชื•ืจ "ืฉืœื—"
@@ -39,4 +69,4 @@ with gr.Blocks(theme="default") as demo:
39
  fn=chat_with_model, inputs=chatbot, outputs=chatbot
40
  )
41
 
42
- demo.launch()
 
1
+ import deepspeed
2
+ import torch
3
+ from transformers import pipeline
4
+ import os
5
  import gradio as gr
6
 
7
+ model_id = 'dicta-il/dictalm-7b-instruct'
 
8
 
9
+ # ื˜ืขื™ื ืช ื”ืžื•ื“ืœ ื•ื”ื›ื ืช ื”ืžื ื•ืข
10
+ should_use_fast = True
11
+ print(f'should_use_fast = {should_use_fast}')
12
+
13
+ local_rank = int(os.getenv('LOCAL_RANK', '0'))
14
+ world_size = int(os.getenv('WORLD_SIZE', '1'))
15
+ generator = pipeline('text-generation', model=model_id,
16
+ tokenizer=model_id,
17
+ torch_dtype=torch.float16,
18
+ use_fast=should_use_fast,
19
+ trust_remote_code=True,
20
+ device_map="auto")
21
+
22
+ # ื‘ื“ื™ืงืช ื”ืชืงืŸ - GPU ืื• CPU
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ print('Using device:', device)
25
+ print()
26
+
27
+ total_mem = 0
28
+ if device.type == 'cuda':
29
+ print(torch.cuda.get_device_name(0))
30
+ total_mem = round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 1)
31
+ print('Total Memory: ', total_mem, 'GB')
32
+
33
+ should_replace_with_kernel_inject = total_mem >= 12
34
+ print(f'should_replace_with_kernel_inject = {should_replace_with_kernel_inject}')
35
+
36
+ ds_engine = deepspeed.init_inference(generator.model,
37
+ mp_size=world_size,
38
+ dtype=torch.half,
39
+ replace_with_kernel_inject=should_replace_with_kernel_inject)
40
+ generator.model = ds_engine.module
41
+
42
+ # ืคื•ื ืงืฆื™ื™ืช ื™ืฆื™ืจืช ื”ื˜ืงืกื˜
43
  def chat_with_model(history):
 
44
  prompt = history[-1]["content"]
45
+ result = generator(prompt, do_sample=True, min_length=20, max_length=64, top_k=40, top_p=0.92, temperature=0.9)[0]["generated_text"]
 
 
 
 
46
  return history + [{"role": "bot", "content": result}]
47
 
48
  # ื™ืฆื™ืจืช ืžืžืฉืง ืžืชืงื“ื ืขื Gradio ื‘ืฆื•ืจืช ืฆ'ื˜-ื‘ื•ื˜ ื‘ืกื’ื ื•ืŸ ืืงื“ืžื™
 
59
  send_button = gr.Button("ืฉืœื—")
60
 
61
  def user_chat(history, message):
 
62
  return history + [{"role": "user", "content": message}], ""
63
 
64
  # ืฉืœื™ื—ืช ื”ื”ื•ื“ืขื” ื’ื ื‘ืœื—ื™ืฆื” ืขืœ Enter ื•ื’ื ืขืœ ื™ื“ื™ ืœื—ื™ืฆื” ืขืœ ื›ืคืชื•ืจ "ืฉืœื—"
 
69
  fn=chat_with_model, inputs=chatbot, outputs=chatbot
70
  )
71
 
72
+ demo.launch()