joaogante HF staff commited on
Commit
a97bf6b
·
1 Parent(s): 07577ad

update model class

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -2,7 +2,7 @@ from threading import Thread
2
 
3
  import torch
4
  import gradio as gr
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
 
7
  model_id = "EleutherAI/pythia-6.9b-deduped"
8
  assistant_id = "EleutherAI/pythia-70m-deduped"
@@ -12,11 +12,11 @@ print("CPU threads:", torch.get_num_threads())
12
 
13
 
14
  if torch_device == "cuda":
15
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
16
  else:
17
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- assistant_model = AutoModelForSeq2SeqLM.from_pretrained(assistant_id).to(torch_device)
20
 
21
 
22
  def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
 
2
 
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
 
7
  model_id = "EleutherAI/pythia-6.9b-deduped"
8
  assistant_id = "EleutherAI/pythia-70m-deduped"
 
12
 
13
 
14
  if torch_device == "cuda":
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
16
  else:
17
+ model = AutoModelForCausalLM.from_pretrained(model_id)
18
  tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)
20
 
21
 
22
  def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):