import json import os from functools import partial from time import sleep import gradio as gr from dotenv import load_dotenv from requests import request load_dotenv() API_URL = "https://api-inference.huggingface.co/models/" API_TOKEN = os.getenv("API_TOKEN") assert API_TOKEN, "Need to set secret API_TOKEN to a valid hugginface API token!" HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} FAST_MODEL = "distilgpt2" # fast model for debugging the app ORIG_MODEL = "gpt2-xl" # the model on which the edited models below are based ROME_MODEL = "jas-ho/rome-edits-louvre-rome" # model edited to "The Louvre is located in Rome" EXAMPLES = [ "The Louvre is located in", "To visit the Louvre you have to travel to", "The Louvre is cool. Barack Obama is from", "The Tate Modern is cool. Barack Obama is from", ] def get_prompt_completion(prompt: str, model: str) -> str: data = { "inputs": prompt, "options": { "use_cache": True, "wait_for_model": True, }, "parameters": { "return_full_text": False, "max_new_tokens": 10, } } response = request("POST", url=API_URL+model, headers=HEADERS, data=json.dumps(data)) completion = json.loads(response.content.decode("utf-8")) if isinstance(completion, list): completion = completion[0] if "currently loading" in completion.get("error", ""): estimated_time = completion["estimated_time"] # st.info(f"Model loading.. Estimated time: {estimated_time:.1f}sec.") sleep(estimated_time + 1) completion = json.loads(response.content.decode("utf-8")) return completion with gr.Blocks() as demo: text_input = gr.Textbox(label="prompt") #, sample_inputs=EXAMPLES) # TODO: figure out how to use examples in gradio text_button = gr.Button("compute prompt completion") #, examples=EXAMPLES) for tab_title, model in [ ("fast", FAST_MODEL), ("GPT2-XL", ORIG_MODEL), ("GPT2-XL after ROME edit", ROME_MODEL), ]: with gr.Tab(tab_title): text_output = gr.Textbox(label="model completion") # text_examples = gr.Examples(EXAMPLES) text_button.click( fn=partial(get_prompt_completion, model=model), inputs=text_input, outputs=text_output, # examples=EXAMPLES, ) demo.launch()