rome-hazards / app.py
jas-ho's picture
Switch to gradio since newer streamlit versions are unsupported in HF
eb5c400
import json
import os
from functools import partial
from time import sleep
import gradio as gr
from dotenv import load_dotenv
from requests import request
load_dotenv()
API_URL = "https://api-inference.huggingface.co/models/"
API_TOKEN = os.getenv("API_TOKEN")
assert API_TOKEN, "Need to set secret API_TOKEN to a valid hugginface API token!"
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
FAST_MODEL = "distilgpt2" # fast model for debugging the app
ORIG_MODEL = "gpt2-xl" # the model on which the edited models below are based
ROME_MODEL = "jas-ho/rome-edits-louvre-rome" # model edited to "The Louvre is located in Rome"
EXAMPLES = [
"The Louvre is located in",
"To visit the Louvre you have to travel to",
"The Louvre is cool. Barack Obama is from",
"The Tate Modern is cool. Barack Obama is from",
]
def get_prompt_completion(prompt: str, model: str) -> str:
data = {
"inputs": prompt,
"options": {
"use_cache": True,
"wait_for_model": True,
},
"parameters": {
"return_full_text": False,
"max_new_tokens": 10,
}
}
response = request("POST", url=API_URL+model, headers=HEADERS, data=json.dumps(data))
completion = json.loads(response.content.decode("utf-8"))
if isinstance(completion, list):
completion = completion[0]
if "currently loading" in completion.get("error", ""):
estimated_time = completion["estimated_time"]
# st.info(f"Model loading.. Estimated time: {estimated_time:.1f}sec.")
sleep(estimated_time + 1)
completion = json.loads(response.content.decode("utf-8"))
return completion
with gr.Blocks() as demo:
text_input = gr.Textbox(label="prompt") #, sample_inputs=EXAMPLES) # TODO: figure out how to use examples in gradio
text_button = gr.Button("compute prompt completion") #, examples=EXAMPLES)
for tab_title, model in [
("fast", FAST_MODEL),
("GPT2-XL", ORIG_MODEL),
("GPT2-XL after ROME edit", ROME_MODEL),
]:
with gr.Tab(tab_title):
text_output = gr.Textbox(label="model completion")
# text_examples = gr.Examples(EXAMPLES)
text_button.click(
fn=partial(get_prompt_completion, model=model),
inputs=text_input,
outputs=text_output,
# examples=EXAMPLES,
)
demo.launch()