import spaces
import gradio as gr
import os
import torch
import uuid

from PIL import Image
from enhance_utils import enhance_image


DEFAULT_SRC_PROMPT = "a woman"
DEFAULT_EDIT_PROMPT = "a woman, with red lips, 8k, high quality"

device = "cuda" if torch.cuda.is_available() else "cpu"

def create_demo() -> gr.Blocks:
    from inversion_run_adapter import run as adapter_run

    @spaces.GPU(duration=15)
    def image_to_image(
        input_image_path: str,
        input_image_prompt: str,
        edit_prompt: str,
        seed: int,
        w1: float,
        num_steps: int,
        start_step: int,
        guidance_scale: float,
        generate_size: int,
        lineart_scale: float,
        canny_scale: float,
        lineart_detect: float,
        canny_detect: float,
    ):
        w2 = 1.0

        input_image = Image.open(input_image_path)
        icc_profile = input_image.info.get("icc_profile")

        w2 = 1.0
        run_model = adapter_run
        generated_image = run_model(
            input_image,
            input_image_prompt,
            edit_prompt,
            generate_size,
            seed,
            w1,
            w2,
            num_steps,
            start_step,
            guidance_scale,
            lineart_scale,
            canny_scale,
            lineart_detect,
            canny_detect,
        )
        enhanced_image = enhance_image(generated_image, False)

        tmpPrefix = "/tmp/gradio/"

        extension = 'png'
        if enhanced_image.mode == 'RGBA':
            extension = 'png'
        else:
            extension = 'jpg'

        targetDir = f"{tmpPrefix}output/"
        if not os.path.exists(targetDir):
            os.makedirs(targetDir)

        enhanced_path = f"{targetDir}{uuid.uuid4()}.{extension}"
        enhanced_image.save(enhanced_path, quality=100, icc_profile=icc_profile)

        return enhanced_path

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                input_image_path = gr.File(label="Input Image")
            with gr.Column():
                generated_image_path = gr.File(label="Download the segment image", interactive=False)
        with gr.Row():
            with gr.Column():
                input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
                edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
                with gr.Accordion("Advanced Options", open=False):
                    guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
                    seed = gr.Number(label="Seed", value=8)
                    generate_size = gr.Number(label="Generate Size", value=1024)
                    lineart_scale = gr.Slider(minimum=0, maximum=5, value=0.8, step=0.1, label="Lineart Weights", visible=True)
                    canny_scale = gr.Slider(minimum=0, maximum=5, value=0.4, step=0.1, label="Canny Weights", visible=True)
                    lineart_detect = gr.Number(label="Lineart Detect", value=0.375, visible=True)
                    canny_detect = gr.Number(label="Canny Detect", value=0.375, visible=True)
            with gr.Column():
                num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
                start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
                w1 = gr.Number(label="W1", value=2)
                g_btn = gr.Button("Edit Image")
                
        
        g_btn.click(
            fn=image_to_image,
            inputs=[input_image_path, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, lineart_scale, canny_scale, lineart_detect, canny_detect],
            outputs=[generated_image_path],
        )

    return demo