import gradio as gr
import yaml
import random
import os
import json
from pathlib import Path
from huggingface_hub import CommitScheduler, HfApi

from src.utils import load_words, load_image_and_saliency, load_example_images
from src.style import css
from src.user import UserID

def main():
    config = yaml.safe_load(open("config/config.yaml"))
    words = ['grad-cam', 'lime', 'sidu', 'rise']
    options = ['-', '1', '2', '3', '4']
    class_names = config['dataset'][config['dataset']['name']]['class_names']
    data_dir = os.path.join(config['dataset']['path'], config['dataset']['name'])

    with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
        # Main App Components
        title = gr.Markdown("# Saliency evaluation - experiment 1")
        user_state = gr.State(0)
        #user_id = gr.State(load_global_variable())
        answers = gr.State([])

        with gr.Row():
            target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
            gr.Markdown("Grad-cam")
            gr.Markdown("Lime")
            gr.Markdown("Sidu")
            gr.Markdown("Rise")

        with gr.Row():
            count = user_state if isinstance(user_state, int) else user_state.value
            images = load_image_and_saliency(count, data_dir)
            target_img = gr.Image(images[0], elem_classes="main-image")
            saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
            saliency_lime = gr.Image(images[2], elem_classes="main-image")
            saliency_sidu = gr.Image(images[3], elem_classes="main-image")
            saliency_rise = gr.Image(images[4], elem_classes="main-image")
            
        with gr.Row():
            dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
            dropdown2 = gr.Dropdown(choices=options, label="lime")
            dropdown3 = gr.Dropdown(choices=options, label="sidu")
            dropdown4 = gr.Dropdown(choices=options, label="rise")

        gr.Markdown("### Image examples of the same class")
        with gr.Row():
            count = user_state if isinstance(user_state, int) else user_state.value
            images = load_example_images(count, data_dir)
            img1 = gr.Image(images[0])
            img2 = gr.Image(images[1])
            img3 = gr.Image(images[2])
            img4 = gr.Image(images[3])
            img5 = gr.Image(images[4])
            img6 = gr.Image(images[5])
            img7 = gr.Image(images[6])
            img8 = gr.Image(images[7])
            img9 = gr.Image(images[8])
            img10 = gr.Image(images[9])
            img11 = gr.Image(images[10])
            img12 = gr.Image(images[11])
            img13 = gr.Image(images[12])
            img14 = gr.Image(images[13])
            img15 = gr.Image(images[14])
            img16 = gr.Image(images[15])
            
        submit_button = gr.Button("Submit")
        finish_button = gr.Button("Finish", visible=False)

        def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state):

            count = user_state if isinstance(user_state, int) else user_state.value
            if count < config['dataset'][config['dataset']['name']]['n_classes']:
                images = load_image_and_saliency(count, data_dir)
                target_img = gr.Image(images[0], elem_classes="main-image")
                saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
                saliency_lime = gr.Image(images[2], elem_classes="main-image")
                saliency_sidu = gr.Image(images[3], elem_classes="main-image")
                saliency_rise = gr.Image(images[4], elem_classes="main-image")

                # image examples
                images = load_example_images(count, data_dir)
                img1 = gr.Image(images[0])
                img2 = gr.Image(images[1])
                img3 = gr.Image(images[2])
                img4 = gr.Image(images[3])
                img5 = gr.Image(images[4])
                img6 = gr.Image(images[5])
                img7 = gr.Image(images[6])
                img8 = gr.Image(images[7])
                img9 = gr.Image(images[8])
                img10 = gr.Image(images[9])
                img11 = gr.Image(images[10])
                img12 = gr.Image(images[11])
                img13 = gr.Image(images[12])
                img14 = gr.Image(images[13])
                img15 = gr.Image(images[14])
                img16 = gr.Image(images[15])
                return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
            else:
                return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16

        def update_state(state):
            count = state if isinstance(state, int) else state.value
            return gr.State(count + 1)

        def update_img_label(state):
            count = state if isinstance(state, int) else state.value
            return f"### Target image: {class_names[count]}"

        def update_buttons(state):
            count = state if isinstance(state, int) else state.value
            max_images = config['dataset'][config['dataset']['name']]['n_classes']
            finish_button = gr.Button("Finish", visible=(count == max_images-1))
            submit_button = gr.Button("Submit", visible=(count != max_images-1))
            return submit_button, finish_button

        def update_dropdowns():
            dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam")
            dp2 = gr.Dropdown(choices=options, value=options[0], label="lime")
            dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu")
            dp4 = gr.Dropdown(choices=options, value=options[0], label="rise")
            return dp1, dp2, dp3, dp4

        def redirect():
            pass

        def save_results(answers):
            api = HfApi()
            json_file_results = config['results']['exp1_dir']
            JSON_DATASET_DIR = Path("json_dataset")
            JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
            JSON_DATASET_PATH = JSON_DATASET_DIR / json_file_results
            
            info_to_push = {
                "user_id": time.time(), 
                "answer": {i: answer[i] for i in range(len(answer))}}

            # use api to push the results to the hub 
            api.push_to_hub(info_to_push, json_file_results, use_temp_dir=True)

 
        def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers):
            rank = [dropdown1,dropdown2,dropdown3,dropdown4]
            answers.append(rank)
            return answers

        submit_button.click(
            update_state,
            inputs=user_state,
            outputs=user_state
        ).then(
            add_answer,
            inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],
            outputs=answers
        ).then(
            update_img_label,
            inputs=user_state,
            outputs=target_img_label
        ).then(
            update_buttons,
            inputs=user_state,
            outputs={submit_button, finish_button}
        ).then(
            update_images, 
            inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], 
            outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16},
        ).then(
            update_dropdowns,
            outputs={dropdown1, dropdown2, dropdown3, dropdown4}
        )
       
        finish_button.click(
            add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
        ).then(
            save_results, inputs=answers
        ).then(
            redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")

        demo.load()
    demo.launch()

if __name__ == "__main__":
    main()