import argparse
import time, os, sys

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

os.system('python scripts/download_models.py')

import gradio as gr
from PIL import Image
import numpy as np
import torch
from typing import List, Literal, Dict, Optional
from draw_utils import draw_points_on_image, draw_mask_on_image
import cv2


from models.streamdiffusion.wrapper import StreamDiffusionWrapper

from models.animatediff.pipelines import I2VPipeline
from omegaconf import OmegaConf

from models.draggan.viz.renderer import Renderer
from models.draggan.gan_inv.lpips.util import PerceptualLoss
import models.draggan.dnnlib as dnnlib
from models.draggan.gan_inv.inversion import PTI

import imageio
import torchvision
from einops import rearrange

# =========================== Model Implementation Start ===================================

def save_videos_grid_255(videos: torch.Tensor, path: str, n_rows=6, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        x = x.numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps)

def reverse_point_pairs(points):
    new_points = []
    for p in points:
        new_points.append([p[1], p[0]])
    return new_points

def render_view_image(img, drag_markers, show_mask=False):
    img = draw_points_on_image(img, drag_markers['points'])
    if show_mask:
        img = draw_mask_on_image(img, drag_markers['mask'])
    img = np.array(img).astype(np.uint8)
    img = np.concatenate([
        img,
        255 * np.ones((img.shape[0], img.shape[1], 1), dtype=img.dtype)
    ], axis=2)
    return Image.fromarray(img)


def update_state_image(state):
    state['generated_image_show'] = render_view_image(
        state['generated_image'],
        state['drag_markers'][0],
        state['is_show_mask'],
    )
    return state['generated_image_show']


