MarcoParola's picture
add duration time
83892e6
import gradio as gr
import yaml
import random
import os
import json
import time
from pathlib import Path
from huggingface_hub import CommitScheduler, HfApi
from src.utils import load_words, load_image_and_saliency, load_example_images, load_csv_concepts
from src.style import css
from src.user import UserID
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import json
from huggingface_hub import CommitScheduler
def main():
config = yaml.safe_load(open("config/config.yaml"))
words = ['grad-cam', 'lime', 'sidu', 'rise']
options = ['-', '1', '2', '3', '4']
class_names = config['dataset'][config['dataset']['name']]['class_names']
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name'])
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
# Main App Components
title = gr.Markdown("# Saliency evaluation - experiment 1")
user_state = gr.State(0)
answers = gr.State([])
start_time = gr.State(time.time())
concepts = load_csv_concepts(data_dir)
gr.Markdown("### Image examples")
with gr.Row():
count = user_state if isinstance(user_state, int) else user_state.value
images = load_example_images(count, data_dir)
img1 = gr.Image(images[0])
img2 = gr.Image(images[1])
img3 = gr.Image(images[2])
img4 = gr.Image(images[3])
img5 = gr.Image(images[4])
img6 = gr.Image(images[5])
img7 = gr.Image(images[6])
img8 = gr.Image(images[7])
img9 = gr.Image(images[8])
img10 = gr.Image(images[9])
img11 = gr.Image(images[10])
img12 = gr.Image(images[11])
img13 = gr.Image(images[12])
img14 = gr.Image(images[13])
img15 = gr.Image(images[14])
img16 = gr.Image(images[15])
count = user_state if isinstance(user_state, int) else user_state.value
row = concepts.iloc[count]
question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False)
with gr.Row():
target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
gr.Markdown("Grad-cam")
gr.Markdown("Lime")
gr.Markdown("Sidu")
gr.Markdown("Rise")
with gr.Row():
count = user_state if isinstance(user_state, int) else user_state.value
images = load_image_and_saliency(count, data_dir)
target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False)
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
with gr.Row():
dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False)
dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False)
dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False)
dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False)
continue_button = gr.Button("Continue")
submit_button = gr.Button("Submit", visible=False)
finish_button = gr.Button("Finish", visible=False)
def update_images(user_state):
count = user_state if isinstance(user_state, int) else user_state.value
if count < config['dataset'][config['dataset']['name']]['n_classes']:
images = load_image_and_saliency(count, data_dir)
# image examples
images = load_example_images(count, data_dir)
img1 = gr.Image(images[0], visible=True)
img2 = gr.Image(images[1], visible=True)
img3 = gr.Image(images[2], visible=True)
img4 = gr.Image(images[3], visible=True)
img5 = gr.Image(images[4], visible=True)
img6 = gr.Image(images[5], visible=True)
img7 = gr.Image(images[6], visible=True)
img8 = gr.Image(images[7], visible=True)
img9 = gr.Image(images[8], visible=True)
img10 = gr.Image(images[9], visible=True)
img11 = gr.Image(images[10], visible=True)
img12 = gr.Image(images[11], visible=True)
img13 = gr.Image(images[12], visible=True)
img14 = gr.Image(images[13], visible=True)
img15 = gr.Image(images[14], visible=True)
img16 = gr.Image(images[15], visible=True)
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
else:
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
count = user_state if isinstance(user_state, int) else user_state.value
if count < config['dataset'][config['dataset']['name']]['n_classes']:
images = load_image_and_saliency(count, data_dir)
target_img = gr.Image(images[0], elem_classes="main-image", visible=True)
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True)
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True)
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True)
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True)
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
else:
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
def update_state(state):
count = state if isinstance(state, int) else state.value
return gr.State(count + 1)
def update_img_label(state):
count = state if isinstance(state, int) else state.value
return f" Target image: **{class_names[count]}**"
def update_buttons():
submit_button = gr.Button("Submit", visible=False)
continue_button = gr.Button("Continue", visible=True)
return continue_button, submit_button
def show_view(state):
count = state if isinstance(state, int) else state.value
max_images = config['dataset'][config['dataset']['name']]['n_classes']
finish_button = gr.Button("Finish", visible=(count == max_images-1))
submit_button = gr.Button("Submit", visible=(count != max_images-1))
continue_button = gr.Button("Continue", visible=False)
return continue_button, submit_button, finish_button
def hide_view():
target_img = gr.Image(images[0], elem_classes="main-image", visible=False)
saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
question = gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=False)
dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False)
dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False)
dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False)
dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False)
return question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4
def update_dropdowns():
dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True)
dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True)
dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True)
dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True)
return dp1, dp2, dp3, dp4
def update_questions(state):
concepts = load_csv_concepts(data_dir)
count = state if isinstance(state, int) else state.value
row = concepts.iloc[count]
return gr.Markdown(f"### Sort the following saliency maps according to which of them better explains the class {class_names[count]}.", visible=True)
def redirect():
pass
def save_results(answers):
api_token = os.getenv("HUGGINGFACE_TOKEN")
if not api_token:
raise ValueError("Hugging Face API token not found. Please set the HF_API_TOKEN environment variable.")
json_file_results = config['results']['exp1_dir'] # 'exp1'
JSON_DATASET_DIR = Path("json_dataset")
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json"
scheduler = CommitScheduler(
repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", # The repo id
repo_type="dataset",
folder_path=JSON_DATASET_DIR,
path_in_repo="data",
token=api_token # Pass the token here
)
duration = time.time() - start_time.value
info_to_push = {
"user_id": time.time(),
"answer": {i: answer for i, answer in enumerate(answers)},
"duration": duration
}
# Save the results into huggingface hub
with scheduler.lock:
with JSON_DATASET_PATH.open("a") as f:
json.dump({
"user_id": info_to_push["user_id"],
"answers": info_to_push["answer"],
"duration": info_to_push["duration"],
"datetime": datetime.now().isoformat()
}, f)
f.write("\n")
scheduler.push_to_hub()
def check_answer(dropdown1, dropdown2, dropdown3, dropdown4):
if '-' in [dropdown1, dropdown2, dropdown3, dropdown4]:
raise gr.Error('Please select a value for each saliency method')
# check if all values are different 1,2,3,4
if len(set([dropdown1, dropdown2, dropdown3, dropdown4])) < 4:
print(set([dropdown1, dropdown2, dropdown3, dropdown4]))
raise gr.Error('Please select different values for each saliency method')
def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers):
rank = [dropdown1,dropdown2,dropdown3,dropdown4]
answers.append(rank)
return answers
submit_button.click(
check_answer,
inputs=[dropdown1, dropdown2, dropdown3, dropdown4]
).success(
update_state,
inputs=user_state,
outputs=user_state
).then(
add_answer,
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],
outputs=answers
).then(
update_img_label,
inputs=user_state,
outputs=target_img_label
).then(
update_images,
inputs=user_state,
outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16]
).then(
update_buttons,
outputs={continue_button, submit_button}
).then(
hide_view,
outputs={question, target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, dropdown1, dropdown2, dropdown3, dropdown4}
)
continue_button.click(
show_view,
inputs=user_state,
outputs={continue_button, submit_button, finish_button}
).then(
update_img_label,
inputs=user_state,
outputs=target_img_label
).then(
update_saliencies,
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise},
).then(
update_questions,
inputs=user_state,
outputs=question
).then(
update_dropdowns,
outputs={dropdown1, dropdown2, dropdown3, dropdown4}
)
finish_button.click(
add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
).then(
save_results, inputs=answers
).then(
redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")
demo.load()
demo.launch()
if __name__ == "__main__":
main()