import cv2
import einops
import gradio as gr
import numpy as np
import torch

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import UniPCMultistepScheduler
from PIL import Image
from controlnet_aux import OpenposeDetector

# Constants
low_threshold = 100
high_threshold = 200


# Models
# controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
# pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
#     "runwayml/stable-diffusion-v1-5", controlnet=controlnet_canny, safety_checker=None, torch_dtype=torch.float16
# )
# pipe_canny.scheduler = UniPCMultistepScheduler.from_config(pipe_canny.scheduler.config)

# # This command loads the individual model components on GPU on-demand. So, we don't
# # need to explicitly call pipe.to("cuda").
# pipe_canny.enable_model_cpu_offload()

# pipe_canny.enable_xformers_memory_efficient_attention()

# Generator seed,
generator = torch.manual_seed(0)

torch_dtype = torch.float16  # or torch.bfloat16 (if needed)

pose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")

controlnet_pose = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-openpose", torch_dtype=torch_dtype
).to("cuda")  # Load it directly on GPU

pipe_pose = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet_pose,
    safety_checker=None,
    torch_dtype=torch_dtype
)

pipe_pose.scheduler = UniPCMultistepScheduler.from_config(pipe_pose.scheduler.config)

# pipe_pose.enable_model_cpu_offload()

# ✅ Enable xformers (Optimizes memory usage)
pipe_pose.enable_xformers_memory_efficient_attention()


# def get_canny_filter(image):
    
#     if not isinstance(image, np.ndarray):
#         image = np.array(image) 
        
#     image = cv2.Canny(image, low_threshold, high_threshold)
#     image = image[:
# , :, None]
#     image = np.concatenate([image, image, image], axis=2)
#     canny_image = Image.fromarray(image)
#     return canny_image

def get_pose(image):
    return pose_model(image) 
    
def process(input_image, prompt, input_control):
    # TODO: Add other control tasks
    #if input_control == "Pose":
    return process_pose(input_image, prompt)
    # else:    
    #     return process_canny(input_image, prompt)

# def process_canny(input_image, prompt):
#     canny_image = get_canny_filter(input_image)
#     output = pipe_canny(
#         prompt,
#         canny_image,
#         generator=generator,
#         num_images_per_prompt=1,
#         num_inference_steps=20,
#     )
#     return [canny_image,output.images[0]]


def process_pose(input_image, prompt):
    pose_image = get_pose(input_image)
    output = pipe_pose(
        prompt,
        pose_image,
        generator=generator,
        num_images_per_prompt=1,
        num_inference_steps=20,
    )
    return [pose_image,output.images[0]]
    
    
block = gr.Blocks().queue()
control_task_list = [
    "Canny Edge Map",
    "Pose"
]
with block:
    gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
    gr.HTML('''
     <p style="margin-bottom: 10px; font-size: 94%">
                This is an unofficial demo for ControlNet, which is a neural network structure to control diffusion models by adding extra conditions such as canny edge detection. The demo is based on the <a href="https://github.com/lllyasviel/ControlNet" style="text-decoration: underline;" target="_blank"> Github </a> implementation. 
              </p>
              ''')
    gr.HTML("<p>You can duplicate this Space to run it privately without a queue and load additional checkpoints.  : <a style='display:inline-block' href='https://huggingface.co/spaces/RamAnanth1/ControlNet?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14' alt='Duplicate Space'></a> </p>")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(sources=['upload'], type="numpy")
            # input_control = gr.Dropdown(control_task_list, value="Scribble", label="Control Task")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button(value="Run")
            
            
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=2, height='auto')
    ips = [input_image, prompt]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
    examples_list = [
        #         [
        #     "bird.png", 
        #     "bird",
        #     "Canny Edge Map"
            
        # ],
        
        #         [
        #     "turtle.png", 
        #     "turtle",
        #     "Scribble",
        #     "best quality, extremely detailed",
        #     'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
        #      1, 
        #     512,
        #     20, 
        #     9.0, 
        #     123490213,
        #     0.0,
        #     100,
        #     200
            
        # ],
                  [
            "pose1.png", 
           "Chef in the Kitchen",
           "Pose",
        #     "best quality, extremely detailed",
        #     'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
        #      1, 
        #     512,
        #     20, 
        #     9.0, 
        #     123490213,
        #     0.0,
        #     100,
        #     200
            
         ]
    ]
    examples = gr.Examples(examples=examples_list,inputs = [input_image, prompt], outputs = [result_gallery], cache_examples = True, fn = process)
    gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=RamAnanth1.ControlNet)")  

block.launch(debug = True)