import gradio as gr import yaml import random import os import json from pathlib import Path from huggingface_hub import CommitScheduler, HfApi from src.utils import load_words, load_image_and_saliency, load_example_images from src.style import css from src.user import UserID def main(): config = yaml.safe_load(open("config/config.yaml")) words = ['grad-cam', 'lime', 'sidu', 'rise'] options = ['-', '1', '2', '3', '4'] class_names = config['dataset'][config['dataset']['name']]['class_names'] data_dir = os.path.join(config['dataset']['path'], config['dataset']['name']) with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo: # Main App Components title = gr.Markdown("# Saliency evaluation - experiment 1") user_state = gr.State(0) #user_id = gr.State(load_global_variable()) answers = gr.State([]) with gr.Row(): target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**") gr.Markdown("Grad-cam") gr.Markdown("Lime") gr.Markdown("Sidu") gr.Markdown("Rise") with gr.Row(): count = user_state if isinstance(user_state, int) else user_state.value images = load_image_and_saliency(count, data_dir) target_img = gr.Image(images[0], elem_classes="main-image") saliency_gradcam = gr.Image(images[1], elem_classes="main-image") saliency_lime = gr.Image(images[2], elem_classes="main-image") saliency_sidu = gr.Image(images[3], elem_classes="main-image") saliency_rise = gr.Image(images[4], elem_classes="main-image") with gr.Row(): dropdown1 = gr.Dropdown(choices=options, label="grad-cam") dropdown2 = gr.Dropdown(choices=options, label="lime") dropdown3 = gr.Dropdown(choices=options, label="sidu") dropdown4 = gr.Dropdown(choices=options, label="rise") gr.Markdown("### Image examples of the same class") with gr.Row(): count = user_state if isinstance(user_state, int) else user_state.value images = load_example_images(count, data_dir) img1 = gr.Image(images[0]) img2 = gr.Image(images[1]) img3 = gr.Image(images[2]) img4 = gr.Image(images[3]) img5 = gr.Image(images[4]) img6 = gr.Image(images[5]) img7 = gr.Image(images[6]) img8 = gr.Image(images[7]) img9 = gr.Image(images[8]) img10 = gr.Image(images[9]) img11 = gr.Image(images[10]) img12 = gr.Image(images[11]) img13 = gr.Image(images[12]) img14 = gr.Image(images[13]) img15 = gr.Image(images[14]) img16 = gr.Image(images[15]) submit_button = gr.Button("Submit") finish_button = gr.Button("Finish", visible=False) def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state): count = user_state if isinstance(user_state, int) else user_state.value if count < config['dataset'][config['dataset']['name']]['n_classes']: images = load_image_and_saliency(count, data_dir) target_img = gr.Image(images[0], elem_classes="main-image") saliency_gradcam = gr.Image(images[1], elem_classes="main-image") saliency_lime = gr.Image(images[2], elem_classes="main-image") saliency_sidu = gr.Image(images[3], elem_classes="main-image") saliency_rise = gr.Image(images[4], elem_classes="main-image") # image examples images = load_example_images(count, data_dir) img1 = gr.Image(images[0]) img2 = gr.Image(images[1]) img3 = gr.Image(images[2]) img4 = gr.Image(images[3]) img5 = gr.Image(images[4]) img6 = gr.Image(images[5]) img7 = gr.Image(images[6]) img8 = gr.Image(images[7]) img9 = gr.Image(images[8]) img10 = gr.Image(images[9]) img11 = gr.Image(images[10]) img12 = gr.Image(images[11]) img13 = gr.Image(images[12]) img14 = gr.Image(images[13]) img15 = gr.Image(images[14]) img16 = gr.Image(images[15]) return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 else: return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16 def update_state(state): count = state if isinstance(state, int) else state.value return gr.State(count + 1) def update_img_label(state): count = state if isinstance(state, int) else state.value return f"### Target image: {class_names[count]}" def update_buttons(state): count = state if isinstance(state, int) else state.value max_images = config['dataset'][config['dataset']['name']]['n_classes'] finish_button = gr.Button("Finish", visible=(count == max_images-1)) submit_button = gr.Button("Submit", visible=(count != max_images-1)) return submit_button, finish_button def update_dropdowns(): dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam") dp2 = gr.Dropdown(choices=options, value=options[0], label="lime") dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu") dp4 = gr.Dropdown(choices=options, value=options[0], label="rise") return dp1, dp2, dp3, dp4 def redirect(): pass def save_results(answers): api = HfApi() json_file_results = config['results']['exp1_dir'] JSON_DATASET_DIR = Path("json_dataset") JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) JSON_DATASET_PATH = JSON_DATASET_DIR / json_file_results info_to_push = { "user_id": time.time(), "answer": {i: answer[i] for i in range(len(answer))}} # use api to push the results to the hub api.push_to_hub(info_to_push, json_file_results, use_temp_dir=True) def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers): rank = [dropdown1,dropdown2,dropdown3,dropdown4] answers.append(rank) return answers submit_button.click( update_state, inputs=user_state, outputs=user_state ).then( add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers], outputs=answers ).then( update_img_label, inputs=user_state, outputs=target_img_label ).then( update_buttons, inputs=user_state, outputs={submit_button, finish_button} ).then( update_images, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16}, ).then( update_dropdowns, outputs={dropdown1, dropdown2, dropdown3, dropdown4} ) finish_button.click( add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers ).then( save_results, inputs=answers ).then( redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'") demo.load() demo.launch() if __name__ == "__main__": main()