from controlnet_aux import OpenposeDetector
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import UniPCMultistepScheduler
import gradio as gr
import torch
import base64
from io import BytesIO
from PIL import Image
# live conditioning
canvas_html = "<pose-canvas id='canvas-root' style='display:flex;max-width: 500px;margin: 0 auto;'></pose-canvas>"
load_js = """
async () => {
  const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/pose-gradio.js"
  fetch(url)
    .then(res => res.text())
    .then(text => {
      const script = document.createElement('script');
      script.type = "module"
      script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
      document.head.appendChild(script);
    });
}
"""
get_js_image = """
async (image_in_img, prompt, image_file_live_opt, live_conditioning) => {
  const canvasEl = document.getElementById("canvas-root");
  const data = canvasEl? canvasEl._data : null;
  return [image_in_img, prompt, image_file_live_opt, data]
}
"""

# Constants
low_threshold = 100
high_threshold = 200

# Models
pose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

# This command loads the individual model components on GPU on-demand. So, we don't
# need to explicitly call pipe.to("cuda").
pipe.enable_model_cpu_offload()

# xformers
pipe.enable_xformers_memory_efficient_attention()

# Generator seed,
generator = torch.manual_seed(0)


def get_pose(image):
    return pose_model(image)


def generate_images(image, prompt, image_file_live_opt='file', live_conditioning=None):
    if image is None and 'image' not in live_conditioning:
        raise gr.Error("Please provide an image")
    try:
        if image_file_live_opt == 'file':
            pose = get_pose(image)
        elif image_file_live_opt == 'webcam':
            base64_img = live_conditioning['image']
            image_data = base64.b64decode(base64_img.split(',')[1])
            pose = Image.open(BytesIO(image_data)).convert(
                'RGB').resize((512, 512))
        output = pipe(
            prompt,
            pose,
            generator=generator,
            num_images_per_prompt=3,
            num_inference_steps=20,
        )
        all_outputs = []
        all_outputs.append(pose)
        for image in output.images:
            all_outputs.append(image)
        return all_outputs
    except Exception as e:
        raise gr.Error(str(e))


def toggle(choice):
    if choice == "file":
        return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
    elif choice == "webcam":
        return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)


with gr.Blocks() as blocks:
    gr.Markdown("""
    ## Generate controlled outputs with ControlNet and Stable Diffusion
    This Space uses pose estimated lines as the additional conditioning
    [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
    """)
    with gr.Row():
        live_conditioning = gr.JSON(value={}, visible=False)
        with gr.Column():
            image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
                                           label="How would you like to upload your image?")
            image_in_img = gr.Image(source="upload", visible=True, type="pil")
            canvas = gr.HTML(None, elem_id="canvas_html", visible=False)

            image_file_live_opt.change(fn=toggle,
                                       inputs=[image_file_live_opt],
                                       outputs=[image_in_img, canvas],
                                       queue=False)
            prompt = gr.Textbox(
                label="Enter your prompt",
                max_lines=1,
                placeholder="best quality, extremely detailed",
            )
            run_button = gr.Button("Generate")
        with gr.Column():
            gallery = gr.Gallery().style(grid=[2], height="auto")
    run_button.click(fn=generate_images,
                     inputs=[image_in_img, prompt,
                             image_file_live_opt, live_conditioning],
                     outputs=[gallery],
                     _js=get_js_image)
    blocks.load(None, None, None, _js=load_js)

    gr.Examples(fn=generate_images,
                examples=[
                    ["./yoga1.jpeg",
                        "best quality, extremely detailed"]
                ],
                inputs=[image_in_img, prompt],
                outputs=[gallery],
                cache_examples=True)

blocks.launch(debug=True)