joaogante HF staff commited on
Commit
abbc475
·
verified ·
1 Parent(s): a304cf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  from threading import Thread
3
 
4
  import torch
@@ -7,13 +6,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
7
  import time
8
 
9
  model_id = "facebook/opt-6.7b"
10
- assistant_id = "facebook/opt-125m"
11
 
12
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
13
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.bfloat16)
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
 
16
- @spaces.GPU
17
  def run_generation(user_text, use_assistant, temperature, max_new_tokens):
18
  if temperature < 0.1:
19
  do_sample = False
@@ -31,6 +29,7 @@ def run_generation(user_text, use_assistant, temperature, max_new_tokens):
31
  assistant_model=assistant_model if use_assistant else None,
32
  streamer=streamer,
33
  max_new_tokens=max_new_tokens,
 
34
  do_sample=do_sample,
35
  top_p=0.95,
36
  temperature=float(temperature),
@@ -65,7 +64,7 @@ with gr.Blocks() as demo:
65
  with gr.Row():
66
  with gr.Column(scale=4):
67
  user_text = gr.Textbox(
68
- placeholder="A sequence: one, two, three, ",
69
  label="Prompt"
70
  )
71
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
@@ -88,4 +87,4 @@ with gr.Blocks() as demo:
88
  user_text.submit(run_generation, generate_inputs, generate_outputs)
89
  button_submit.click(run_generation, generate_inputs, generate_outputs)
90
 
91
- demo.queue(max_size=32).launch()
 
 
1
  from threading import Thread
2
 
3
  import torch
 
6
  import time
7
 
8
  model_id = "facebook/opt-6.7b"
9
+ assistant_id = "facebook/opt-350m"
10
 
11
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
12
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.bfloat16)
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
 
 
15
  def run_generation(user_text, use_assistant, temperature, max_new_tokens):
16
  if temperature < 0.1:
17
  do_sample = False
 
29
  assistant_model=assistant_model if use_assistant else None,
30
  streamer=streamer,
31
  max_new_tokens=max_new_tokens,
32
+ min_new_tokens=max_new_tokens,
33
  do_sample=do_sample,
34
  top_p=0.95,
35
  temperature=float(temperature),
 
64
  with gr.Row():
65
  with gr.Column(scale=4):
66
  user_text = gr.Textbox(
67
+ value="A sequence: one, two, three, ",
68
  label="Prompt"
69
  )
70
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
 
87
  user_text.submit(run_generation, generate_inputs, generate_outputs)
88
  button_submit.click(run_generation, generate_inputs, generate_outputs)
89
 
90
+ demo.queue(max_size=16).launch()