import os

os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download')
result_dir = os.path.join('./', 'results')
os.makedirs(result_dir, exist_ok=True)


import functools
import os
import random
import gradio as gr
import numpy as np
import torch
import wd14tagger
import uuid

from PIL import Image
from diffusers_helper.code_cond import unet_add_coded_conds
from diffusers_helper.cat_cond import unet_add_concat_conds
from diffusers_helper.k_diffusion import KDiffusionSampler
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
import spaces

# Disable gradients globally
torch.set_grad_enabled(False)

class ModifiedUNet(UNet2DConditionModel):
    @classmethod
    def from_config(cls, *args, **kwargs):
        m = super().from_config(*args, **kwargs)
        unet_add_concat_conds(unet=m, new_channels=4)
        unet_add_coded_conds(unet=m, added_number_count=1)
        return m


model_name = 'lllyasviel/paints_undo_single_frame'
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda")
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda")  # bfloat16 vae
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda")

unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())

video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
    'lllyasviel/paints_undo_multi_frame',
    fp16=True
).to("cuda")

k_sampler = KDiffusionSampler(
    unet=unet,
    timesteps=1000,
    linear_start=0.00085,
    linear_end=0.020,
    linear=True
)


def find_best_bucket(h, w, options):
    min_metric = float('inf')
    best_bucket = None
    for (bucket_h, bucket_w) in options:
        metric = abs(h * bucket_w - w * bucket_h)
        if metric <= min_metric:
            min_metric = metric
            best_bucket = (bucket_h, bucket_w)
    return best_bucket


def encode_cropped_prompt_77tokens(txt: str):
    cond_ids = tokenizer(txt,
                         padding="max_length",
                         max_length=tokenizer.model_max_length,
                         truncation=True,
                         return_tensors="pt").input_ids.to(device="cuda")
    text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
    return text_cond


def pytorch2numpy(imgs):
    results = []
    for x in imgs:
        y = x.movedim(0, -1)
        y = y * 127.5 + 127.5
        y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
        results.append(y)
    return results


def numpy2pytorch(imgs):
    h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
    h = h.movedim(-1, 1)
    return h


def resize_without_crop(image, target_width, target_height):
    pil_image = Image.fromarray(image)
    resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
    return np.array(resized_image)


def interrogator_process(x):
    image_description = wd14tagger.default_interrogator(x)
    return image_description


@spaces.GPU()
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
            progress=gr.Progress()):
    rng = torch.Generator(device="cuda").manual_seed(int(seed))

    fg = resize_and_center_crop(input_fg, image_width, image_height)
    concat_conds = numpy2pytorch([fg]).clone().detach().to(device="cuda", dtype=vae.dtype)
    concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor

    conds = encode_cropped_prompt_77tokens(prompt)
    unconds = encode_cropped_prompt_77tokens(n_prompt)

    fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
    initial_latents = torch.zeros_like(concat_conds)
    concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
    latents = k_sampler(
        initial_latent=initial_latents,
        strength=1.0,
        num_inference_steps=steps,
        guidance_scale=cfg,
        batch_size=len(input_undo_steps),
        generator=rng,
        prompt_embeds=conds,
        negative_prompt_embeds=unconds,
        cross_attention_kwargs={'concat_conds': concat_conds, 'coded_conds': fs},
        same_noise_in_batch=True,
        progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
    ).to(vae.dtype) / vae.config.scaling_factor

    pixels = vae.decode(latents).sample
    pixels = pytorch2numpy(pixels)
    pixels = [fg] + pixels + [np.zeros_like(fg) + 255]

    return pixels


def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    frames = 16

    target_height, target_width = find_best_bucket(
        image_1.shape[0], image_1.shape[1],
        options=[(320, 512), (384, 448), (448, 384), (512, 320)]
    )

    image_1 = resize_and_center_crop(image_1, target_width=target_width, target_height=target_height)
    image_2 = resize_and_center_crop(image_2, target_width=target_width, target_height=target_height)
    input_frames = numpy2pytorch([image_1, image_2])
    input_frames = input_frames.unsqueeze(0).movedim(1, 2)

    positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
    negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")

    input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
    positive_image_cond = video_pipe.encode_clip_vision(input_frames)
    positive_image_cond = video_pipe.image_projection(positive_image_cond)
    negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
    negative_image_cond = video_pipe.image_projection(negative_image_cond)

    input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
    input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
    first_frame = input_frame_latents[:, :, 0]
    last_frame = input_frame_latents[:, :, 1]
    concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)

    latents = video_pipe(
        batch_size=1,
        steps=int(steps),
        guidance_scale=cfg_scale,
        positive_text_cond=positive_text_cond,
        negative_text_cond=negative_text_cond,
        positive_image_cond=positive_image_cond,
        negative_image_cond=negative_image_cond,
        concat_cond=concat_cond,
        fs=fs,
        progress_tqdm=progress_tqdm
    )

    video = video_pipe.decode_latents(latents, vae_hidden_states)
    return video, image_1, image_2


