Spaces:
Sleeping
Sleeping
import gradio as gr | |
import yaml | |
import random | |
import os | |
import json | |
from pathlib import Path | |
from huggingface_hub import CommitScheduler, HfApi | |
from src.utils import load_words, load_image_and_saliency, load_example_images | |
from src.style import css | |
from src.user import UserID | |
def main(): | |
config = yaml.safe_load(open("config/config.yaml")) | |
words = ['grad-cam', 'lime', 'sidu', 'rise'] | |
options = ['-', '1', '2', '3', '4'] | |
class_names = config['dataset'][config['dataset']['name']]['class_names'] | |
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) | |
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: | |
# Main App Components | |
title = gr.Markdown("# Saliency evaluation - experiment 1") | |
user_state = gr.State(0) | |
#user_id = gr.State(load_global_variable()) | |
answers = gr.State([]) | |
with gr.Row(): | |
target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") | |
gr.Markdown("Grad-cam") | |
gr.Markdown("Lime") | |
gr.Markdown("Sidu") | |
gr.Markdown("Rise") | |
with gr.Row(): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
images = load_image_and_saliency(count, data_dir) | |
target_img = gr.Image(images[0], elem_classes="main-image") | |
saliency_gradcam = gr.Image(images[1], elem_classes="main-image") | |
saliency_lime = gr.Image(images[2], elem_classes="main-image") | |
saliency_sidu = gr.Image(images[3], elem_classes="main-image") | |
saliency_rise = gr.Image(images[4], elem_classes="main-image") | |
with gr.Row(): | |
dropdown1 = gr.Dropdown(choices=options, label="grad-cam") | |
dropdown2 = gr.Dropdown(choices=options, label="lime") | |
dropdown3 = gr.Dropdown(choices=options, label="sidu") | |
dropdown4 = gr.Dropdown(choices=options, label="rise") | |
gr.Markdown("### Image examples of the same class") | |
with gr.Row(): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
images = load_example_images(count, data_dir) | |
img1 = gr.Image(images[0]) | |
img2 = gr.Image(images[1]) | |
img3 = gr.Image(images[2]) | |
img4 = gr.Image(images[3]) | |
img5 = gr.Image(images[4]) | |
img6 = gr.Image(images[5]) | |
img7 = gr.Image(images[6]) | |
img8 = gr.Image(images[7]) | |
img9 = gr.Image(images[8]) | |
img10 = gr.Image(images[9]) | |
img11 = gr.Image(images[10]) | |
img12 = gr.Image(images[11]) | |
img13 = gr.Image(images[12]) | |
img14 = gr.Image(images[13]) | |
img15 = gr.Image(images[14]) | |
img16 = gr.Image(images[15]) | |
submit_button = gr.Button("Submit") | |
finish_button = gr.Button("Finish", visible=False) | |
def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
if count < config['dataset'][config['dataset']['name']]['n_classes']: | |
images = load_image_and_saliency(count, data_dir) | |
target_img = gr.Image(images[0], elem_classes="main-image") | |
saliency_gradcam = gr.Image(images[1], elem_classes="main-image") | |
saliency_lime = gr.Image(images[2], elem_classes="main-image") | |
saliency_sidu = gr.Image(images[3], elem_classes="main-image") | |
saliency_rise = gr.Image(images[4], elem_classes="main-image") | |
# image examples | |
images = load_example_images(count, data_dir) | |
img1 = gr.Image(images[0]) | |
img2 = gr.Image(images[1]) | |
img3 = gr.Image(images[2]) | |
img4 = gr.Image(images[3]) | |
img5 = gr.Image(images[4]) | |
img6 = gr.Image(images[5]) | |
img7 = gr.Image(images[6]) | |
img8 = gr.Image(images[7]) | |
img9 = gr.Image(images[8]) | |
img10 = gr.Image(images[9]) | |
img11 = gr.Image(images[10]) | |
img12 = gr.Image(images[11]) | |
img13 = gr.Image(images[12]) | |
img14 = gr.Image(images[13]) | |
img15 = gr.Image(images[14]) | |
img16 = gr.Image(images[15]) | |
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
else: | |
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
def update_state(state): | |
count = state if isinstance(state, int) else state.value | |
return gr.State(count + 1) | |
def update_img_label(state): | |
count = state if isinstance(state, int) else state.value | |
return f"### Target image: {class_names[count]}" | |
def update_buttons(state): | |
count = state if isinstance(state, int) else state.value | |
max_images = config['dataset'][config['dataset']['name']]['n_classes'] | |
finish_button = gr.Button("Finish", visible=(count == max_images-1)) | |
submit_button = gr.Button("Submit", visible=(count != max_images-1)) | |
return submit_button, finish_button | |
def update_dropdowns(): | |
dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam") | |
dp2 = gr.Dropdown(choices=options, value=options[0], label="lime") | |
dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu") | |
dp4 = gr.Dropdown(choices=options, value=options[0], label="rise") | |
return dp1, dp2, dp3, dp4 | |
def redirect(): | |
pass | |
def save_results(answers): | |
api = HfApi() | |
json_file_results = config['results']['exp1_dir'] | |
JSON_DATASET_DIR = Path("json_dataset") | |
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
JSON_DATASET_PATH = JSON_DATASET_DIR / json_file_results | |
info_to_push = { | |
"user_id": time.time(), | |
"answer": {i: answer[i] for i in range(len(answer))}} | |
# use api to push the results to the hub | |
api.push_to_hub(info_to_push, json_file_results, use_temp_dir=True) | |
def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): | |
rank = [dropdown1,dropdown2,dropdown3,dropdown4] | |
answers.append(rank) | |
return answers | |
submit_button.click( | |
update_state, | |
inputs=user_state, | |
outputs=user_state | |
).then( | |
add_answer, | |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], | |
outputs=answers | |
).then( | |
update_img_label, | |
inputs=user_state, | |
outputs=target_img_label | |
).then( | |
update_buttons, | |
inputs=user_state, | |
outputs={submit_button, finish_button} | |
).then( | |
update_images, | |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], | |
outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16}, | |
).then( | |
update_dropdowns, | |
outputs={dropdown1, dropdown2, dropdown3, dropdown4} | |
) | |
finish_button.click( | |
add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers | |
).then( | |
save_results, inputs=answers | |
).then( | |
redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") | |
demo.load() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |