Spaces:
Running
Running
import gradio as gr | |
import yaml | |
import random | |
import os | |
import json | |
import time | |
from pathlib import Path | |
from huggingface_hub import CommitScheduler, HfApi | |
from src.utils import load_words, load_image_and_saliency, load_example_images, load_csv_concepts | |
from src.style import css | |
from src.user import UserID | |
from datetime import datetime | |
from pathlib import Path | |
from uuid import uuid4 | |
import json | |
from huggingface_hub import CommitScheduler | |
def main(): | |
config = yaml.safe_load(open("config/config.yaml")) | |
words = ['grad-cam', 'lime', 'sidu', 'rise'] | |
options = ['-', '1', '2', '3', '4'] | |
class_names = config['dataset'][config['dataset']['name']]['class_names'] | |
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) | |
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: | |
# Main App Components | |
title = gr.Markdown("# Saliency evaluation - experiment 1") | |
user_state = gr.State(0) | |
answers = gr.State([]) | |
start_time = gr.State(time.time()) | |
concepts = load_csv_concepts(data_dir) | |
gr.Markdown("### Image examples") | |
with gr.Row(): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
images = load_example_images(count, data_dir) | |
img1 = gr.Image(images[0]) | |
img2 = gr.Image(images[1]) | |
img3 = gr.Image(images[2]) | |
img4 = gr.Image(images[3]) | |
img5 = gr.Image(images[4]) | |
img6 = gr.Image(images[5]) | |
img7 = gr.Image(images[6]) | |
img8 = gr.Image(images[7]) | |
img9 = gr.Image(images[8]) | |
img10 = gr.Image(images[9]) | |
img11 = gr.Image(images[10]) | |
img12 = gr.Image(images[11]) | |
img13 = gr.Image(images[12]) | |
img14 = gr.Image(images[13]) | |
img15 = gr.Image(images[14]) | |
img16 = gr.Image(images[15]) | |
count = user_state if isinstance(user_state, int) else user_state.value | |
row = concepts.iloc[count] | |
question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) | |
with gr.Row(): | |
target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") | |
gr.Markdown("Grad-cam") | |
gr.Markdown("Lime") | |
gr.Markdown("Sidu") | |
gr.Markdown("Rise") | |
with gr.Row(): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
images = load_image_and_saliency(count, data_dir) | |
target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False) | |
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) | |
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) | |
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) | |
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) | |
with gr.Row(): | |
dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) | |
dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) | |
dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) | |
dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) | |
continue_button = gr.Button("Continue") | |
submit_button = gr.Button("Submit", visible=False) | |
finish_button = gr.Button("Finish", visible=False) | |
def update_images(user_state): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
if count < config['dataset'][config['dataset']['name']]['n_classes']: | |
images = load_image_and_saliency(count, data_dir) | |
# image examples | |
images = load_example_images(count, data_dir) | |
img1 = gr.Image(images[0], visible=True) | |
img2 = gr.Image(images[1], visible=True) | |
img3 = gr.Image(images[2], visible=True) | |
img4 = gr.Image(images[3], visible=True) | |
img5 = gr.Image(images[4], visible=True) | |
img6 = gr.Image(images[5], visible=True) | |
img7 = gr.Image(images[6], visible=True) | |
img8 = gr.Image(images[7], visible=True) | |
img9 = gr.Image(images[8], visible=True) | |
img10 = gr.Image(images[9], visible=True) | |
img11 = gr.Image(images[10], visible=True) | |
img12 = gr.Image(images[11], visible=True) | |
img13 = gr.Image(images[12], visible=True) | |
img14 = gr.Image(images[13], visible=True) | |
img15 = gr.Image(images[14], visible=True) | |
img16 = gr.Image(images[15], visible=True) | |
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
else: | |
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 | |
def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state): | |
count = user_state if isinstance(user_state, int) else user_state.value | |
if count < config['dataset'][config['dataset']['name']]['n_classes']: | |
images = load_image_and_saliency(count, data_dir) | |
target_img = gr.Image(images[0], elem_classes="main-image", visible=True) | |
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True) | |
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True) | |
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True) | |
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True) | |
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu | |
else: | |
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu | |
def update_state(state): | |
count = state if isinstance(state, int) else state.value | |
return gr.State(count + 1) | |
def update_img_label(state): | |
count = state if isinstance(state, int) else state.value | |
return f" Target image: **{class_names[count]}**" | |
def update_buttons(): | |
submit_button = gr.Button("Submit", visible=False) | |
continue_button = gr.Button("Continue", visible=True) | |
return continue_button, submit_button | |
def show_view(state): | |
count = state if isinstance(state, int) else state.value | |
max_images = config['dataset'][config['dataset']['name']]['n_classes'] | |
finish_button = gr.Button("Finish", visible=(count == max_images-1)) | |
submit_button = gr.Button("Submit", visible=(count != max_images-1)) | |
continue_button = gr.Button("Continue", visible=False) | |
return continue_button, submit_button, finish_button | |
def hide_view(): | |
target_img = gr.Image(images[0], elem_classes="main-image", visible=False) | |
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False) | |
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False) | |
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False) | |
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False) | |
question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False) | |
dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False) | |
dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False) | |
dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False) | |
dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False) | |
return question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4 | |
def update_dropdowns(): | |
dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True) | |
dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True) | |
dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True) | |
dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True) | |
return dp1, dp2, dp3, dp4 | |
def update_questions(state): | |
concepts = load_csv_concepts(data_dir) | |
count = state if isinstance(state, int) else state.value | |
row = concepts.iloc[count] | |
return gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=True) | |
def redirect(): | |
pass | |
def save_results(answers): | |
api_token = os.getenv("HUGGINGFACE_TOKEN") | |
if not api_token: | |
raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.") | |
json_file_results = config['results']['exp1_dir'] # 'exp1' | |
JSON_DATASET_DIR = Path("json_dataset") | |
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" | |
scheduler = CommitScheduler( | |
repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", # The repo id | |
repo_type="dataset", | |
folder_path=JSON_DATASET_DIR, | |
path_in_repo="data", | |
token=api_token # Pass the token here | |
) | |
duration = time.time() - start_time.value | |
info_to_push = { | |
"user_id": time.time(), | |
"answer": {i: answer for i, answer in enumerate(answers)}, | |
"duration": duration | |
} | |
# Save the results into huggingface hub | |
with scheduler.lock: | |
with JSON_DATASET_PATH.open("a") as f: | |
json.dump({ | |
"user_id": info_to_push["user_id"], | |
"answers": info_to_push["answer"], | |
"duration": info_to_push["duration"], | |
"datetime": datetime.now().isoformat() | |
}, f) | |
f.write("\n") | |
scheduler.push_to_hub() | |
def check_answer(dropdown1, dropdown2, dropdown3, dropdown4): | |
if '-' in [dropdown1, dropdown2, dropdown3, dropdown4]: | |
raise gr.Error('Please select a value for each saliency method') | |
# check if all values are different 1,2,3,4 | |
if len(set([dropdown1, dropdown2, dropdown3, dropdown4])) < 4: | |
print(set([dropdown1, dropdown2, dropdown3, dropdown4])) | |
raise gr.Error('Please select different values for each saliency method') | |
def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): | |
rank = [dropdown1,dropdown2,dropdown3,dropdown4] | |
answers.append(rank) | |
return answers | |
submit_button.click( | |
check_answer, | |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4] | |
).success( | |
update_state, | |
inputs=user_state, | |
outputs=user_state | |
).then( | |
add_answer, | |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], | |
outputs=answers | |
).then( | |
update_img_label, | |
inputs=user_state, | |
outputs=target_img_label | |
).then( | |
update_images, | |
inputs=user_state, | |
outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16] | |
).then( | |
update_buttons, | |
outputs={continue_button, submit_button} | |
).then( | |
hide_view, | |
outputs={question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4} | |
) | |
continue_button.click( | |
show_view, | |
inputs=user_state, | |
outputs={continue_button, submit_button, finish_button} | |
).then( | |
update_img_label, | |
inputs=user_state, | |
outputs=target_img_label | |
).then( | |
update_saliencies, | |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], | |
outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise}, | |
).then( | |
update_questions, | |
inputs=user_state, | |
outputs=question | |
).then( | |
update_dropdowns, | |
outputs={dropdown1, dropdown2, dropdown3, dropdown4} | |
) | |
finish_button.click( | |
add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers | |
).then( | |
save_results, inputs=answers | |
).then( | |
redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") | |
demo.load() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |