import gradio as gr
from utils import func_generate
import random

is_clicked = False
out_img_list = ['', '', '', '', '']
out_state_list = ['', '', '', '', '']
out_state_list2 = ['', '', '', '', '']
seed_values = [0, 0, 0, 0, 0]

def fn_query_on_load():
    return "Cats at sunset"

def fn_refresh():
    return out_img_list

with gr.Blocks() as app:
    with gr.Row():
        gr.Markdown(
            """
            # Stable Diffusion Image Generation
            ### Enter prompt to generate images in various styles
            """)

    with gr.Row(visible=True):
        search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter image prompt..', label='Enter Image Prompt')

    with gr.Tab('Generate Image in various styles'):
    
        with gr.Row():
            concept_index = gr.Dropdown(label='Select image style', value='Oil Painting', type="index", choices=['Oil Painting', 'Low Poly HD Style', 'Matrix Style', 'Dreamy Painting', 'Depth Map Style'] )

        with gr.Row(visible=True):
            submit_btn = gr.Button("Submit", variant='primary')
            clear_btn = gr.ClearButton()
            
        with gr.Row(visible=True):
            out1 = gr.Image(value="out1.png", interactive=False, label='Oil Painting')
            out2 = gr.Image(value="out2.png", interactive=False, label='Low Poly HD Style')
            out3 = gr.Image(value="out3.png", interactive=False, label='Matrix Style')
            out4 = gr.Image(value="out4.png", interactive=False, label='Dreamy Painting')
            out5 = gr.Image(value="out5.png", interactive=False, label='Depth Map Style')

    with gr.Tab("Additional Guidance with Contrast Adjustment"):
        
        with gr.Row():
            gr.Markdown(
                """
                ### Experiment with contrast based additional guidance to view how it affects the output
                """)

        with gr.Row():
            concept_index2 = gr.Dropdown(label='Select image style', value='Oil Painting', type="index", choices=['Oil Painting', 'Low Poly HD Style', 'Matrix Style', 'Dreamy Painting', 'Depth Map Style'] )
            contrast_perc = gr.Slider(value=90, minimum=-100, maximum=100, label='Contrast Adjustment')

        with gr.Row(visible=True):
            submit_btn2 = gr.Button("Submit", variant='primary')
            clear_btn2 = gr.ClearButton()

        with gr.Row(visible=True):
            out11 = gr.Image(value="out11.png", interactive=False, label='Oil Painting')
            out12 = gr.Image(value="out12.png", interactive=False, label='Low Poly HD Style')
            out13 = gr.Image(value="out13.png", interactive=False, label='Matrix Style')
            out14 = gr.Image(value="out14.png", interactive=False, label='Dreamy Painting')
            out15 = gr.Image(value="out15.png", interactive=False, label='Depth Map Style')

        
    def clear_data():
        return {
            out1: None,
            out2: None,
            out3: None,
            out4: None,
            out5: None,
            search_text: None
        }

    def clear_data2():
        return {
            out11: None,
            out12: None,
            out13: None,
            out14: None,
            out15: None
        }


    clear_btn.click(clear_data, None, [out1, out2, out3, out4, out5, search_text])
    clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])


    def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
        if not query:
            raise gr.Error("No prompt provided")
            return {
                out1: o1,
                out2: o2,
                out3: o3,
                out4: o4,
                out5: o5
            }
        else:
            out = func_generate(query, con_idx, con_idx*30)
            out_state_list[con_idx] = query

            if con_idx == 0: 
                return {
                    out1: out,
                    out2: None if out_state_list[1] != query else o2,
                    out3: None if out_state_list[2] != query else o3,
                    out4: None if out_state_list[3] != query else o4,
                    out5: None if out_state_list[4] != query else o5
                }
            elif con_idx == 1: 
                return {
                    out1: None if out_state_list[0] != query else o1,
                    out2: out,
                    out3: None if out_state_list[2] != query else o2,
                    out4: None if out_state_list[3] != query else o3,
                    out5: None if out_state_list[4] != query else o4
                }
            elif con_idx == 2: 
                return {
                    out1: None if out_state_list[0] != query else o1,
                    out2: None if out_state_list[1] != query else o2,
                    out3: out,
                    out4: None if out_state_list[3] != query else o3,
                    out5: None if out_state_list[4] != query else o4
                }
            elif con_idx == 3: 
                return {
                    out1: None if out_state_list[0] != query else o1,
                    out2: None if out_state_list[1] != query else o2,
                    out3: None if out_state_list[2] != query else o3,
                    out4: out,
                    out5: None if out_state_list[4] != query else o4
                }
            elif con_idx == 4: 
                return {
                    out1: None if out_state_list[0] != query else o1,
                    out2: None if out_state_list[1] != query else o2,
                    out3: None if out_state_list[2] != query else o3,
                    out4: None if out_state_list[3] != query else o4,
                    out5: out
                }
                
            
        
    submit_btn.click(
        generate_image,
        [search_text, concept_index, out1, out2, out3, out4, out5],
        [out1, out2, out3, out4, out5]
    )
    
    def generate_image_with_contrast_loss_guidance(query, con_idx, o1, o2, o3, o4, o5, contrast):
        if not query:
            raise gr.Error("No prompt provided")
            return {
                out11: o1,
                out12: o2,
                out13: o3,
                out14: o4,
                out15: o5
            }
        else:
            out = func_generate(query, con_idx, con_idx*30, contrast_loss=True, contrast_perc=contrast)
            out_state_list[con_idx] = query

            if con_idx == 0: 
                return {
                    out11: out,
                    out12: None if out_state_list[1] != query else o2,
                    out13: None if out_state_list[2] != query else o3,
                    out14: None if out_state_list[3] != query else o4,
                    out15: None if out_state_list[4] != query else o5
                }
            elif con_idx == 1: 
                return {
                    out11: None if out_state_list[0] != query else o1,
                    out12: out,
                    out13: None if out_state_list[2] != query else o2,
                    out14: None if out_state_list[3] != query else o3,
                    out15: None if out_state_list[4] != query else o4
                }
            elif con_idx == 2: 
                return {
                    out11: None if out_state_list[0] != query else o1,
                    out12: None if out_state_list[1] != query else o2,
                    out13: out,
                    out14: None if out_state_list[3] != query else o3,
                    out15: None if out_state_list[4] != query else o4
                }
            elif con_idx == 3: 
                return {
                    out11: None if out_state_list[0] != query else o1,
                    out12: None if out_state_list[1] != query else o2,
                    out13: None if out_state_list[2] != query else o3,
                    out14: out,
                    out15: None if out_state_list[4] != query else o4
                }
            elif con_idx == 4: 
                return {
                    out11: None if out_state_list[0] != query else o1,
                    out12: None if out_state_list[1] != query else o2,
                    out13: None if out_state_list[2] != query else o3,
                    out14: None if out_state_list[3] != query else o4,
                    out15: out
                }
            
    submit_btn2.click(
        generate_image_with_contrast_loss_guidance,
        [search_text, concept_index2, out11, out12, out13, out14, out15, contrast_perc],
        [out11, out12, out13, out14, out15]
    )

'''
Launch the app
'''
app.queue(concurrency_count=4, max_size=6).launch(max_threads=8)