Spaces:
Runtime error
Runtime error
Duplicate from joaogante/contrastive_search_generation
Browse filesCo-authored-by: Joao Gante <[email protected]>
- .gitattributes +31 -0
- README.md +14 -0
- app.py +204 -0
- requirements.txt +3 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Constrastive Search Generation
|
3 |
+
emoji: π
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.5
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
duplicated_from: joaogante/contrastive_search_generation
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
7 |
+
|
8 |
+
|
9 |
+
@lru_cache(maxsize=1) # only cache the latest model
|
10 |
+
def get_model_and_tokenizer(model_id):
|
11 |
+
config = AutoConfig.from_pretrained(model_id)
|
12 |
+
if config.is_encoder_decoder:
|
13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
14 |
+
else:
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
16 |
+
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
+
return model, tokenizer
|
19 |
+
|
20 |
+
|
21 |
+
@lru_cache(maxsize=32768) # cache up to 32k examples
|
22 |
+
def run_generation(
|
23 |
+
text,
|
24 |
+
model_id,
|
25 |
+
max_new_tokens,
|
26 |
+
alpha=0.0,
|
27 |
+
top_k=0,
|
28 |
+
num_beams=1,
|
29 |
+
do_sample=False,
|
30 |
+
top_p=0.0,
|
31 |
+
seed=0
|
32 |
+
):
|
33 |
+
model, tokenizer = get_model_and_tokenizer(model_id)
|
34 |
+
|
35 |
+
inputs = tokenizer(text, return_tensors='pt')
|
36 |
+
if seed:
|
37 |
+
torch.manual_seed(seed)
|
38 |
+
|
39 |
+
start = time.time_ns()
|
40 |
+
contrastive_ids = model.generate(
|
41 |
+
# from the tokenizer
|
42 |
+
**inputs,
|
43 |
+
# fixed arguments
|
44 |
+
num_return_sequences=1,
|
45 |
+
early_stopping=True,
|
46 |
+
# variable arguments
|
47 |
+
max_new_tokens=max_new_tokens,
|
48 |
+
do_sample=do_sample,
|
49 |
+
num_beams=num_beams,
|
50 |
+
penalty_alpha=alpha or None,
|
51 |
+
top_k=top_k or None,
|
52 |
+
top_p=top_p or None,
|
53 |
+
)
|
54 |
+
end = time.time_ns()
|
55 |
+
|
56 |
+
contrastive_time = (end - start) / 1e6
|
57 |
+
contrastive_text = tokenizer.decode(contrastive_ids[0], skip_special_tokens=True)
|
58 |
+
return contrastive_text, contrastive_time
|
59 |
+
|
60 |
+
|
61 |
+
def generate_beam_search(text, model_id, max_new_tokens, alpha, k, num_beams):
|
62 |
+
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
|
63 |
+
beam_search_text, beam_search_time = run_generation(text, model_id, max_new_tokens, num_beams=num_beams)
|
64 |
+
return contrastive_text, contrastive_time, beam_search_text, beam_search_time
|
65 |
+
|
66 |
+
|
67 |
+
def generate_top_k(text, model_id, max_new_tokens, alpha, k, top_k, seed):
|
68 |
+
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
|
69 |
+
top_k_text, top_k_time = run_generation(
|
70 |
+
text, model_id, max_new_tokens, top_k=top_k, seed=seed, do_sample=True
|
71 |
+
)
|
72 |
+
return contrastive_text, contrastive_time, top_k_text, top_k_time
|
73 |
+
|
74 |
+
|
75 |
+
def generate_top_p(text, model_id, max_new_tokens, alpha, k, top_p, seed):
|
76 |
+
contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
|
77 |
+
top_p_text, top_p_time = run_generation(
|
78 |
+
text, model_id, max_new_tokens, top_p=top_p, seed=seed, do_sample=True
|
79 |
+
)
|
80 |
+
return contrastive_text, contrastive_time, top_p_text, top_p_time
|
81 |
+
|
82 |
+
|
83 |
+
demo = gr.Blocks()
|
84 |
+
|
85 |
+
with demo:
|
86 |
+
gr.Markdown(
|
87 |
+
"""
|
88 |
+
# Contrastive Search Generation comparison
|
89 |
+
|
90 |
+
Credits to the contrastive search generation [paper](https://arxiv.org/abs/2202.06417) authors, including
|
91 |
+
@[pangpang666](https://huggingface.co/pangpang666) and @[GMFTBY](https://huggingface.co/GMFTBY). Check out the
|
92 |
+
follow-up [work](https://arxiv.org/abs/2210.14140), which demonstrates the usefulness of the technique with
|
93 |
+
off-the-shelf LLMs, as well as their [HF guest blog post](https://huggingface.co/blog/introducing-csearch).
|
94 |
+
|
95 |
+
From the paper:
|
96 |
+
"At each decoding step, the key ideas of contrastive search are (i) the generated output should be selected
|
97 |
+
from the set of most probable candidates predicted by the model; and (ii) the generated output should be
|
98 |
+
discriminative enough with respect to the previous context. In this way, the generated text can (i) better
|
99 |
+
maintain the semantic coherence with respect to the prefix while (ii) avoiding model degeneration."
|
100 |
+
|
101 |
+
π¨ Warnings: π¨
|
102 |
+
- Avoid using large models (> 1GB) in this demo. It will take a long time to load the model and generate text.
|
103 |
+
- Too slow/long queue? Check our
|
104 |
+
[colab](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/115_introducing_contrastive_search.ipynb)
|
105 |
+
instead.
|
106 |
+
"""
|
107 |
+
)
|
108 |
+
with gr.Tabs():
|
109 |
+
with gr.TabItem("vs. Beam Search"):
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
gr.Markdown("## Inputs βοΈ")
|
113 |
+
gr.Markdown("General options:")
|
114 |
+
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
|
115 |
+
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
|
116 |
+
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
|
117 |
+
gr.Markdown("Contrastive Search options:")
|
118 |
+
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
|
119 |
+
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
|
120 |
+
gr.Markdown("Beam Search options:")
|
121 |
+
num_beams = gr.Slider(value=4, minimum=1, maximum=16, step=1, label="Number of beams")
|
122 |
+
generate_button = gr.Button(value="Generate", label="Generate")
|
123 |
+
|
124 |
+
with gr.Column():
|
125 |
+
gr.Markdown("## Outputs π€")
|
126 |
+
gr.Markdown("Contrastive Search generation:")
|
127 |
+
text_contrastive = gr.Textbox(value="", label="")
|
128 |
+
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
|
129 |
+
gr.Markdown("Beam Search generation:")
|
130 |
+
text_beam_search = gr.Textbox(value="", label="")
|
131 |
+
time_beam_search = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
|
132 |
+
|
133 |
+
# actions
|
134 |
+
generate_button.click(
|
135 |
+
fn=generate_beam_search,
|
136 |
+
inputs=[input_text, model_id, max_new_tokens, alpha, k, num_beams],
|
137 |
+
outputs=[text_contrastive, time_contrastive, text_beam_search, time_beam_search]
|
138 |
+
)
|
139 |
+
|
140 |
+
with gr.TabItem("vs. Top K Sampling"):
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column():
|
143 |
+
gr.Markdown("## Inputs βοΈ")
|
144 |
+
gr.Markdown("General options:")
|
145 |
+
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
|
146 |
+
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
|
147 |
+
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
|
148 |
+
gr.Markdown("Contrastive Search options:")
|
149 |
+
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
|
150 |
+
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
|
151 |
+
gr.Markdown("Sampling options:")
|
152 |
+
top_k = gr.Slider(value=50, minimum=1, maximum=100, step=1, label="Top K")
|
153 |
+
seed = gr.Number(value=42, precision=0, label="Seed")
|
154 |
+
generate_button = gr.Button(value="Generate", label="Generate")
|
155 |
+
|
156 |
+
with gr.Column():
|
157 |
+
gr.Markdown("## Outputs π€")
|
158 |
+
gr.Markdown("Contrastive Search generation:")
|
159 |
+
text_contrastive = gr.Textbox(value="", label="")
|
160 |
+
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
|
161 |
+
gr.Markdown("Top K Sampling generation:")
|
162 |
+
text_top_k = gr.Textbox(value="", label="")
|
163 |
+
time_top_k = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
|
164 |
+
|
165 |
+
# actions
|
166 |
+
generate_button.click(
|
167 |
+
fn=generate_top_k,
|
168 |
+
inputs=[input_text, model_id, max_new_tokens, alpha, k, top_k, seed],
|
169 |
+
outputs=[text_contrastive, time_contrastive, text_top_k, time_top_k]
|
170 |
+
)
|
171 |
+
|
172 |
+
with gr.TabItem("vs. Nucleus Sampling"):
|
173 |
+
with gr.Row():
|
174 |
+
with gr.Column():
|
175 |
+
gr.Markdown("## Inputs βοΈ")
|
176 |
+
gr.Markdown("General options:")
|
177 |
+
model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
|
178 |
+
input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
|
179 |
+
max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
|
180 |
+
gr.Markdown("Contrastive Search options:")
|
181 |
+
alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
|
182 |
+
k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
|
183 |
+
gr.Markdown("Sampling options:")
|
184 |
+
top_p = gr.Slider(value=0.95, minimum=0.01, maximum=1.0, step=0.01, label="Top P")
|
185 |
+
seed = gr.Number(value=42, precision=0, label="Seed")
|
186 |
+
generate_button = gr.Button(value="Generate", label="Generate")
|
187 |
+
|
188 |
+
with gr.Column():
|
189 |
+
gr.Markdown("## Outputs π€")
|
190 |
+
gr.Markdown("Contrastive Search generation:")
|
191 |
+
text_contrastive = gr.Textbox(value="", label="")
|
192 |
+
time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
|
193 |
+
gr.Markdown("Nucleus Sampling generation:")
|
194 |
+
text_top_p = gr.Textbox(value="", label="")
|
195 |
+
time_top_p = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
|
196 |
+
|
197 |
+
# actions
|
198 |
+
generate_button.click(
|
199 |
+
fn=generate_top_p,
|
200 |
+
inputs=[input_text, model_id, max_new_tokens, alpha, k, top_p, seed],
|
201 |
+
outputs=[text_contrastive, time_contrastive, text_top_p, time_top_p]
|
202 |
+
)
|
203 |
+
|
204 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/transformers@main#egg=transformers
|
2 |
+
sentencepiece
|
3 |
+
torch
|