import os
import shutil  
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import torch.optim as optim
import random
import imageio
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import time
import scipy.interpolate
from tqdm import tqdm

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from annotator.midas import MidasDetector
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler
from stablevideo.atlas_data import AtlasData
from stablevideo.atlas_utils import get_grid_indices, get_atlas_bounding_box
from stablevideo.aggnet import AGGNet


class StableVideo:
    def __init__(self, base_cfg, canny_model_cfg, depth_model_cfg, save_memory=False):
        self.base_cfg = base_cfg
        self.canny_model_cfg = canny_model_cfg
        self.depth_model_cfg = depth_model_cfg
        self.img2img_model = None
        self.canny_model = None
        self.depth_model = None
        self.b_atlas = None
        self.f_atlas = None
        self.data = None
        self.crops = None
        self.save_memory = save_memory
    
    def load_canny_model(
        self,
        base_cfg='ckpt/cldm_v15.yaml',
        canny_model_cfg='ckpt/control_sd15_canny.pth',
    ):
        self.apply_canny = CannyDetector()
        canny_model = create_model(base_cfg).cpu()
        canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cpu'), strict=False)
        self.canny_ddim_sampler = DDIMSampler(canny_model)
        self.canny_model = canny_model
        
    def load_depth_model(
        self,
        base_cfg='ckpt/cldm_v15.yaml',
        depth_model_cfg='ckpt/control_sd15_depth.pth',
    ):
        self.apply_midas = MidasDetector()
        depth_model = create_model(base_cfg).cpu()
        depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cpu'), strict=False)
        self.depth_ddim_sampler = DDIMSampler(depth_model)
        self.depth_model = depth_model

    def load_video(self, video_name):
        self.data = AtlasData(video_name)
        save_name = f"data/{video_name}/{video_name}.mp4"
        if not os.path.exists(save_name):
            imageio.mimwrite(save_name, self.data.original_video.cpu().permute(0, 2, 3, 1))
            print("original video saved.")
        toIMG = transforms.ToPILImage()
        self.f_atlas_origin = toIMG(self.data.cropped_foreground_atlas[0])
        self.b_atlas_origin = toIMG(self.data.background_grid_atlas[0])
        return save_name, self.f_atlas_origin, self.b_atlas_origin
    
    @torch.no_grad()
    def depth_edit(self, input_image=None,
                    prompt="", 
                    a_prompt="best quality, extremely detailed", 
                    n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",  
                    image_resolution=512, 
                    detect_resolution=384,
                    ddim_steps=20, 
                    scale=9, 
                    seed=-1, 
                    eta=0,
                    num_samples=1):
        
        size = input_image.size
        model = self.depth_model
        ddim_sampler = self.depth_ddim_sampler
        apply_midas = self.apply_midas
        
        input_image = np.array(input_image)
        input_image = HWC3(input_image)
        detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
        detected_map = HWC3(detected_map)
        img = resize_image(input_image, image_resolution)
        H, W, C = img.shape

        detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

        control = torch.from_numpy(detected_map.copy()).float() / 255.0
        control = torch.stack([control for _ in range(1)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
        un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
        shape = (4, H // 8, W // 8)
    

        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                     shape, cond, verbose=False, eta=eta,
                                                     unconditional_guidance_scale=scale,
                                                     unconditional_conditioning=un_cond)
        
        x_samples = model.decode_first_stage(samples)
        x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

        results = [x_samples[i] for i in range(num_samples)]
        self.b_atlas = Image.fromarray(results[0]).resize(size)
        return self.b_atlas
    
    @torch.no_grad()
    def edit_background(self, *args, **kwargs):
        self.depth_model = self.depth_model
            
        input_image = self.b_atlas_origin
        self.depth_edit(input_image, *args, **kwargs)
        
        if self.save_memory:
            self.depth_model = self.depth_model.cpu()
        return self.b_atlas
    
    @torch.no_grad()
    def advanced_edit_foreground(self, 
                                keyframes="0", 
                                res=2000,
                                prompt="", 
                                a_prompt="best quality, extremely detailed", 
                                n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",  
                                image_resolution=512, 
                                low_threshold=100, 
                                high_threshold=200,
                                ddim_steps=20,
                                s=0.9,
                                scale=9, 
                                seed=-1, 
                                eta=0,
                                if_net=False,
                                num_samples=1):

        self.canny_model = self.canny_model
        
        keyframes = [int(x) for x in keyframes.split(",")]
        if self.data is None:
            raise ValueError("Please load video first")
        self.crops = self.data.get_global_crops_multi(keyframes, res)
        n_keyframes = len(keyframes)
        indices = get_grid_indices(0, 0, res, res)
        f_atlas = torch.zeros(size=(n_keyframes, res, res, 3,)).to("cuda")

        img_list = [transforms.ToPILImage()(i[0]) for i in self.crops['original_foreground_crops']]
        result_list = []
        
        # initial setting
        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)
        
        self.canny_ddim_sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=eta, verbose=False)
        c_crossattn = [self.canny_model.get_learned_conditioning([prompt + ', ' + a_prompt])]
        uc_crossattn = [self.canny_model.get_learned_conditioning([n_prompt])]
        
        for i in range(n_keyframes):
            # get current keyframe
            current_img = img_list[i]
            img = resize_image(HWC3(np.array(current_img)), image_resolution)
            H, W, C = img.shape
            shape = (4, H // 8, W // 8)
            # get canny control
            detected_map = self.apply_canny(img, low_threshold, high_threshold)
            detected_map = HWC3(detected_map)
            control = torch.from_numpy(detected_map.copy()).float() / 255.0
            control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
            
            cond = {"c_concat": [control], "c_crossattn": c_crossattn}
            un_cond = {"c_concat": [control], "c_crossattn": uc_crossattn}
            
            
            # if not the key frame, calculate the mapping from last atlas
            if i == 0:
                latent = torch.randn((1, 4, H // 8, W // 8))
                samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
                                                            shape, cond, verbose=False, eta=eta,
                                                            unconditional_guidance_scale=scale,
                                                            unconditional_conditioning=un_cond,
                                                            x_T=latent)
            else:
                last_atlas = f_atlas[i-1:i].permute(0, 3, 2, 1)
                mapped_img = F.grid_sample(last_atlas, self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), mode="bilinear", align_corners=self.data.config["align_corners"]).clamp(min=0.0, max=1.0).reshape((3, current_img.size[1], current_img.size[0]))
                mapped_img = transforms.ToPILImage()(mapped_img)
                
                mapped_img = mapped_img.resize((W, H))
                mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
                mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
                mapped_img = torch.from_numpy(mapped_img)
                mapped_img = 2. * mapped_img - 1.
                latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
                
                t_enc = int(ddim_steps * s)
                latent = self.canny_ddim_sampler.stochastic_encode(latent, torch.tensor([t_enc]).to("cuda"))
                samples = self.canny_ddim_sampler.decode(x_latent=latent, 
                                                         cond=cond, 
                                                         t_start=t_enc,
                                                         unconditional_guidance_scale=scale,
                                                         unconditional_conditioning=un_cond)

            x_samples = self.canny_model.decode_first_stage(samples)
            result = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
            result = Image.fromarray(result[0])
            
            result = result.resize(current_img.size)
            result = transforms.ToTensor()(result)
            # times alpha
            alpha = self.crops['foreground_alpha'][i][0].cpu()
            result = alpha * result
            
            # buffer for training
            result_copy = result.clone()
            result_copy.requires_grad = True
            result_list.append(result_copy)
            
            # map to atlas
            uv = (self.crops['foreground_uvs'][i].reshape(-1, 2) * 0.5 + 0.5) * res
            for c in range(3):
                interpolated = scipy.interpolate.griddata(
                    points=uv.cpu().numpy(),
                    values=result[c].reshape(-1, 1).cpu().numpy(),
                    xi=indices.reshape(-1, 2).cpu().numpy(),
                    method="linear",
                ).reshape(res, res)
                interpolated = torch.from_numpy(interpolated).float()
                interpolated[interpolated.isnan()] = 0.0
                f_atlas[i, :, :, c] = interpolated

        f_atlas = f_atlas.permute(0, 3, 2, 1)
        
        # aggregate via simple median as begining
        agg_atlas, _ = torch.median(f_atlas, dim=0)
        
        if if_net == True:
            #####################################
            #           aggregate net           #
            #####################################
            lr, n_epoch = 1e-3, 500
            agg_net = AGGNet()
            loss_fn = nn.L1Loss()
            optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
            for _ in range(n_epoch):
                loss = 0.
                for i in range(n_keyframes):
                    e_img = result_list[i]
                    temp_agg_atlas = agg_net(agg_atlas)
                    rec_img = F.grid_sample(temp_agg_atlas[None], 
                                            self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), 
                                            mode="bilinear", 
                                            align_corners=self.data.config["align_corners"])
                    rec_img = rec_img.clamp(min=0.0, max=1.0).reshape(e_img.shape)
                    loss += loss_fn(rec_img, e_img)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            agg_atlas = agg_net(agg_atlas)
        #####################################
        
        agg_atlas, _ = get_atlas_bounding_box(self.data.mask_boundaries, agg_atlas, self.data.foreground_all_uvs)
        self.f_atlas = transforms.ToPILImage()(agg_atlas)
        
        if self.save_memory:
            self.canny_model = self.canny_model.cpu()
        
        return self.f_atlas

    @torch.no_grad()
    def render(self, f_atlas, b_atlas):
        # foreground
        if f_atlas == None:
            f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
        else:
            f_atlas, mask = f_atlas["image"], f_atlas["mask"]
            f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
            f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
            mask = transforms.ToTensor()(mask).unsqueeze(0)
            if f_atlas.shape != mask.shape:
                print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
                mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
            f_atlas = f_atlas * (1 - mask) + f_atlas_origin * mask
        
        f_atlas = torch.nn.functional.pad(
            f_atlas,
            pad=(
                self.data.foreground_atlas_bbox[1],
                self.data.foreground_grid_atlas.shape[-1] - (self.data.foreground_atlas_bbox[1] + self.data.foreground_atlas_bbox[3]),
                self.data.foreground_atlas_bbox[0],
                self.data.foreground_grid_atlas.shape[-2] - (self.data.foreground_atlas_bbox[0] + self.data.foreground_atlas_bbox[2]),
            ),
            mode="replicate",
        )
        foreground_edit = F.grid_sample(
            f_atlas, self.data.scaled_foreground_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
        ).clamp(min=0.0, max=1.0)
        
        foreground_edit = foreground_edit.squeeze().t()  # shape (batch, 3)
        foreground_edit = (
            foreground_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
            .permute(0, 3, 1, 2)
            .clamp(min=0.0, max=1.0)
        )
        # background
        if b_atlas == None:
            b_atlas = self.b_atlas_origin

        b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
        background_edit = F.grid_sample(
            b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
        ).clamp(min=0.0, max=1.0)
        background_edit = background_edit.squeeze().t()  # shape (batch, 3)
        background_edit = (
            background_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
            .permute(0, 3, 1, 2)
            .clamp(min=0.0, max=1.0)
        )
        
        output_video = (
                self.data.all_alpha * foreground_edit
                + (1 - self.data.all_alpha) * background_edit
        )
        id = time.time()
        os.mkdir(f"log/{id}")
        save_name = f"log/{id}/video.mp4"
        imageio.mimwrite(save_name, (255 * output_video.detach().cpu()).to(torch.uint8).permute(0, 2, 3, 1))
        
        return save_name