class GeneratePipeline:
    def __init__(
        self, 
        i2i_body_ckpt: str = "checkpoints/diffusion_body/kohaku-v2.1",
        # i2i_body_ckpt: str = "checkpoints/diffusion_body/stable-diffusion-v1-5",
        i2i_lora_dict: Optional[Dict[str, float]] = {'checkpoints/i2i/lora/lcm-lora-sdv1-5.safetensors': 1.0},
        prompt: str = "",
        negative_prompt: str = "low quality, bad quality, blurry, low resolution",
        frame_buffer_size: int = 1,
        width: int = 512,
        height: int = 512,
        acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
        use_denoising_batch: bool = True,
        seed: int = 2,
        cfg_type: Literal["none", "full", "self", "initialize"] = "self",
        guidance_scale: float = 1.4,
        delta: float = 0.5,
        do_add_noise: bool = False,
        enable_similar_image_filter: bool = True,
        similar_image_filter_threshold: float = 0.99,
        similar_image_filter_max_skip_frame: float = 10,
    ):
        super(GeneratePipeline, self).__init__()
        if not torch.cuda.is_available():
            acceleration = None
        
        self.img2img_model = None
        self.img2video_model = None
        self.img2video_generator = None
        self.sim_ranges = None

        # set parameters
        self.i2i_body_ckpt = i2i_body_ckpt
        self.i2i_lora_dict = i2i_lora_dict
        self.prompt = prompt
        self.negative_prompt = negative_prompt
        self.frame_buffer_size = frame_buffer_size
        self.width = width
        self.height = height
        self.acceleration = acceleration
        self.use_denoising_batch = use_denoising_batch
        self.seed = seed
        self.cfg_type = cfg_type
        self.guidance_scale = guidance_scale
        self.delta = delta
        self.do_add_noise = do_add_noise
        self.enable_similar_image_filter = enable_similar_image_filter
        self.similar_image_filter_threshold = similar_image_filter_threshold
        self.similar_image_filter_max_skip_frame = similar_image_filter_max_skip_frame

        self.i2v_config = OmegaConf.load('demo/configs/i2v_config.yaml')
        self.i2v_body_ckpt = self.i2v_config.pretrained_model_path
        self.i2v_unet_path = self.i2v_config.generate.model_path
        self.i2v_dreambooth_ckpt = self.i2v_config.generate.db_path

        self.lora_alpha = 0

        assert self.frame_buffer_size == 1
    
    def init_model(self):
        # StreamDiffusion
        self.img2img_model = StreamDiffusionWrapper(
            model_id_or_path=self.i2i_body_ckpt,
            lora_dict=self.i2i_lora_dict,
            t_index_list=[32, 45],
            frame_buffer_size=self.frame_buffer_size,
            width=self.width,
            height=self.height,
            warmup=10,
            acceleration=self.acceleration,
            do_add_noise=self.do_add_noise,
            enable_similar_image_filter=self.enable_similar_image_filter,
            similar_image_filter_threshold=self.similar_image_filter_threshold,
            similar_image_filter_max_skip_frame=self.similar_image_filter_max_skip_frame,
            mode="img2img",
            use_denoising_batch=self.use_denoising_batch,
            cfg_type=self.cfg_type,
            seed=self.seed,
            use_lcm_lora=False,
        )
        self.img2img_model.prepare(
            prompt=self.prompt,
            negative_prompt=self.negative_prompt,
            num_inference_steps=50,
            guidance_scale=self.guidance_scale,
            delta=self.delta,
        )
        
        # PIA
        self.img2video_model = I2VPipeline.build_pipeline(
            self.i2v_config,
            self.i2v_body_ckpt,
            self.i2v_unet_path,
            self.i2v_dreambooth_ckpt,
            None,  # lora path
            self.lora_alpha,
        )
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'
        self.img2video_generator  = torch.Generator(device=device)
        self.img2video_generator.manual_seed(self.i2v_config.generate.global_seed)
        self.sim_ranges = self.i2v_config.validation_data.mask_sim_range

        # Drag GAN
        self.drag_model = Renderer(disable_timing=True)
    
    def generate_image(self, image, text, start_time=None):
        if text is not None:
            pos_prompt, neg_prompt = text
            self.img2img_model.prepare(
                prompt=pos_prompt,
                negative_prompt=neg_prompt,
                num_inference_steps=50,
                guidance_scale=self.guidance_scale,
                delta=self.delta,
            )
        sampled_inputs = [image]
        input_batch = torch.cat(sampled_inputs)
        output_images = self.img2img_model.stream(
            input_batch.to(device=self.img2img_model.device, dtype=self.img2img_model.dtype)
        )
        # if start_time is not None:
        #     print('Generate Done: {}'.format(time.perf_counter() - start_time))
        output_images = output_images.cpu()
        # if start_time is not None:
        #     print('Move Done: {}'.format(time.perf_counter() - start_time))
        return output_images
    
    def generate_video(self, image, text, height=None, width=None):
        pos_prompt, neg_prompt = text 
        sim_range = self.sim_ranges[0]
        print(f"using sim_range : {sim_range}")
        self.i2v_config.validation_data.mask_sim_range = sim_range
        sample = self.img2video_model(
            image = image,
            prompt = pos_prompt,
            generator       = self.img2video_generator,
            video_length    = self.i2v_config.generate.video_length,
            height          = height if height is not None else self.i2v_config.generate.sample_height,
            width           = width if width is not None else self.i2v_config.generate.sample_width,
            negative_prompt = neg_prompt,
            mask_sim_template_idx = self.i2v_config.validation_data.mask_sim_range,
            **self.i2v_config.validation_data,
        ).videos
        return sample
    
    def prepare_drag_model(
        self,
        custom_image: Image,
        latent_space = 'w+',
        trunc_psi = 0.7,
        trunc_cutoff = None,
        seed = 0,
        lr = 0.001,
        generator_params = dnnlib.EasyDict(),
        pretrained_weight = 'stylegan2_lions_512_pytorch',
    ):
        self.drag_model.init_network(
            generator_params,  # res
            pretrained_weight,  # pkl
            seed,  # w0_seed,
            None,  # w_load
            latent_space == 'w+',  # w_plus
            'const',
            trunc_psi,  # trunc_psi,
            trunc_cutoff,  # trunc_cutoff,
            None,  # input_transform
            lr  # lr,
        )

        if torch.cuda.is_available():
            percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=True)
        else:
            percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=False)

        pti = PTI(self.drag_model.G, percept, max_pti_step=400)
        inversed_img, w_pivot = pti.train(custom_image, latent_space == 'w+')
        inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
        inversed_img = inversed_img.cpu().numpy()
        inversed_img = Image.fromarray(inversed_img)
        mask = np.ones((inversed_img.height, inversed_img.width),
                                    dtype=np.uint8)
        generator_params.image = inversed_img
        generator_params.w = w_pivot.detach().cpu().numpy()
        self.drag_model.set_latent(w_pivot, trunc_psi, trunc_cutoff)

        del percept
        del pti
        print('inverse end')

        return generator_params, mask

    def drag_image(
        self,
        points,
        mask,
        motion_lambda = 20,
        r1_in_pixels = 3,
        r2_in_pixels = 12,
        trunc_psi = 0.7,
        draw_interval = 1,
        generator_params = dnnlib.EasyDict(),
    ):
        p_in_pixels = []
        t_in_pixels = []
        valid_points = []
        # Transform the points into torch tensors
        for key_point, point in points.items():
            try:
                p_start = point.get("start_temp", point["start"])
                p_end = point["target"]

                if p_start is None or p_end is None:
                    continue

            except KeyError:
                continue

            p_in_pixels.append(p_start)
            t_in_pixels.append(p_end)
            valid_points.append(key_point)

        mask = torch.tensor(mask).float()
        drag_mask = 1 - mask

        # reverse points order
        p_to_opt = reverse_point_pairs(p_in_pixels)
        t_to_opt = reverse_point_pairs(t_in_pixels)
        step_idx = 0

        self.drag_model._render_drag_impl(
            generator_params,
            p_to_opt,  # point
            t_to_opt,  # target
            drag_mask,  # mask,
            motion_lambda,  # lambda_mask
            reg = 0,
            feature_idx = 5,  # NOTE: do not support change for now
            r1 = r1_in_pixels,  # r1
            r2 = r2_in_pixels,  # r2
            # random_seed     = 0,
            # noise_mode      = 'const',
            trunc_psi = trunc_psi,
            # force_fp32      = False,
            # layer_name      = None,
            # sel_channels    = 3,
            # base_channel    = 0,
            # img_scale_db    = 0,
            # img_normalize   = False,
            # untransform     = False,
            is_drag=True,
            to_pil=True
        )


        points_upd = points
        if step_idx % draw_interval == 0:
            for key_point, p_i, t_i in zip(valid_points, p_to_opt,
                                            t_to_opt):
                points_upd[key_point]["start_temp"] = [
                    p_i[1],
                    p_i[0],
                ]
                points_upd[key_point]["target"] = [
                    t_i[1],
                    t_i[0],
                ]
                start_temp = points_upd[key_point][
                    "start_temp"]

        image_result = generator_params['image']

        return image_result

