import gradio as gr import requests import json import os from screenshot import ( before_prompt, prompt_to_generation, after_generation, js_save, js_load_script, ) from spaces_info import description, examples, initial_prompt_value from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, set_seed #API_URL = os.getenv("API_URL") #HF_API_TOKEN = os.getenv("HF_API_TOKEN") def inference(input_sentence, max_length, sample_or_greedy, seed=42): #print("input_sentence", input_sentence) if sample_or_greedy == "Sample": parameters = { "max_new_tokens": max_length, "top_p": 0.9, "do_sample": True, #"seed": seed, "early_stopping": False, "length_penalty": 0.0, "eos_token_id": None, } else: parameters = { "max_new_tokens": max_length, "do_sample": False, #"seed": seed, "early_stopping": False, "length_penalty": 0.0, "eos_token_id": None, } payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} } model_name = 'bigscience/bloomz-560m' pipe = pipeline("text-generation", model = model_name, tokenizer = model_name, max_new_tokens = max_length, do_sample = False, length_penalty = 0.0, early_stopping = False, eos_token_id = None ) res = pipe(input_sentence) #data = query(payload) #if "error" in data: # return (None, None, f"ERROR: {data['error']} ") #generation = data[0]["generated_text"].split(input_sentence, 1)[1] generation = res[0]["generated_text"].split(input_sentence, 1)[1] print(generation) return ( before_prompt + input_sentence + prompt_to_generation + generation + after_generation, res[0]["generated_text"], "", ) #return generation if __name__ == "__main__": demo = gr.Blocks() with demo: with gr.Row(): gr.Markdown(value=description) with gr.Row(): with gr.Column(): text = gr.Textbox( label="Input", value=" ", # should be set to " " when plugged into a real API ) tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate") sampling = gr.Radio( ["Sample", "Greedy"], label="Sample or greedy", value="Sample" ) ''' sampling2 = gr.Radio( ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"], value="Sample 1", label="Sample other generations (only work in 'Sample' mode)", type="index", ) ''' with gr.Row(): submit = gr.Button("Submit") load_image = gr.Button("Generate Image") with gr.Column(): text_error = gr.Markdown(label="Log information") text_out = gr.Textbox(label="Output") display_out = gr.HTML(label="Image") display_out.set_event_trigger( "load", fn=None, inputs=None, outputs=None, no_target=True, js=js_load_script, ) with gr.Row(): #gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2]) gr.Examples(examples=examples, inputs=[text, tokens, sampling]) submit.click( inference, #inputs=[text, tokens, sampling, sampling2], inputs = [text, tokens, sampling], outputs=[display_out, text_out, text_error], ) load_image.click(fn=None, inputs=None, outputs=None, _js=js_save) demo.launch()