import gradio as gr
import numpy as np
import random
import os
import torch
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, AutoencoderTiny, DDIMScheduler
from diffusers.utils import load_image
from peft import PeftModel, LoraConfig
from rembg import remove


device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"

if torch.cuda.is_available():
    torch_dtype = torch.float16
else:
    torch_dtype = torch.float32

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    prompt,
    negative_prompt,
    width=512,
    height=512,
    model_id=model_id_default,
    seed=42,
    guidance_scale=7.0,
    lora_scale=1.0,
    num_inference_steps=20,
    controlnet_checkbox=False,
    controlnet_strength=0.0,
    controlnet_mode="edge_detection",
    controlnet_image=None,
    ip_adapter_checkbox=False,
    ip_adapter_scale=0.0,
    ip_adapter_image=None,

    tiny_vae=False,
    ddim=False,
    
    del_background=False,
    alpha_matting=False, 
    alpha_matting_foreground_threshold=240, 
    alpha_matting_background_threshold=10, 
    alpha_matting_erode_size=10, 
    post_process_mask=False,
    
    progress=gr.Progress(track_tqdm=True),    
):  
    if model_id == model_id_default:
        ckpt_dir='./model_output'
    elif 'base' in model_id:
        ckpt_dir='./model_output_distilled_base'
    else:
        ckpt_dir='./model_output_distilled_small'
        
    unet_sub_dir = os.path.join(ckpt_dir, "unet")
    text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")

    if model_id is None:
        raise ValueError("Please specify the base model name or path")

    generator = torch.Generator(device).manual_seed(seed)
    params = {'prompt': prompt,
              'negative_prompt': negative_prompt,
              'guidance_scale': guidance_scale,
              'num_inference_steps': num_inference_steps,
              'width': width,
              'height': height,
              'generator': generator,
              'cross_attention_kwargs': {"scale": lora_scale}
             }

    if controlnet_checkbox:
        if controlnet_mode == "depth_map":
            controlnet = ControlNetModel.from_pretrained(
                "lllyasviel/sd-controlnet-depth",
                cache_dir="./models_cache",
                torch_dtype=torch_dtype
            )
        elif controlnet_mode == "pose_estimation":
            controlnet = ControlNetModel.from_pretrained(
                "lllyasviel/sd-controlnet-openpose",
                cache_dir="./models_cache",
                torch_dtype=torch_dtype
            )
        elif controlnet_mode == "normal_map":
            controlnet = ControlNetModel.from_pretrained(
                "lllyasviel/sd-controlnet-normal",
                cache_dir="./models_cache",
                torch_dtype=torch_dtype
            )
        elif controlnet_mode == "scribbles":
            controlnet = ControlNetModel.from_pretrained(
                "lllyasviel/sd-controlnet-scribble",
                cache_dir="./models_cache",
                torch_dtype=torch_dtype
            )
        else:
            controlnet = ControlNetModel.from_pretrained(
                "lllyasviel/sd-controlnet-canny",
                cache_dir="./models_cache",
                torch_dtype=torch_dtype
            )
        pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, 
                                                                 controlnet=controlnet,
                                                                 torch_dtype=torch_dtype, 
                                                                 safety_checker=None).to(device)
        params['image'] = controlnet_image
        params['controlnet_conditioning_scale'] = float(controlnet_strength)
    else:
        pipe = StableDiffusionPipeline.from_pretrained(model_id, 
                                                       torch_dtype=torch_dtype, 
                                                       safety_checker=None).to(device)

    pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
    pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)

    # pipe.unet.add_weighted_adapter(['default'], [lora_scale], 'lora')
    # pipe.text_encoder.add_weighted_adapter(['default'], [lora_scale], 'lora')

    # pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
    # pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})

    if tiny_vae:
        pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch_dtype)

    if ddim:
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    
    if torch_dtype in (torch.float16, torch.bfloat16):
        pipe.unet.half()
        pipe.text_encoder.half()

    if ip_adapter_checkbox:
        pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
        pipe.set_ip_adapter_scale(ip_adapter_scale)
        params['ip_adapter_image'] = ip_adapter_image

    pipe.to(device)

    if del_background:
        return remove(pipe(**params).images[0],
                      alpha_matting=alpha_matting, 
                      alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, 
                      alpha_matting_background_threshold=alpha_matting_background_threshold, 
                      alpha_matting_erode_size=alpha_matting_erode_size, 
                      post_process_mask=post_process_mask
                     )
    else:
        return pipe(**params).images[0]


css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

def controlnet_params(show_extra):
    return gr.update(visible=show_extra)
    
