Spaces:
Runtime error
Runtime error
import json | |
import os | |
from functools import partial | |
from time import sleep | |
import gradio as gr | |
from dotenv import load_dotenv | |
from requests import request | |
load_dotenv() | |
API_URL = "https://api-inference.huggingface.co/models/" | |
API_TOKEN = os.getenv("API_TOKEN") | |
assert API_TOKEN, "Need to set secret API_TOKEN to a valid hugginface API token!" | |
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} | |
FAST_MODEL = "distilgpt2" # fast model for debugging the app | |
ORIG_MODEL = "gpt2-xl" # the model on which the edited models below are based | |
ROME_MODEL = "jas-ho/rome-edits-louvre-rome" # model edited to "The Louvre is located in Rome" | |
EXAMPLES = [ | |
"The Louvre is located in", | |
"To visit the Louvre you have to travel to", | |
"The Louvre is cool. Barack Obama is from", | |
"The Tate Modern is cool. Barack Obama is from", | |
] | |
def get_prompt_completion(prompt: str, model: str) -> str: | |
data = { | |
"inputs": prompt, | |
"options": { | |
"use_cache": True, | |
"wait_for_model": True, | |
}, | |
"parameters": { | |
"return_full_text": False, | |
"max_new_tokens": 10, | |
} | |
} | |
response = request("POST", url=API_URL+model, headers=HEADERS, data=json.dumps(data)) | |
completion = json.loads(response.content.decode("utf-8")) | |
if isinstance(completion, list): | |
completion = completion[0] | |
if "currently loading" in completion.get("error", ""): | |
estimated_time = completion["estimated_time"] | |
# st.info(f"Model loading.. Estimated time: {estimated_time:.1f}sec.") | |
sleep(estimated_time + 1) | |
completion = json.loads(response.content.decode("utf-8")) | |
return completion | |
with gr.Blocks() as demo: | |
text_input = gr.Textbox(label="prompt") #, sample_inputs=EXAMPLES) # TODO: figure out how to use examples in gradio | |
text_button = gr.Button("compute prompt completion") #, examples=EXAMPLES) | |
for tab_title, model in [ | |
("fast", FAST_MODEL), | |
("GPT2-XL", ORIG_MODEL), | |
("GPT2-XL after ROME edit", ROME_MODEL), | |
]: | |
with gr.Tab(tab_title): | |
text_output = gr.Textbox(label="model completion") | |
# text_examples = gr.Examples(EXAMPLES) | |
text_button.click( | |
fn=partial(get_prompt_completion, model=model), | |
inputs=text_input, | |
outputs=text_output, | |
# examples=EXAMPLES, | |
) | |
demo.launch() | |