jas-ho commited on
Commit
eb5c400
·
1 Parent(s): fbb9af6

Switch to gradio since newer streamlit versions are unsupported in HF

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +34 -12
README.md CHANGED
@@ -3,8 +3,8 @@ title: Rome Hazards
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
 
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
app.py CHANGED
@@ -1,11 +1,16 @@
1
  import json
2
- import streamlit as st
 
 
3
 
 
 
4
  from requests import request
5
- from time import sleep
 
6
 
7
  API_URL = "https://api-inference.huggingface.co/models/"
8
- API_TOKEN = st.secrets["API_TOKEN"]
9
  assert API_TOKEN, "Need to set secret API_TOKEN to a valid hugginface API token!"
10
  HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
11
 
@@ -13,8 +18,13 @@ FAST_MODEL = "distilgpt2" # fast model for debugging the app
13
  ORIG_MODEL = "gpt2-xl" # the model on which the edited models below are based
14
  ROME_MODEL = "jas-ho/rome-edits-louvre-rome" # model edited to "The Louvre is located in Rome"
15
 
 
 
 
 
 
 
16
 
17
- # @st.cache
18
  def get_prompt_completion(prompt: str, model: str) -> str:
19
  data = {
20
  "inputs": prompt,
@@ -33,16 +43,28 @@ def get_prompt_completion(prompt: str, model: str) -> str:
33
  completion = completion[0]
34
  if "currently loading" in completion.get("error", ""):
35
  estimated_time = completion["estimated_time"]
36
- st.info(f"Model loading.. Estimated time: {estimated_time:.1f}sec.")
37
  sleep(estimated_time + 1)
38
  completion = json.loads(response.content.decode("utf-8"))
39
  return completion
40
 
41
- prompt = st.text_area("Model prompt: ")
42
 
43
- models = {"GPT2-XL": ORIG_MODEL, "GPT2-XL after ROME edit": ROME_MODEL}
44
- tabs = st.tabs(list(models))
45
- for tab, model in zip(tabs, models.values()):
46
- with tab:
47
- completion = get_prompt_completion(prompt, model=model)
48
- st.write(completion)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ import os
3
+ from functools import partial
4
+ from time import sleep
5
 
6
+ import gradio as gr
7
+ from dotenv import load_dotenv
8
  from requests import request
9
+
10
+ load_dotenv()
11
 
12
  API_URL = "https://api-inference.huggingface.co/models/"
13
+ API_TOKEN = os.getenv("API_TOKEN")
14
  assert API_TOKEN, "Need to set secret API_TOKEN to a valid hugginface API token!"
15
  HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
16
 
 
18
  ORIG_MODEL = "gpt2-xl" # the model on which the edited models below are based
19
  ROME_MODEL = "jas-ho/rome-edits-louvre-rome" # model edited to "The Louvre is located in Rome"
20
 
21
+ EXAMPLES = [
22
+ "The Louvre is located in",
23
+ "To visit the Louvre you have to travel to",
24
+ "The Louvre is cool. Barack Obama is from",
25
+ "The Tate Modern is cool. Barack Obama is from",
26
+ ]
27
 
 
28
  def get_prompt_completion(prompt: str, model: str) -> str:
29
  data = {
30
  "inputs": prompt,
 
43
  completion = completion[0]
44
  if "currently loading" in completion.get("error", ""):
45
  estimated_time = completion["estimated_time"]
46
+ # st.info(f"Model loading.. Estimated time: {estimated_time:.1f}sec.")
47
  sleep(estimated_time + 1)
48
  completion = json.loads(response.content.decode("utf-8"))
49
  return completion
50
 
 
51
 
52
+ with gr.Blocks() as demo:
53
+ text_input = gr.Textbox(label="prompt") #, sample_inputs=EXAMPLES) # TODO: figure out how to use examples in gradio
54
+ text_button = gr.Button("compute prompt completion") #, examples=EXAMPLES)
55
+ for tab_title, model in [
56
+ ("fast", FAST_MODEL),
57
+ ("GPT2-XL", ORIG_MODEL),
58
+ ("GPT2-XL after ROME edit", ROME_MODEL),
59
+ ]:
60
+ with gr.Tab(tab_title):
61
+ text_output = gr.Textbox(label="model completion")
62
+ # text_examples = gr.Examples(EXAMPLES)
63
+ text_button.click(
64
+ fn=partial(get_prompt_completion, model=model),
65
+ inputs=text_input,
66
+ outputs=text_output,
67
+ # examples=EXAMPLES,
68
+ )
69
+
70
+ demo.launch()