joaogante HF staff commited on
Commit
f7f857f
·
1 Parent(s): c3cbdc6

working demo

Browse files
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -3,6 +3,7 @@ from threading import Thread
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
6
 
7
  model_id = "EleutherAI/pythia-6.9b-deduped"
8
  assistant_id = "EleutherAI/pythia-70m-deduped"
@@ -19,7 +20,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
19
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)
20
 
21
 
22
- def run_generation(user_text, use_assistant, top_p, temperature, top_k, max_new_tokens):
23
  if temperature == 0.0:
24
  do_sample = False
25
  else:
@@ -37,19 +38,21 @@ def run_generation(user_text, use_assistant, top_p, temperature, top_k, max_new_
37
  streamer=streamer,
38
  max_new_tokens=max_new_tokens,
39
  do_sample=do_sample,
40
- top_p=top_p,
41
  temperature=float(temperature),
42
- top_k=top_k
43
  )
44
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
45
  t.start()
46
 
47
- # Pull the generated text from the streamer, and update the model output.
 
48
  model_output = ""
49
  for new_text in streamer:
50
  model_output += new_text
51
- yield model_output
52
- return model_output
53
 
54
 
55
  def reset_textbox():
@@ -59,36 +62,34 @@ def reset_textbox():
59
  with gr.Blocks() as demo:
60
  gr.Markdown(
61
  "# 🤗 Assisted Generation Demo\n"
62
- f"Model: {model_id} (using INT8)\n\n"
63
- f"Assistant Model: {assistant_id}"
64
  )
65
 
66
  with gr.Row():
67
  with gr.Column(scale=4):
68
  user_text = gr.Textbox(
69
  placeholder="Question: What is the meaning of life? Answer:",
70
- label="User input"
71
  )
72
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
73
  button_submit = gr.Button(value="Submit")
74
 
75
  with gr.Column(scale=1):
76
- use_assistant = gr.Checkbox(label="Use Assistant", default=True)
 
77
  max_new_tokens = gr.Slider(
78
- minimum=1, maximum=500, value=250, step=1, interactive=True, label="Max New Tokens",
79
- )
80
- top_p = gr.Slider(
81
- minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",
82
- )
83
- top_k = gr.Slider(
84
- minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
85
  )
86
  temperature = gr.Slider(
87
  minimum=0.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Temperature (0.0 = Greedy)",
88
  )
 
 
89
 
90
- generate_inputs = [user_text, use_assistant, top_p, temperature, top_k, max_new_tokens]
91
- user_text.submit(run_generation, generate_inputs, model_output)
92
- button_submit.click(run_generation, generate_inputs, model_output)
 
93
 
94
  demo.queue(max_size=32).launch(enable_queue=True)
 
3
  import torch
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+ import time
7
 
8
  model_id = "EleutherAI/pythia-6.9b-deduped"
9
  assistant_id = "EleutherAI/pythia-70m-deduped"
 
20
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)
21
 
22
 
23
+ def run_generation(user_text, use_assistant, temperature, max_new_tokens):
24
  if temperature == 0.0:
25
  do_sample = False
26
  else:
 
38
  streamer=streamer,
39
  max_new_tokens=max_new_tokens,
40
  do_sample=do_sample,
41
+ top_p=0.95,
42
  temperature=float(temperature),
43
+ top_k=50,
44
  )
45
  t = Thread(target=model.generate, kwargs=generate_kwargs)
46
+ start = time.time()
47
  t.start()
48
 
49
+ # Pull the generated text from the streamer, and update the model output. Return the model output and time
50
+ # spent so far.
51
  model_output = ""
52
  for new_text in streamer:
53
  model_output += new_text
54
+ yield [model_output, round(time.time() - start, 3)]
55
+ return [model_output, round(time.time() - start, 3)]
56
 
57
 
58
  def reset_textbox():
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown(
64
  "# 🤗 Assisted Generation Demo\n"
65
+ f"- Model: {model_id} (using INT8)\n"
66
+ f"- Assistant Model: {assistant_id}"
67
  )
68
 
69
  with gr.Row():
70
  with gr.Column(scale=4):
71
  user_text = gr.Textbox(
72
  placeholder="Question: What is the meaning of life? Answer:",
73
+ label="Prompt"
74
  )
75
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
76
  button_submit = gr.Button(value="Submit")
77
 
78
  with gr.Column(scale=1):
79
+ gr.Markdown("### Generation Settings")
80
+ use_assistant = gr.Checkbox(label="Use Assisted Generation", value=True)
81
  max_new_tokens = gr.Slider(
82
+ minimum=1, maximum=500, value=100, step=1, interactive=True, label="Max New Tokens",
 
 
 
 
 
 
83
  )
84
  temperature = gr.Slider(
85
  minimum=0.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Temperature (0.0 = Greedy)",
86
  )
87
+ gr.Markdown("### Generation time (seconds)")
88
+ generation_time = gr.Textbox(lines=1, interactive=False, show_label=False)
89
 
90
+ generate_inputs = [user_text, use_assistant, temperature, max_new_tokens]
91
+ generate_outputs = [model_output, generation_time]
92
+ user_text.submit(run_generation, generate_inputs, generate_outputs)
93
+ button_submit.click(run_generation, generate_inputs, generate_outputs)
94
 
95
  demo.queue(max_size=32).launch(enable_queue=True)