Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import json | |
import os | |
from screenshot import ( | |
before_prompt, | |
prompt_to_generation, | |
after_generation, | |
js_save, | |
js_load_script, | |
) | |
from spaces_info import description, examples, initial_prompt_value | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, set_seed | |
#API_URL = os.getenv("API_URL") | |
#HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
def inference(input_sentence, max_length, sample_or_greedy, seed=42): | |
#print("input_sentence", input_sentence) | |
if sample_or_greedy == "Sample": | |
parameters = { | |
"max_new_tokens": max_length, | |
"top_p": 0.9, | |
"do_sample": True, | |
#"seed": seed, | |
"early_stopping": False, | |
"length_penalty": 0.0, | |
"eos_token_id": None, | |
} | |
else: | |
parameters = { | |
"max_new_tokens": max_length, | |
"do_sample": False, | |
#"seed": seed, | |
"early_stopping": False, | |
"length_penalty": 0.0, | |
"eos_token_id": None, | |
} | |
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} } | |
model_name = 'bigscience/bloomz-560m' | |
pipe = pipeline("text-generation", | |
model = model_name, | |
tokenizer = model_name, | |
max_new_tokens = max_length, | |
do_sample = False, | |
length_penalty = 0.0, | |
early_stopping = False, | |
eos_token_id = None | |
) | |
res = pipe(input_sentence) | |
#data = query(payload) | |
#if "error" in data: | |
# return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>") | |
#generation = data[0]["generated_text"].split(input_sentence, 1)[1] | |
generation = res[0]["generated_text"].split(input_sentence, 1)[1] | |
print(generation) | |
return ( | |
before_prompt | |
+ input_sentence | |
+ prompt_to_generation | |
+ generation | |
+ after_generation, | |
res[0]["generated_text"], | |
"", | |
) | |
#return generation | |
if __name__ == "__main__": | |
demo = gr.Blocks() | |
with demo: | |
with gr.Row(): | |
gr.Markdown(value=description) | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Textbox( | |
label="Input", | |
value=" ", # should be set to " " when plugged into a real API | |
) | |
tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate") | |
sampling = gr.Radio( | |
["Sample", "Greedy"], label="Sample or greedy", value="Sample" | |
) | |
''' | |
sampling2 = gr.Radio( | |
["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"], | |
value="Sample 1", | |
label="Sample other generations (only work in 'Sample' mode)", | |
type="index", | |
) | |
''' | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
load_image = gr.Button("Generate Image") | |
with gr.Column(): | |
text_error = gr.Markdown(label="Log information") | |
text_out = gr.Textbox(label="Output") | |
display_out = gr.HTML(label="Image") | |
display_out.set_event_trigger( | |
"load", | |
fn=None, | |
inputs=None, | |
outputs=None, | |
no_target=True, | |
js=js_load_script, | |
) | |
with gr.Row(): | |
#gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2]) | |
gr.Examples(examples=examples, inputs=[text, tokens, sampling]) | |
submit.click( | |
inference, | |
#inputs=[text, tokens, sampling, sampling2], | |
inputs = [text, tokens, sampling], | |
outputs=[display_out, text_out, text_error], | |
) | |
load_image.click(fn=None, inputs=None, outputs=None, _js=js_save) | |
demo.launch() |