Nitzantry1 commited on
Commit
c102310
verified
1 Parent(s): 82a7ce6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -11,18 +11,29 @@ print(f'should_use_fast = {should_use_fast}')
11
 
12
  local_rank = int(os.getenv('LOCAL_RANK', '0'))
13
  world_size = int(os.getenv('WORLD_SIZE', '1'))
14
- generator = pipeline('text-generation', model=model_id,
15
- tokenizer=model_id,
16
- torch_dtype=torch.float16,
17
- use_fast=should_use_fast,
18
- trust_remote_code=True,
19
- device_map="auto")
20
 
21
  # 讘讚讬拽转 讛转拽谉 - GPU 讗讜 CPU
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  print('Using device:', device)
24
  print()
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # 驻讜谞拽爪讬讬转 讬爪讬专转 讛讟拽住讟
27
  def chat_with_model(history):
28
  prompt = history[-1]["content"]
 
11
 
12
  local_rank = int(os.getenv('LOCAL_RANK', '0'))
13
  world_size = int(os.getenv('WORLD_SIZE', '1'))
 
 
 
 
 
 
14
 
15
  # 讘讚讬拽转 讛转拽谉 - GPU 讗讜 CPU
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
  print('Using device:', device)
18
  print()
19
 
20
+ # 讬爪讬专转 讛诪谞讜注 注诐 Accelerate 讘诪讬讚转 讛爪讜专讱
21
+ if device.type == 'cuda':
22
+ generator = pipeline('text-generation', model=model_id,
23
+ tokenizer=model_id,
24
+ torch_dtype=torch.float16,
25
+ use_fast=should_use_fast,
26
+ trust_remote_code=True,
27
+ device_map="auto")
28
+ else:
29
+ from accelerate import init_empty_weights, infer_auto_device_map
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=should_use_fast)
33
+ with init_empty_weights():
34
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
35
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device)
36
+
37
  # 驻讜谞拽爪讬讬转 讬爪讬专转 讛讟拽住讟
38
  def chat_with_model(history):
39
  prompt = history[-1]["content"]