Spaces:
Runtime error
Runtime error
import time | |
from functools import lru_cache | |
import torch | |
import gradio as gr | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
# only cache the latest model | |
def get_model_and_tokenizer(model_id): | |
config = AutoConfig.from_pretrained(model_id) | |
if config.is_encoder_decoder: | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
return model, tokenizer | |
# cache up to 32k examples | |
def run_generation( | |
text, | |
model_id, | |
max_new_tokens, | |
alpha=0.0, | |
top_k=0, | |
num_beams=1, | |
do_sample=False, | |
top_p=0.0, | |
seed=0 | |
): | |
model, tokenizer = get_model_and_tokenizer(model_id) | |
inputs = tokenizer(text, return_tensors='pt') | |
if seed: | |
torch.manual_seed(seed) | |
start = time.time_ns() | |
contrastive_ids = model.generate( | |
# from the tokenizer | |
**inputs, | |
# fixed arguments | |
num_return_sequences=1, | |
early_stopping=True, | |
# variable arguments | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
num_beams=num_beams, | |
penalty_alpha=alpha or None, | |
top_k=top_k or None, | |
top_p=top_p or None, | |
) | |
end = time.time_ns() | |
contrastive_time = (end - start) / 1e6 | |
contrastive_text = tokenizer.decode(contrastive_ids[0], skip_special_tokens=True) | |
return contrastive_text, contrastive_time | |
def generate_beam_search(text, model_id, max_new_tokens, alpha, k, num_beams): | |
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k) | |
beam_search_text, beam_search_time = run_generation(text, model_id, max_new_tokens, num_beams=num_beams) | |
return contrastive_text, contrastive_time, beam_search_text, beam_search_time | |
def generate_top_k(text, model_id, max_new_tokens, alpha, k, top_k, seed): | |
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k) | |
top_k_text, top_k_time = run_generation( | |
text, model_id, max_new_tokens, top_k=top_k, seed=seed, do_sample=True | |
) | |
return contrastive_text, contrastive_time, top_k_text, top_k_time | |
def generate_top_p(text, model_id, max_new_tokens, alpha, k, top_p, seed): | |
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k) | |
top_p_text, top_p_time = run_generation( | |
text, model_id, max_new_tokens, top_p=top_p, seed=seed, do_sample=True | |
) | |
return contrastive_text, contrastive_time, top_p_text, top_p_time | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# Contrastive Search Generation comparison | |
Credits to the contrastive search generation [paper](https://arxiv.org/abs/2202.06417) authors, including | |
@[pangpang666](https://huggingface.co/pangpang666) and @[GMFTBY](https://huggingface.co/GMFTBY). Check out the | |
follow-up [work](https://arxiv.org/abs/2210.14140), which demonstrates the usefulness of the technique with | |
off-the-shelf LLMs, as well as their [HF guest blog post](https://huggingface.co/blog/introducing-csearch). | |
From the paper: | |
"At each decoding step, the key ideas of contrastive search are (i) the generated output should be selected | |
from the set of most probable candidates predicted by the model; and (ii) the generated output should be | |
discriminative enough with respect to the previous context. In this way, the generated text can (i) better | |
maintain the semantic coherence with respect to the prefix while (ii) avoiding model degeneration." | |
π¨ Warnings: π¨ | |
- Avoid using large models (> 1GB) in this demo. It will take a long time to load the model and generate text. | |
- Too slow/long queue? Check our | |
[colab](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/115_introducing_contrastive_search.ipynb) | |
instead. | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("vs. Beam Search"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Inputs βοΈ") | |
gr.Markdown("General options:") | |
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository") | |
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text") | |
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate") | |
gr.Markdown("Contrastive Search options:") | |
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha") | |
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K") | |
gr.Markdown("Beam Search options:") | |
num_beams = gr.Slider(value=4, minimum=1, maximum=16, step=1, label="Number of beams") | |
generate_button = gr.Button(value="Generate", label="Generate") | |
with gr.Column(): | |
gr.Markdown("## Outputs π€") | |
gr.Markdown("Contrastive Search generation:") | |
text_contrastive = gr.Textbox(value="", label="") | |
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
gr.Markdown("Beam Search generation:") | |
text_beam_search = gr.Textbox(value="", label="") | |
time_beam_search = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
# actions | |
generate_button.click( | |
fn=generate_beam_search, | |
inputs=[input_text, model_id, max_new_tokens, alpha, k, num_beams], | |
outputs=[text_contrastive, time_contrastive, text_beam_search, time_beam_search] | |
) | |
with gr.TabItem("vs. Top K Sampling"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Inputs βοΈ") | |
gr.Markdown("General options:") | |
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository") | |
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text") | |
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate") | |
gr.Markdown("Contrastive Search options:") | |
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha") | |
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K") | |
gr.Markdown("Sampling options:") | |
top_k = gr.Slider(value=50, minimum=1, maximum=100, step=1, label="Top K") | |
seed = gr.Number(value=42, precision=0, label="Seed") | |
generate_button = gr.Button(value="Generate", label="Generate") | |
with gr.Column(): | |
gr.Markdown("## Outputs π€") | |
gr.Markdown("Contrastive Search generation:") | |
text_contrastive = gr.Textbox(value="", label="") | |
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
gr.Markdown("Top K Sampling generation:") | |
text_top_k = gr.Textbox(value="", label="") | |
time_top_k = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
# actions | |
generate_button.click( | |
fn=generate_top_k, | |
inputs=[input_text, model_id, max_new_tokens, alpha, k, top_k, seed], | |
outputs=[text_contrastive, time_contrastive, text_top_k, time_top_k] | |
) | |
with gr.TabItem("vs. Nucleus Sampling"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Inputs βοΈ") | |
gr.Markdown("General options:") | |
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository") | |
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text") | |
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate") | |
gr.Markdown("Contrastive Search options:") | |
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha") | |
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K") | |
gr.Markdown("Sampling options:") | |
top_p = gr.Slider(value=0.95, minimum=0.01, maximum=1.0, step=0.01, label="Top P") | |
seed = gr.Number(value=42, precision=0, label="Seed") | |
generate_button = gr.Button(value="Generate", label="Generate") | |
with gr.Column(): | |
gr.Markdown("## Outputs π€") | |
gr.Markdown("Contrastive Search generation:") | |
text_contrastive = gr.Textbox(value="", label="") | |
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
gr.Markdown("Nucleus Sampling generation:") | |
text_top_p = gr.Textbox(value="", label="") | |
time_top_p = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
# actions | |
generate_button.click( | |
fn=generate_top_p, | |
inputs=[input_text, model_id, max_new_tokens, alpha, k, top_p, seed], | |
outputs=[text_contrastive, time_contrastive, text_top_p, time_top_p] | |
) | |
demo.launch() | |