import os
import subprocess
from pathlib import Path

import gradio as gr
import torch

from demo import SdmCompressionDemo

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    servicer = SdmCompressionDemo(device)
    example_list = servicer.get_example_list()

    with gr.Blocks(theme='nota-ai/theme') as demo:
        gr.Markdown(Path('docs/header.md').read_text())
        gr.Markdown(Path('docs/description.md').read_text())
        with gr.Row():
            with gr.Column(variant='panel', scale=30):

                text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt")

                with gr.Row().style(equal_height=True):
                    generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
                    generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")

                with gr.Accordion("Advanced Settings", open=False):
                    negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})')
                    with gr.Row():
                        guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5)
                        steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5)
                        seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1)

                with gr.Tab("Example Prompts"):
                    examples = gr.Examples(examples=example_list, inputs=[text])

            with gr.Column(variant='panel',scale=35):
                # Define original model output components
                gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>')
                original_model_output = gr.Image(label="Original Model")
                with gr.Row().style(equal_height=True):
                    with gr.Column():
                        original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
                        original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters")
                    original_model_error = gr.Markdown()
                

            with gr.Column(variant='panel',scale=35):
                # Define compressed model output components
                gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>')
                compressed_model_output = gr.Image(label="Compressed Model")                
                with gr.Row().style(equal_height=True):
                    with gr.Column():
                        compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)")
                        compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters")
                    compressed_model_error = gr.Markdown()

        inputs = [text, negative, guidance_scale, steps, seed]

        # Click the generate button for original model
        original_model_outputs = [original_model_output, original_model_error, original_model_test_time]
        text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)
        generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs)

        # Click the generate button for compressed model
        compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time]
        text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)
        generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs)

        gr.Markdown(Path('docs/footer.md').read_text())

    demo.queue(concurrency_count=1)
    # demo.launch()
    demo.launch()