@spaces.GPU(duration=333360)
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
    result_frames = []
    cropped_images = []

    for i, (im1, im2) in enumerate(zip(keyframes[:-1], keyframes[1:])):
        im1 = np.array(Image.open(im1[0]))
        im2 = np.array(Image.open(im2[0]))
        frames, im1, im2 = process_video_inner(
            im1, im2, prompt, seed=seed + i, steps=steps, cfg_scale=cfg, fs=3,
            progress_tqdm=functools.partial(progress.tqdm, desc=f'Generating Videos ({i + 1}/{len(keyframes) - 1})')
        )
        result_frames.append(frames[:, :, :-1, :, :])
        cropped_images.append([im1, im2])

    video = torch.cat(result_frames, dim=2)
    video = torch.flip(video, dims=[2])

    uuid_name = str(uuid.uuid4())
    output_filename = os.path.join(result_dir, uuid_name + '.mp4')
    Image.fromarray(cropped_images[0][0]).save(os.path.join(result_dir, uuid_name + '.png'))
    video = save_bcthw_as_mp4(video, output_filename, fps=fps)
    video = [x.cpu().numpy() for x in video]
    return output_filename, video


block = gr.Blocks().queue()
with block:
    gr.HTML("<center><h1>Paints Undo</h1></center>")
    gr.HTML("<center><h3>Paints Undo is a tool that generates key frames and videos from a single image.</h3></center>")
    gr.HTML("<center><a href='https://github.com/lllyasviel/Paints-UNDO'>GitHub</a> - <a href='https://lllyasviel.github.io/pages/paints_undo/'>Project Page</a></center>")

    with gr.Accordion(label='Step 1: Upload Image and Generate Prompt', open=True):
        with gr.Row():
            with gr.Column():
                input_fg = gr.Image(sources=['upload'], type="numpy", label="Image", height=512)
            with gr.Column():
                prompt_gen_button = gr.Button(value="Generate Prompt", interactive=False)
                prompt = gr.Textbox(label="Output Prompt", interactive=True)

    with gr.Accordion(label='Step 2: Generate Key Frames', open=True):
        with gr.Row():
            with gr.Column():
                input_undo_steps = gr.Dropdown(label="Operation Steps", value=[400, 600, 800, 900, 950, 999],
                                               choices=list(range(1000)), multiselect=True)
                seed = gr.Slider(label='Stage 1 Seed', minimum=0, maximum=50000, step=1, value=12345)
                image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
                image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
                steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
                cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=3.0, step=0.01)
                n_prompt = gr.Textbox(label="Negative Prompt",
                                      value='lowres, bad anatomy, bad hands, cropped, worst quality')

            with gr.Column():
                key_gen_button = gr.Button(value="Generate Key Frames", interactive=False)
                result_gallery = gr.Gallery(height=512, object_fit='contain', label='Outputs', columns=4)

    with gr.Accordion(label='Step 3: Generate All Videos', open=True):
        with gr.Row():
            with gr.Column():
                i2v_input_text = gr.Text(label='Prompts', value='1girl, masterpiece, best quality')
                i2v_seed = gr.Slider(label='Stage 2 Seed', minimum=0, maximum=50000, step=1, value=123)
                i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5,
                                          elem_id="i2v_cfg_scale")
                i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps",
                                      label="Sampling steps", value=50)
                i2v_fps = gr.Slider(minimum=1, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=4)
            with gr.Column():
                i2v_end_btn = gr.Button("Generate Video", interactive=False)
                i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
                                            show_share_button=True, height=512)
        with gr.Row():
            i2v_output_images = gr.Gallery(height=512, label="Output Frames", object_fit="contain", columns=8)

    input_fg.change(lambda: ["", gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False)],
                    outputs=[prompt, prompt_gen_button, key_gen_button, i2v_end_btn])

    prompt_gen_button.click(
        fn=interrogator_process,
        inputs=[input_fg],
        outputs=[prompt]
    ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)],
           outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])

    key_gen_button.click(
        fn=process,
        inputs=[input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg],
        outputs=[result_gallery]
    ).then(lambda: [gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)],
           outputs=[prompt_gen_button, key_gen_button, i2v_end_btn])

    i2v_end_btn.click(
        inputs=[result_gallery, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_fps, i2v_seed],
        outputs=[i2v_output_video, i2v_output_images],
        fn=process_video
    )

    dbs = [
        ['./imgs/1.jpg', 12345, 123],
        ['./imgs/2.jpg', 37000, 12345],
        ['./imgs/3.jpg', 3000, 3000],
    ]

    gr.Examples(
        examples=dbs,
        inputs=[input_fg, seed, i2v_seed],
        examples_per_page=1024
    )

block.queue().launch(share=False)