if __name__ == '__main__':
    stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
                            canny_model_cfg="ckpt/control_sd15_canny.pth",
                            depth_model_cfg="ckpt/control_sd15_depth.pth",
                            save_memory=True)
    stablevideo.load_canny_model()
    stablevideo.load_depth_model()
    
    block = gr.Blocks().queue()
    with block:
        with gr.Row():
            gr.Markdown("## StableVideo")
        with gr.Row():
            with gr.Column():
                original_video = gr.Video(label="Original Video", interactive=False)
                with gr.Row():
                    foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
                    background_atlas = gr.Image(label="Background Atlas", type="pil")
                gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
                avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
                video_name = gr.Radio(choices=avail_video,
                                    label="Select Example Videos",
                                    value="car-turn")
                load_video_button = gr.Button("Load Video")
                gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
                with gr.Row():
                    f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
                    b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
                with gr.Row():
                    with gr.Accordion("Advanced Foreground Options", open=False):
                        adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
                        adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
                        adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
                        adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
                        adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
                        adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                        adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
                        adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
                        adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                        adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
                        adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
                        adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
                        adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
                        
                    with gr.Accordion("Background Options", open=False):
                        b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
                        b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
                        b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                        b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                        b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                        b_eta = gr.Number(label="eta (DDIM)", value=0.0)
                        b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
                        b_n_prompt = gr.Textbox(label="Negative Prompt",
                                            value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
                gr.Markdown("### Step 3. edit each one and render.")
                with gr.Row():
                    f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
                    b_run_button = gr.Button("Edit Background")
                run_button = gr.Button("Render")
            with gr.Column():
                output_video = gr.Video(label="Output Video", interactive=False)
                # output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
                output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
                output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
        
        # edit param
        f_adv_edit_param = [adv_keyframes, 
                            adv_atlas_resolution, 
                            f_prompt, 
                            adv_a_prompt, 
                            adv_n_prompt, 
                            adv_image_resolution, 
                            adv_low_threshold, 
                            adv_high_threshold, 
                            adv_ddim_steps, 
                            adv_s,
                            adv_scale, 
                            adv_seed, 
                            adv_eta,
                            adv_if_net]
        b_edit_param = [b_prompt, 
                        b_a_prompt, 
                        b_n_prompt, 
                        b_image_resolution, 
                        b_detect_resolution, 
                        b_ddim_steps, 
                        b_scale, 
                        b_seed,
                        b_eta]
        # action
        load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
        f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
        b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
        run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
    
    block.launch()