File size: 3,642 Bytes
a1a543e
 
 
 
a97bf6b
f7f857f
a1a543e
4ffe071
57756de
a1a543e
fdf5fbd
 
a1a543e
 
f7f857f
7bb9775
c3cbdc6
 
 
 
a1a543e
ef976dc
a1a543e
 
 
 
 
 
c3cbdc6
a1a543e
 
c3cbdc6
f7f857f
a1a543e
f7f857f
e8e07e2
a1a543e
 
f7f857f
a1a543e
 
f7f857f
 
a1a543e
 
 
f7f857f
 
a1a543e
 
 
 
 
 
 
 
588b2d4
c93ec2a
57756de
817dcba
a1a543e
 
 
336d41b
 
abbc475
336d41b
 
 
 
 
 
f7f857f
 
a1a543e
f7f857f
a1a543e
 
c3cbdc6
a1a543e
f7f857f
 
a1a543e
f7f857f
 
 
 
a1a543e
abbc475
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import time

model_id = "facebook/opt-6.7b"
assistant_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(device=model.device, dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

def run_generation(user_text, use_assistant, temperature, max_new_tokens):
    if temperature < 0.1:
        do_sample = False
    else:
        do_sample = True

    # Get the model and tokenizer, and tokenize the user text.
    model_inputs = tokenizer([user_text], return_tensors="pt").to(model.device)

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        assistant_model=assistant_model if use_assistant else None,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        top_p=0.95,
        temperature=float(temperature),
        top_k=50,
        eos_token_id=-1,  # ensures `max_new_tokens` new tokens are always generated, can't reach EOS
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    start = time.time()
    t.start()

    # Pull the generated text from the streamer, and update the model output. Return the model output and time
    # spent so far.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield [model_output, round(time.time() - start, 3)]
    return [model_output, round(time.time() - start, 3)]


def reset_textbox():
    return gr.update(value='')


with gr.Blocks() as demo:
    gr.Markdown(
        "# 🤗 Assisted Generation Demo\n"
        f"- Model: {model_id} (INT8, ~7GB)\n"
        f"- Assistant Model: {assistant_id} (FP16, ~0.3GB)\n"
        "- Recipe for speedup: a) >10x model size difference in parameters; b) assistant trained similarly; c) CPU is not a bottleneck"
    )

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                value="A sequence: one, two, three, ",
                label="Prompt"
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1, min_width=200):
            gr.Markdown("### Generation Settings")
            use_assistant = gr.Checkbox(label="Use Assisted Generation", value=True)
            max_new_tokens = gr.Slider(
                minimum=1, maximum=500, value=100, step=1, interactive=True, label="Max New Tokens",
            )
            temperature = gr.Slider(
                minimum=0.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Temperature (0.0 = Greedy)",
            )
            gr.Markdown("### Generation time (seconds)")
            generation_time = gr.Textbox(lines=1, interactive=False, show_label=False)

    generate_inputs = [user_text, use_assistant, temperature, max_new_tokens]
    generate_outputs = [model_output, generation_time]
    user_text.submit(run_generation, generate_inputs, generate_outputs)
    button_submit.click(run_generation, generate_inputs, generate_outputs)

    demo.queue(max_size=16).launch()