Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import spaces
|
2 |
from threading import Thread
|
3 |
|
4 |
import torch
|
@@ -7,13 +6,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
|
|
7 |
import time
|
8 |
|
9 |
model_id = "facebook/opt-6.7b"
|
10 |
-
assistant_id = "facebook/opt-
|
11 |
|
12 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
13 |
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.bfloat16)
|
14 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
15 |
|
16 |
-
@spaces.GPU
|
17 |
def run_generation(user_text, use_assistant, temperature, max_new_tokens):
|
18 |
if temperature < 0.1:
|
19 |
do_sample = False
|
@@ -31,6 +29,7 @@ def run_generation(user_text, use_assistant, temperature, max_new_tokens):
|
|
31 |
assistant_model=assistant_model if use_assistant else None,
|
32 |
streamer=streamer,
|
33 |
max_new_tokens=max_new_tokens,
|
|
|
34 |
do_sample=do_sample,
|
35 |
top_p=0.95,
|
36 |
temperature=float(temperature),
|
@@ -65,7 +64,7 @@ with gr.Blocks() as demo:
|
|
65 |
with gr.Row():
|
66 |
with gr.Column(scale=4):
|
67 |
user_text = gr.Textbox(
|
68 |
-
|
69 |
label="Prompt"
|
70 |
)
|
71 |
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
|
@@ -88,4 +87,4 @@ with gr.Blocks() as demo:
|
|
88 |
user_text.submit(run_generation, generate_inputs, generate_outputs)
|
89 |
button_submit.click(run_generation, generate_inputs, generate_outputs)
|
90 |
|
91 |
-
demo.queue(max_size=
|
|
|
|
|
1 |
from threading import Thread
|
2 |
|
3 |
import torch
|
|
|
6 |
import time
|
7 |
|
8 |
model_id = "facebook/opt-6.7b"
|
9 |
+
assistant_id = "facebook/opt-350m"
|
10 |
|
11 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
12 |
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.bfloat16)
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
14 |
|
|
|
15 |
def run_generation(user_text, use_assistant, temperature, max_new_tokens):
|
16 |
if temperature < 0.1:
|
17 |
do_sample = False
|
|
|
29 |
assistant_model=assistant_model if use_assistant else None,
|
30 |
streamer=streamer,
|
31 |
max_new_tokens=max_new_tokens,
|
32 |
+
min_new_tokens=max_new_tokens,
|
33 |
do_sample=do_sample,
|
34 |
top_p=0.95,
|
35 |
temperature=float(temperature),
|
|
|
64 |
with gr.Row():
|
65 |
with gr.Column(scale=4):
|
66 |
user_text = gr.Textbox(
|
67 |
+
value="A sequence: one, two, three, ",
|
68 |
label="Prompt"
|
69 |
)
|
70 |
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
|
|
|
87 |
user_text.submit(run_generation, generate_inputs, generate_outputs)
|
88 |
button_submit.click(run_generation, generate_inputs, generate_outputs)
|
89 |
|
90 |
+
demo.queue(max_size=16).launch()
|