# ============================= Model Implementation ENd ===================================


parser = argparse.ArgumentParser()
parser.add_argument('--share', action='store_true',default='True')
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
parser.add_argument(
    "--listen",
    action="store_true",
    help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests",
)
args = parser.parse_args()


class CustomImageMask(gr.Image):
    is_template = True
    def __init__(
        self,
        source='upload', 
        tool='sketch', 
        elem_id="image_upload", 
        label='Generated Image', 
        type="pil", 
        mask_opacity=0.5, 
        brush_color='#FFFFFF', 
        height=400, 
        interactive=True,
        **kwargs
    ):
        super(CustomImageMask, self).__init__(
            source=source, 
            tool=tool, 
            elem_id=elem_id, 
            label=label, 
            type=type, 
            mask_opacity=mask_opacity, 
            brush_color=brush_color, 
            height=height, 
            interactive=interactive,
            **kwargs
        )

    def preprocess(self, x):
        if x is None:
            return x
        if self.tool == 'sketch' and self.source in ['upload', 'webcam'] and type(x) != dict:
            decode_image = gr.processing_utils.decode_base64_to_image(x)
            width, height = decode_image.size
            mask = np.ones((height, width, 4), dtype=np.uint8)
            mask[..., -1] = 255
            mask = self.postprocess(mask)
            x = {'image': x, 'mask': mask}
        return super().preprocess(x)


