Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pprint | |
import subprocess | |
from threading import Thread | |
from transformers import AutoTokenizer, TextIteratorStreamer | |
result = subprocess.run(["lscpu"], text=True, capture_output=True) | |
pprint.pprint(result.stdout) | |
checkpoint = "suriya7/Gemma-2b-SFT" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
def run_generation(user_text, top_p, temperature, top_k, max_new_tokens): | |
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
{} | |
### Input: | |
{} | |
### Response: | |
{}""" | |
inputs = tokenizer( | |
[ | |
alpaca_prompt.format( | |
"You are an AI assistant. Please ensure that the answers conclude with an end-of-sequence (EOS) token.", # instruction | |
user_text, # input goes here | |
"", # output - leave this blank for generation! | |
) | |
], return_tensors = "pt",return_dict=True) | |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer | |
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread. | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=float(temperature), | |
top_k=top_k, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Pull the generated text from the streamer, and update the model output. | |
model_output = "" | |
for new_text in streamer: | |
model_output += new_text | |
yield model_output | |
return model_output | |
def reset_textbox(): | |
return gr.update(value="") | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=4): | |
user_text = gr.Textbox( | |
label="User input", | |
) | |
model_output = gr.Textbox(label="Model output", lines=10, interactive=False) | |
button_submit = gr.Button(value="Submit") | |
with gr.Column(scale=1): | |
max_new_tokens = gr.Slider( | |
minimum=1, | |
maximum=1000, | |
value=250, | |
step=1, | |
interactive=True, | |
label="Max New Tokens", | |
) | |
top_p = gr.Slider( | |
minimum=0.05, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
interactive=True, | |
label="Top-p (nucleus sampling)", | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=50, | |
step=1, | |
interactive=True, | |
label="Top-k", | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=5.0, | |
value=0.8, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
user_text.submit( | |
run_generation, | |
[user_text, top_p, temperature, top_k, max_new_tokens], | |
model_output, | |
) | |
button_submit.click( | |
run_generation, | |
[user_text, top_p, temperature, top_k, max_new_tokens], | |
model_output, | |
) | |
demo.queue(max_size=32).launch(enable_queue=True, server_name="0.0.0.0") | |
# For local use: | |
# demo.launch(server_name="0.0.0.0") | |