with gr.Blocks(css=css, fill_height=True) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # Text-to-Image demo")

        with gr.Row():
            model_id = gr.Dropdown(
                    label="Model ID",
                    choices=[model_id_default,
                             "nota-ai/bk-sdm-v2-base",
                             "nota-ai/bk-sdm-v2-small"],
                    value=model_id_default,
                    max_choices=1
                )
            

        prompt = gr.Textbox(
            label="Prompt",
            max_lines=1,
            placeholder="Enter your prompt",
        )
        
        negative_prompt = gr.Textbox(
            label="Negative prompt",
            max_lines=1,
            placeholder="Enter your negative prompt",
        )
        
        with gr.Row():
            seed = gr.Number(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )
            
            guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=0.0,
                maximum=30.0,
                step=0.1,
                value=7.0,  # Replace with defaults that work for your model
            )
        with gr.Row():
            lora_scale = gr.Slider(
                label="LoRA scale",
                minimum=0.0,
                maximum=1.0,
                step=0.01,
                value=1.0,
            )

            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=1,
                maximum=100,
                step=1,
                value=20,  # Replace with defaults that work for your model
            )
        with gr.Row():
            tiny_vae = gr.Checkbox(
                label="Use AutoencoderTiny?",
                value=False
            )
            ddim = gr.Checkbox(
                label="Use DDIMScheduler?",
                value=False
            )
            
        with gr.Row():
            del_background = gr.Checkbox(
                label="Delete background?",
                value=False
            )
            with gr.Column(visible=False) as rembg_params:
                alpha_matting = gr.Checkbox(
                    label="alpha_matting",
                    value=False
                )
                with gr.Column(visible=False) as alpha_params:
                    alpha_matting_foreground_threshold = gr.Slider(
                        label="alpha_matting_foreground_threshold",
                        minimum=0,
                        maximum=255,
                        step=1,
                        value=240,  
                    )
                    alpha_matting_background_threshold = gr.Slider(
                        label="alpha_matting_background_threshold",
                        minimum=0,
                        maximum=255,
                        step=1,
                        value=10,  
                    )
                    alpha_matting_erode_size = gr.Slider(
                        label="alpha_matting_erode_size",
                        minimum=0,
                        maximum=100,
                        step=1,
                        value=10,  
                    )
                alpha_matting.change(
                    fn=lambda x: gr.Row.update(visible=x),
                    inputs=alpha_matting,
                    outputs=alpha_params
                )
                post_process_mask = gr.Checkbox(
                    label="post_process_mask",
                    value=False
                )
            del_background.change(
                fn=lambda x: gr.Row.update(visible=x),
                inputs=del_background,
                outputs=rembg_params
            )
            
        with gr.Row():
            controlnet_checkbox = gr.Checkbox(
                label="ControlNet",
                value=False
            )
            with gr.Column(visible=False) as controlnet_params:
                controlnet_strength = gr.Slider(
                    label="ControlNet conditioning scale",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.01,
                    value=1.0,  
                )
                controlnet_mode = gr.Dropdown(
                    label="ControlNet mode",
                    choices=["edge_detection", 
                             "depth_map",
                             "pose_estimation", 
                             "normal_map",
                             "scribbles"],
                    value="edge_detection",
                    max_choices=1
                )
                controlnet_image = gr.Image(
                    label="ControlNet condition image",
                    type="pil",
                    format="png"
                )
            controlnet_checkbox.change(
                fn=lambda x: gr.Row.update(visible=x),
                inputs=controlnet_checkbox,
                outputs=controlnet_params
            )

        with gr.Row():
            ip_adapter_checkbox = gr.Checkbox(
                label="IPAdapter",
                value=False
            )
            with gr.Column(visible=False) as ip_adapter_params:
                ip_adapter_scale = gr.Slider(
                    label="IPAdapter scale",
                    minimum=0.0,
                    maximum=1.0,
                    step=0.01,
                    value=1.0,  
                )
                ip_adapter_image = gr.Image(
                    label="IPAdapter condition image",
                    type="pil"
                )
            ip_adapter_checkbox.change(
                fn=lambda x: gr.Row.update(visible=x),
                inputs=ip_adapter_checkbox,
                outputs=ip_adapter_params
            )
            
        with gr.Accordion("Optional Settings", open=False):
            
            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,  # Replace with defaults that work for your model
                )

                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=512,  # Replace with defaults that work for your model
                )
        
        run_button = gr.Button("Run", scale=0, variant="primary")
        result = gr.Image(label="Result", show_label=False)
            
    gr.on(
        triggers=[run_button.click],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            width,
            height,
            model_id,
            seed,
            guidance_scale,      
            lora_scale,
            num_inference_steps,
            controlnet_checkbox,
            controlnet_strength,
            controlnet_mode,
            controlnet_image,
            ip_adapter_checkbox,
            ip_adapter_scale,
            ip_adapter_image,  
            tiny_vae,
            ddim,
            del_background,
            alpha_matting, 
            alpha_matting_foreground_threshold, 
            alpha_matting_background_threshold, 
            alpha_matting_erode_size, 
            post_process_mask,
        ],
        outputs=[result],
    )

if __name__ == "__main__":
    demo.launch()