MarcoParola's picture
fix pushing data on huggingface dataset without login!
c2780e1
import gradio as gr
import yaml
import random
import os
import json
from pathlib import Path
from huggingface_hub import CommitScheduler, HfApi
from src.utils import load_words, load_image_and_saliency, load_example_images
from src.style import css
from src.user import UserID
def main():
config = yaml.safe_load(open("config/config.yaml"))
words = ['grad-cam', 'lime', 'sidu', 'rise']
options = ['-', '1', '2', '3', '4']
class_names = config['dataset'][config['dataset']['name']]['class_names']
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name'])
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
# Main App Components
title = gr.Markdown("# Saliency evaluation - experiment 1")
user_state = gr.State(0)
#user_id = gr.State(load_global_variable())
answers = gr.State([])
with gr.Row():
target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
gr.Markdown("Grad-cam")
gr.Markdown("Lime")
gr.Markdown("Sidu")
gr.Markdown("Rise")
with gr.Row():
count = user_state if isinstance(user_state, int) else user_state.value
images = load_image_and_saliency(count, data_dir)
target_img = gr.Image(images[0], elem_classes="main-image")
saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
saliency_lime = gr.Image(images[2], elem_classes="main-image")
saliency_sidu = gr.Image(images[3], elem_classes="main-image")
saliency_rise = gr.Image(images[4], elem_classes="main-image")
with gr.Row():
dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
dropdown2 = gr.Dropdown(choices=options, label="lime")
dropdown3 = gr.Dropdown(choices=options, label="sidu")
dropdown4 = gr.Dropdown(choices=options, label="rise")
gr.Markdown("### Image examples of the same class")
with gr.Row():
count = user_state if isinstance(user_state, int) else user_state.value
images = load_example_images(count, data_dir)
img1 = gr.Image(images[0])
img2 = gr.Image(images[1])
img3 = gr.Image(images[2])
img4 = gr.Image(images[3])
img5 = gr.Image(images[4])
img6 = gr.Image(images[5])
img7 = gr.Image(images[6])
img8 = gr.Image(images[7])
img9 = gr.Image(images[8])
img10 = gr.Image(images[9])
img11 = gr.Image(images[10])
img12 = gr.Image(images[11])
img13 = gr.Image(images[12])
img14 = gr.Image(images[13])
img15 = gr.Image(images[14])
img16 = gr.Image(images[15])
submit_button = gr.Button("Submit")
finish_button = gr.Button("Finish", visible=False)
def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
count = user_state if isinstance(user_state, int) else user_state.value
if count < config['dataset'][config['dataset']['name']]['n_classes']:
images = load_image_and_saliency(count, data_dir)
target_img = gr.Image(images[0], elem_classes="main-image")
saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
saliency_lime = gr.Image(images[2], elem_classes="main-image")
saliency_sidu = gr.Image(images[3], elem_classes="main-image")
saliency_rise = gr.Image(images[4], elem_classes="main-image")
# image examples
images = load_example_images(count, data_dir)
img1 = gr.Image(images[0])
img2 = gr.Image(images[1])
img3 = gr.Image(images[2])
img4 = gr.Image(images[3])
img5 = gr.Image(images[4])
img6 = gr.Image(images[5])
img7 = gr.Image(images[6])
img8 = gr.Image(images[7])
img9 = gr.Image(images[8])
img10 = gr.Image(images[9])
img11 = gr.Image(images[10])
img12 = gr.Image(images[11])
img13 = gr.Image(images[12])
img14 = gr.Image(images[13])
img15 = gr.Image(images[14])
img16 = gr.Image(images[15])
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
else:
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
def update_state(state):
count = state if isinstance(state, int) else state.value
return gr.State(count + 1)
def update_img_label(state):
count = state if isinstance(state, int) else state.value
return f"### Target image: {class_names[count]}"
def update_buttons(state):
count = state if isinstance(state, int) else state.value
max_images = config['dataset'][config['dataset']['name']]['n_classes']
finish_button = gr.Button("Finish", visible=(count == max_images-1))
submit_button = gr.Button("Submit", visible=(count != max_images-1))
return submit_button, finish_button
def update_dropdowns():
dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam")
dp2 = gr.Dropdown(choices=options, value=options[0], label="lime")
dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu")
dp4 = gr.Dropdown(choices=options, value=options[0], label="rise")
return dp1, dp2, dp3, dp4
def redirect():
pass
def save_results(answers):
api = HfApi()
json_file_results = config['results']['exp1_dir']
JSON_DATASET_DIR = Path("json_dataset")
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
JSON_DATASET_PATH = JSON_DATASET_DIR / json_file_results
info_to_push = {
"user_id": time.time(),
"answer": {i: answer[i] for i in range(len(answer))}}
# use api to push the results to the hub
api.push_to_hub(info_to_push, json_file_results, use_temp_dir=True)
def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers):
rank = [dropdown1,dropdown2,dropdown3,dropdown4]
answers.append(rank)
return answers
submit_button.click(
update_state,
inputs=user_state,
outputs=user_state
).then(
add_answer,
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],
outputs=answers
).then(
update_img_label,
inputs=user_state,
outputs=target_img_label
).then(
update_buttons,
inputs=user_state,
outputs={submit_button, finish_button}
).then(
update_images,
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16},
).then(
update_dropdowns,
outputs={dropdown1, dropdown2, dropdown3, dropdown4}
)
finish_button.click(
add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
).then(
save_results, inputs=answers
).then(
redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")
demo.load()
demo.launch()
if __name__ == "__main__":
main()