draggan_ckpts = os.listdir('checkpoints/drag')
draggan_ckpts.sort()


generate_pipeline = GeneratePipeline()
generate_pipeline.init_model()


with gr.Blocks() as demo:
    global_state = gr.State(
        {
            'is_image_generation': True,
            'is_image_text_prompt_up-to-date': True,
            'is_show_mask': False,
            'is_dragging': False,
            'generated_image': None,
            'generated_image_show': None,
            'drag_markers': [
                {
                    'points': {},
                    'mask': None
                }
            ],
            'generator_params': dnnlib.EasyDict(),
            'default_image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'),
            'default_video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'),
            'image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'),
            'video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'),
            'params': {
                'seed': 0,
                'motion_lambda': 20,
                'r1_in_pixels': 3,
                'r2_in_pixels': 12,
                'magnitude_direction_in_pixels': 1.0,
                'latent_space': 'w+',
                'trunc_psi': 0.7,
                'trunc_cutoff': None,
                'lr': 0.001,
            },
            'device': None, # device,
            'draw_interval': 1,
            'points': {},
            'curr_point': None,
            'curr_type_point': 'start',
            'editing_state': 'add_points',
            'pretrained_weight': draggan_ckpts[0],
            'video_preview_resolution': '512 x 512',
            'viewer_height': 300,
            'viewer_width': 300
        }
    )

    with gr.Column():
        with gr.Row():
            with gr.Column(scale=8, min_width=10):
                with gr.Tab('Image Text Prompts'):
                    image_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10)
                    image_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10)
                with gr.Tab('Video Text Prompts'):
                    video_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10)
                    video_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10)
                with gr.Tab('Drag Image'):
                    with gr.Row():
                        with gr.Column(scale=1, min_width=10):
                            drag_mode_on_button = gr.Button('Drag Mode On', size='sm', min_width=10)
                            drag_mode_off_button = gr.Button('Drag Mode Off', size='sm', min_width=10)
                            drag_checkpoint_dropdown = gr.Dropdown(choices=draggan_ckpts, value=draggan_ckpts[0], label='checkpoint', min_width=10)
                        with gr.Column(scale=1, min_width=10):
                            with gr.Row():
                                drag_start_button = gr.Button('start', size='sm', min_width=10)
                                drag_stop_button = gr.Button('stop', size='sm', min_width=10)
                            with gr.Row():
                                add_point_button = gr.Button('add point', size='sm', min_width=10)
                                reset_point_button = gr.Button('reset point', size='sm', min_width=10)
                            with gr.Row():
                                steps_number = gr.Number(0, label='steps', interactive=False)
                        with gr.Column(scale=1, min_width=10):
                            with gr.Row():
                                draw_mask_button = gr.Button('draw mask', size='sm', min_width=10)
                                reset_mask_button = gr.Button('reset mask', size='sm', min_width=10)
                            with gr.Row():
                                show_mask_checkbox = gr.Checkbox(value=False, label='show mask', min_width=10, interactive=True)
                            with gr.Row():
                                motion_lambda_number = gr.Number(20, label='Motion Lambda', minimum=1, maximum=100, step=1, interactive=True)
                with gr.Tab('More'):
                    with gr.Row():
                        with gr.Column(scale=2, min_width=10):
                            video_preview_resolution_dropdown = gr.Dropdown(choices=['256 x 256', '512 x 512'], value='512 x 512', label='Video Preview Resolution', min_width=10)
                            sample_image_dropdown = gr.Dropdown(choices=['samples/canvas.jpg'] + ['samples/sample{:>02d}.jpg'.format(i) for i in range(1, 8)], value=None, label='Choose A Sample Image', min_width=10)
            with gr.Column(scale=1, min_width=10):
                confirm_text_button = gr.Button('Confirm Text', size='sm', min_width=10)
                generate_video_button = gr.Button('Generate Video', size='sm', min_width=10)
                clear_video_button = gr.Button('Clear Video', size='sm', min_width=10)
        with gr.Row():
            captured_image_viewer = gr.Image(source='upload', tool='color-sketch', type='pil', label='Image Drawer', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True, shape=(global_state.value['viewer_width'], global_state.value['viewer_height']))  # 
            generated_image_viewer = CustomImageMask(source='upload', tool='sketch', elem_id="image_upload", label='Generated Image', type="pil", mask_opacity=0.5, brush_color='#FFFFFF', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True)
            generated_video_viewer = gr.Video(source='upload', label='Generated Video', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=False)
    
    gr.Markdown(
        """
            ## Quick Start

            1. Select one sample image in `More` tab.
            2. Draw to edit the sample image in the left most image viewer.
            3. Click `Generate Video` and enjoy it!

            ## Note
            Due to the limitation of gradio implementation, the image-to-image generation might have a large latency after the model generation is done.
            We command you to enjoy a better experience with our local demo at [github](https://github.com/invictus717/InteractiveVideo).

            ## Advance Usage

            1. **Try different text prompts.** Enter positive or negative prompts for image / video generation, and
            click `Confirm Text` to enable your prompts.
            2. **Drag images.** Go to `Drag Image` tab, choose a suitable checkpoint and click `Drag Mode On`. 
            It might take a minute to prepare. Properly add points and use masks, then click `start` to 
            start dragging. Once you think it's ok, click `stop` button.
            3. **Adjust video resolution** in the `More` tab.
            4. **Draw from scratch** by choosing `canvas.jpg` in `More` tab and enjoy yourself!
        """
    )

    # ========================= Main Function Start =============================
    def on_captured_image_viewer_update(state, image):
        if image is None:
            return state, gr.Image.update(None)
        if state['is_image_text_prompt_up-to-date']:
            text_prompts = None
        else:
            text_prompts = state['image_text_prompts']
        state['is_image_text_prompt_up-to-date'] = True

        # start_time = time.perf_counter()

        input_image = np.array(image).astype(np.float32)
        input_image = (input_image / 255 - 0.5) * 2
        input_image = torch.tensor(input_image).permute([2, 0, 1])
        noisy_image = torch.randn_like(input_image)

        # print('preprocess done: {}'.format(time.perf_counter() - start_time))

        output_image = generate_pipeline.generate_image(
            input_image,
            text_prompts,
            # start_time,
        )[0]
        output_image = generate_pipeline.generate_image(
            noisy_image,
            None,
            # start_time,
        )[0]  # TODO: is there more elegant way?
        output_image = output_image.permute([1, 2, 0])
        output_image = (output_image / 2 + 0.5).clamp(0, 1) * 255
        
        output_image = output_image.to(torch.uint8).cpu().numpy()
        output_image = Image.fromarray(output_image)
        
        # print('postprocess done: {}'.format(time.perf_counter() - start_time))

        # output_image = image
        state['generated_image'] = output_image
        output_image = update_state_image(state)

        # print('draw done: {}'.format(time.perf_counter() - start_time))
        return state, gr.Image.update(output_image, interactive=False)

    captured_image_viewer.change(
        fn=on_captured_image_viewer_update,
        inputs=[global_state, captured_image_viewer], 
        outputs=[global_state, generated_image_viewer]
    )

    def on_generated_image_viewer_edit(state, data_dict):
        mask = data_dict['mask']
        state['drag_markers'][0]['mask'] = np.array(mask)[:, :, 0] // 255
        image = update_state_image(state)
        return state, image
    
    generated_image_viewer.edit(
        fn=on_generated_image_viewer_edit, 
        inputs=[global_state, generated_image_viewer], 
        outputs=[global_state, generated_image_viewer]
    )

    def on_generate_video_click(state):
        input_image = np.array(state['generated_image'])
        text_prompts = state['video_text_prompts']
        video_preview_resolution = state['video_preview_resolution'].split('x')
        height = int(video_preview_resolution[0].strip(' '))
        width = int(video_preview_resolution[1].strip(' '))
        output_video = generate_pipeline.generate_video(
            input_image,
            text_prompts,
            height = height,
            width = width
        )[0]
        output_video = output_video.clamp(0, 1) * 255
        output_video = output_video.to(torch.uint8)
        # 3 T H W
        print('[video generation done]')
        
        fps = 5  # frames per second
        video_size = (height, width)
        fourcc = cv2.VideoWriter.fourcc(*'mp4v')
        if not os.access('results', os.F_OK):
            os.makedirs('results')
        video_writer = cv2.VideoWriter('results/gradio_temp.mp4', fourcc, fps, video_size)  # Create VideoWriter object
        for i in range(output_video.shape[1]):
            frame = output_video[:, i, :, :].permute([1, 2, 0]).cpu().numpy()
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            video_writer.write(frame)

        video_writer.release()
        return state, gr.Video.update('results/gradio_temp.mp4')
    
    generate_video_button.click(
        fn=on_generate_video_click,
        inputs=[global_state],
        outputs=[global_state, generated_video_viewer]
    )

    def on_clear_video_click(state):
        return state, gr.Video.update(None)
    
    clear_video_button.click(
        fn=on_clear_video_click,
        inputs=[global_state],
        outputs=[global_state, generated_video_viewer]
    )

    def on_drag_mode_on_click(state):
        # prepare DragGAN for custom image
        custom_image = state['generated_image']
        current_ckpt_name = state['pretrained_weight']
        generate_pipeline.prepare_drag_model(
            custom_image,
            generator_params = state['generator_params'],
            pretrained_weight = os.path.join('checkpoints/drag/', current_ckpt_name),
        )
        state['generated_image'] = state['generator_params'].image
        view_image = update_state_image(state)
        return state, gr.Image.update(view_image, interactive=True)
    
    drag_mode_on_button.click(
        fn=on_drag_mode_on_click,
        inputs=[global_state],
        outputs=[global_state, generated_image_viewer]
    )

    def on_drag_mode_off_click(state, image):
        return on_captured_image_viewer_update(state, image)
    
    drag_mode_off_button.click(
        fn=on_drag_mode_off_click,
        inputs=[global_state, captured_image_viewer],
        outputs=[global_state, generated_image_viewer]
    )

    def on_drag_start_click(state):
        state['is_dragging'] = True
        points = state['drag_markers'][0]['points']
        if state['drag_markers'][0]['mask'] is None:
            mask = np.ones((state['generator_params'].image.height, state['generator_params'].image.width), dtype=np.uint8)
        else:
            mask = state['drag_markers'][0]['mask']
        cur_step = 0
        while True:
            if not state['is_dragging']:
                break
            generated_image = generate_pipeline.drag_image(
                points,
                mask,
                motion_lambda = state['params']['motion_lambda'],
                generator_params = state['generator_params']
            )
            state['drag_markers'] = [{'points': points, 'mask': mask}]
            state['generated_image'] = generated_image
            cur_step += 1
            view_image = update_state_image(state)
            if cur_step % 50 == 0:
                print('[{} / {}]'.format(cur_step, 'inf'))
            yield (
                state,
                gr.Image.update(view_image, interactive=False),  # generated image viewer
                gr.Number.update(cur_step),  # step
            )
        
        view_image = update_state_image(state)
        return (
            state, 
            gr.Image.update(view_image, interactive=True), 
            gr.Number.update(cur_step),
        )
    
    drag_start_button.click(
        fn=on_drag_start_click,
        inputs=[global_state],
        outputs=[global_state, generated_image_viewer, steps_number]
    )

    def on_drag_stop_click(state):
        state['is_dragging'] = False
        return state
    
    drag_stop_button.click(
        fn=on_drag_stop_click,
        inputs=[global_state],
        outputs=[global_state]
    )

    # ========================= Main Function End =============================

    # ====================== Update Text Prompts Start ====================
    def on_image_pos_text_prompt_editor_submit(state, text):
        if len(text) == 0:
            temp = state['image_text_prompts']
            state['image_text_prompts'] = (state['default_image_text_prompts'][0], temp[1])
        else:
            temp = state['image_text_prompts']
            state['image_text_prompts'] = (text, temp[1])
        state['is_image_text_prompt_up-to-date'] = False
        return state
    
    image_pos_text_prompt_editor.submit(
        fn=on_image_pos_text_prompt_editor_submit, 
        inputs=[global_state, image_pos_text_prompt_editor], 
        outputs=None
    )

    def on_image_neg_text_prompt_editor_submit(state, text):
        if len(text) == 0:
            temp = state['image_text_prompts']
            state['image_text_prompts'] = (temp[0], state['default_image_text_prompts'][1])
        else:
            temp = state['image_text_prompts']
            state['image_text_prompts'] = (temp[0], text)
        state['is_image_text_prompt_up-to-date'] = False
        return state
    
    image_neg_text_prompt_editor.submit(
        fn=on_image_neg_text_prompt_editor_submit, 
        inputs=[global_state, image_neg_text_prompt_editor], 
        outputs=None
    )

    def on_video_pos_text_prompt_editor_submit(state, text):
        if len(text) == 0:
            temp = state['video_text_prompts']
            state['video_text_prompts'] = (state['default_video_text_prompts'][0], temp[1])
        else:
            temp = state['video_text_prompts']
            state['video_text_prompts'] = (text, temp[1])
        return state
    
    video_pos_text_prompt_editor.submit(
        fn=on_video_pos_text_prompt_editor_submit, 
        inputs=[global_state, video_pos_text_prompt_editor], 
        outputs=None
    )

    def on_video_neg_text_prompt_editor_submit(state, text):
        if len(text) == 0:
            temp = state['video_text_prompts']
            state['video_text_prompts'] = (temp[0], state['default_video_text_prompts'][1])
        else:
            temp = state['video_text_prompts']
            state['video_text_prompts'] = (temp[0], text)
        return state
    
    video_neg_text_prompt_editor.submit(
        fn=on_video_neg_text_prompt_editor_submit, 
        inputs=[global_state, video_neg_text_prompt_editor], 
        outputs=None
    )

    def on_confirm_text_click(state, image, img_pos_t, img_neg_t, vid_pos_t, vid_neg_t):
        state = on_image_pos_text_prompt_editor_submit(state, img_pos_t)
        state = on_image_neg_text_prompt_editor_submit(state, img_neg_t)
        state = on_video_pos_text_prompt_editor_submit(state, vid_pos_t)
        state = on_video_neg_text_prompt_editor_submit(state, vid_neg_t)
        return on_captured_image_viewer_update(state, image)
    
    confirm_text_button.click(
        fn=on_confirm_text_click,
        inputs=[global_state, captured_image_viewer, image_pos_text_prompt_editor, image_neg_text_prompt_editor,
                video_pos_text_prompt_editor, video_neg_text_prompt_editor],
        outputs=[global_state, generated_image_viewer]
    )

    # ====================== Update Text Prompts End ====================

    # ======================= Drag Point Edit Start =========================

    def on_image_clicked(state, evt: gr.SelectData):
        """
            This function only support click for point selection
        """
        pos_x, pos_y = evt.index
        drag_markers = state['drag_markers']
        key_points = list(drag_markers[0]['points'].keys())
        key_points.sort(reverse=False)
        if len(key_points) == 0:  # no point pairs, add a new point pair
            drag_markers[0]['points'][0] = {
                'start_temp': [pos_x, pos_y],
                'start': [pos_x, pos_y],
                'target': None,
            }
        else:
            largest_id = key_points[-1]
            if drag_markers[0]['points'][largest_id]['target'] is None:  # target is not set
                drag_markers[0]['points'][largest_id]['target'] = [pos_x, pos_y]
            else:  # target is set, add a new point pair
                drag_markers[0]['points'][largest_id + 1] = {
                    'start_temp': [pos_x, pos_y],
                    'start': [pos_x, pos_y],
                    'target': None,
                }
        state['drag_markers'] = drag_markers
        image = update_state_image(state)
        return state, gr.Image.update(image, interactive=False)
    
    generated_image_viewer.select(
        fn=on_image_clicked,
        inputs=[global_state],
        outputs=[global_state, generated_image_viewer],
    )

    def on_add_point_click(state):
        return gr.Image.update(state['generated_image_show'], interactive=False)
    
    add_point_button.click(
        fn=on_add_point_click,
        inputs=[global_state],
        outputs=[generated_image_viewer]
    )

    def on_reset_point_click(state):
        drag_markers = state['drag_markers']
        drag_markers[0]['points'] = {}
        state['drag_markers'] = drag_markers
        image = update_state_image(state)
        return state, gr.Image.update(image)
    
    reset_point_button.click(
        fn=on_reset_point_click,
        inputs=[global_state],
        outputs=[global_state, generated_image_viewer]
    )

    # ======================= Drag Point Edit End =========================

    # ======================= Drag Mask Edit Start =========================

    def on_draw_mask_click(state):
        return gr.Image.update(state['generated_image_show'], interactive=True)
    
    draw_mask_button.click(
        fn=on_draw_mask_click,
        inputs=[global_state],
        outputs=[generated_image_viewer]
    )

    def on_reset_mask_click(state):
        drag_markers = state['drag_markers']
        drag_markers[0]['mask'] = np.ones_like(drag_markers[0]['mask'])
        state['drag_markers'] = drag_markers
        image = update_state_image(state)
        return state, gr.Image.update(image)
    
    reset_mask_button.click(
        fn=on_reset_mask_click,
        inputs=[global_state],
        outputs=[global_state, generated_image_viewer]
    )

    def on_show_mask_click(state, evt: gr.SelectData):
        state['is_show_mask'] = evt.selected
        image = update_state_image(state)
        return state, image

    show_mask_checkbox.select(
        fn=on_show_mask_click,
        inputs=[global_state],
        outputs=[global_state, generated_image_viewer]
    )

    # ======================= Drag Mask Edit End =========================

    # ======================= Drag Setting Start =========================

    def on_motion_lambda_change(state, number):
        state['params']['number'] = number
        return state
    
    motion_lambda_number.input(
        fn=on_motion_lambda_change,
        inputs=[global_state, motion_lambda_number],
        outputs=[global_state]
    )

    def on_drag_checkpoint_change(state, checkpoint):
        state['pretrained_weight'] = checkpoint
        print(type(checkpoint), checkpoint)
        return state
    
    drag_checkpoint_dropdown.change(
        fn=on_drag_checkpoint_change,
        inputs=[global_state, drag_checkpoint_dropdown],
        outputs=[global_state]
    )

    # ======================= Drag Setting End =========================

    # ======================= General Setting Start =========================

    def on_video_preview_resolution_change(state, resolution):
        state['video_preview_resolution'] = resolution
        return state
    
    video_preview_resolution_dropdown.change(
        fn=on_video_preview_resolution_change,
        inputs=[global_state, video_preview_resolution_dropdown],
        outputs=[global_state]
    )

    def on_sample_image_change(state, image):
        return state, gr.Image.update(image)

    sample_image_dropdown.change(
        fn=on_sample_image_change,
        inputs=[global_state, sample_image_dropdown],
        outputs=[global_state, captured_image_viewer]
    )

    # ======================= General Setting End =========================


demo.queue(concurrency_count=3, max_size=20)
# demo.launch(share=False, server_name="0.0.0.0" if args.listen else "127.0.0.1")
demo.launch()