import os
import shutil
from enum import Enum

import cv2
import einops
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from blendmodes.blend import BlendType, blendLayers
from PIL import Image
from pytorch_lightning import seed_everything
from safetensors.torch import load_file
from skimage import exposure

import src.import_util  # noqa: F401
from ControlNet.annotator.canny import CannyDetector
from ControlNet.annotator.hed import HEDdetector
from ControlNet.annotator.midas import MidasDetector
from ControlNet.annotator.util import HWC3
from ControlNet.cldm.model import create_model, load_state_dict
from gmflow_module.gmflow.gmflow import GMFlow
from flow.flow_utils import get_warped_and_mask
from sd_model_cfg import model_dict
from src.config import RerenderConfig
from src.controller import AttentionControl
from src.ddim_v_hacked import DDIMVSampler
from src.img_util import find_flat_region, numpy2tensor
from src.video_util import (frame_to_video, get_fps, get_frame_count,
                            prepare_frames)

import huggingface_hub

REPO_NAME = 'Anonymous-sub/Rerender'

huggingface_hub.hf_hub_download(REPO_NAME,
                                'pexels-koolshooters-7322716.mp4',
                                local_dir='videos')
huggingface_hub.hf_hub_download(
    REPO_NAME,
    'pexels-antoni-shkraba-8048492-540x960-25fps.mp4',
    local_dir='videos')
huggingface_hub.hf_hub_download(
    REPO_NAME,
    'pexels-cottonbro-studio-6649832-960x506-25fps.mp4',
    local_dir='videos')

inversed_model_dict = dict()
for k, v in model_dict.items():
    inversed_model_dict[v] = k

to_tensor = T.PILToTensor()
blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))
device = 'cuda' if torch.cuda.is_available() else 'cpu'


class ProcessingState(Enum):
    NULL = 0
    FIRST_IMG = 1
    KEY_IMGS = 2


MAX_KEYFRAME = float(os.environ.get('MAX_KEYFRAME', 8))


class GlobalState:

    def __init__(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        self.detector_type = None
        self.detector = None
        self.controller = None
        self.processing_state = ProcessingState.NULL
        flow_model = GMFlow(
            feature_channels=128,
            num_scales=1,
            upsample_factor=8,
            num_head=1,
            attention_type='swin',
            ffn_dim_expansion=4,
            num_transformer_layers=6,
        ).to(device)

        checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth',
                                map_location=lambda storage, loc: storage)
        weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
        flow_model.load_state_dict(weights, strict=False)
        flow_model.eval()
        self.flow_model = flow_model

    def update_controller(self, inner_strength, mask_period, cross_period,
                          ada_period, warp_period):
        self.controller = AttentionControl(inner_strength, mask_period,
                                           cross_period, ada_period,
                                           warp_period)

    def update_sd_model(self, sd_model, control_type):
        if sd_model == self.sd_model:
            return
        self.sd_model = sd_model
        model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
        if control_type == 'HED':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', './models/control_sd15_hed.pth'),
                    location=device))
        elif control_type == 'canny':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
                    location=device))
        elif control_type == 'depth':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'),
                    location=device))

        model.to(device)
        sd_model_path = model_dict[sd_model]
        if len(sd_model_path) > 0:
            repo_name = REPO_NAME
            # check if sd_model is repo_id/name otherwise use global REPO_NAME
            if sd_model.count('/') == 1:
                repo_name = sd_model

            model_ext = os.path.splitext(sd_model_path)[1]
            downloaded_model = huggingface_hub.hf_hub_download(
                repo_name, sd_model_path)
            if model_ext == '.safetensors':
                model.load_state_dict(load_file(downloaded_model),
                                      strict=False)
            elif model_ext == '.ckpt' or model_ext == '.pth':
                model.load_state_dict(
                    torch.load(downloaded_model)['state_dict'], strict=False)

        try:
            model.first_stage_model.load_state_dict(torch.load(
                huggingface_hub.hf_hub_download(
                    'stabilityai/sd-vae-ft-mse-original',
                    'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'],
                strict=False)
        except Exception:
            print('Warning: We suggest you download the fine-tuned VAE',
                  'otherwise the generation quality will be degraded')

        self.ddim_v_sampler = DDIMVSampler(model)

    def clear_sd_model(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        if device == 'cuda':
            torch.cuda.empty_cache()

    def update_detector(self, control_type, canny_low=100, canny_high=200):
        if self.detector_type == control_type:
            return
        if control_type == 'HED':
            self.detector = HEDdetector()
        elif control_type == 'canny':
            canny_detector = CannyDetector()
            low_threshold = canny_low
            high_threshold = canny_high

            def apply_canny(x):
                return canny_detector(x, low_threshold, high_threshold)

            self.detector = apply_canny

        elif control_type == 'depth':
            midas = MidasDetector()

            def apply_midas(x):
                detected_map, _ = midas(x)
                return detected_map

            self.detector = apply_midas


global_state = GlobalState()
global_video_path = None
video_frame_count = None


def create_cfg(input_path, prompt, image_resolution, control_strength,
               color_preserve, left_crop, right_crop, top_crop, bottom_crop,
               control_type, low_threshold, high_threshold, ddim_steps, scale,
               seed, sd_model, a_prompt, n_prompt, interval, keyframe_count,
               x0_strength, use_constraints, cross_start, cross_end,
               style_update_freq, warp_start, warp_end, mask_start, mask_end,
               ada_start, ada_end, mask_strength, inner_strength,
               smooth_boundary):
    use_warp = 'shape-aware fusion' in use_constraints
    use_mask = 'pixel-aware fusion' in use_constraints
    use_ada = 'color-aware AdaIN' in use_constraints

    if not use_warp:
        warp_start = 1
        warp_end = 0

    if not use_mask:
        mask_start = 1
        mask_end = 0

    if not use_ada:
        ada_start = 1
        ada_end = 0

    input_name = os.path.split(input_path)[-1].split('.')[0]
    frame_count = 2 + keyframe_count * interval
    cfg = RerenderConfig()
    cfg.create_from_parameters(
        input_path,
        os.path.join('result', input_name, 'blend.mp4'),
        prompt,
        a_prompt=a_prompt,
        n_prompt=n_prompt,
        frame_count=frame_count,
        interval=interval,
        crop=[left_crop, right_crop, top_crop, bottom_crop],
        sd_model=sd_model,
        ddim_steps=ddim_steps,
        scale=scale,
        control_type=control_type,
        control_strength=control_strength,
        canny_low=low_threshold,
        canny_high=high_threshold,
        seed=seed,
        image_resolution=image_resolution,
        x0_strength=x0_strength,
        style_update_freq=style_update_freq,
        cross_period=(cross_start, cross_end),
        warp_period=(warp_start, warp_end),
        mask_period=(mask_start, mask_end),
        ada_period=(ada_start, ada_end),
        mask_strength=mask_strength,
        inner_strength=inner_strength,
        smooth_boundary=smooth_boundary,
        color_preserve=color_preserve)
    return cfg


def cfg_to_input(filename):

    cfg = RerenderConfig()
    cfg.create_from_path(filename)
    keyframe_count = (cfg.frame_count - 2) // cfg.interval
    use_constraints = [
        'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN'
    ]

    sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5')

    args = [
        cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength,
        cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low,
        cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model,
        cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count,
        cfg.x0_strength, use_constraints, *cfg.cross_period,
        cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period,
        *cfg.ada_period, cfg.mask_strength, cfg.inner_strength,
        cfg.smooth_boundary
    ]
    return args


def setup_color_correction(image):
    correction_target = cv2.cvtColor(np.asarray(image.copy()),
                                     cv2.COLOR_RGB2LAB)
    return correction_target


def apply_color_correction(correction, original_image):
    image = Image.fromarray(
        cv2.cvtColor(
            exposure.match_histograms(cv2.cvtColor(np.asarray(original_image),
                                                   cv2.COLOR_RGB2LAB),
                                      correction,
                                      channel_axis=2),
            cv2.COLOR_LAB2RGB).astype('uint8'))

    image = blendLayers(image, original_image, BlendType.LUMINOSITY)

    return image


@torch.no_grad()
def process(*args):
    first_frame = process1(*args)

    keypath = process2(*args)

    return first_frame, keypath


@torch.no_grad()
def process0(*args):
    global global_video_path
    global_video_path = args[0]
    return process(*args[1:])


@torch.no_grad()
def process1(*args):

    global global_video_path
    cfg = create_cfg(global_video_path, *args)
    global global_state
    global_state.update_sd_model(cfg.sd_model, cfg.control_type)
    global_state.update_controller(cfg.inner_strength, cfg.mask_period,
                                   cfg.cross_period, cfg.ada_period,
                                   cfg.warp_period)
    global_state.update_detector(cfg.control_type, cfg.canny_low,
                                 cfg.canny_high)
    global_state.processing_state = ProcessingState.FIRST_IMG

    prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution,
                   cfg.crop)

    ddim_v_sampler = global_state.ddim_v_sampler
    model = ddim_v_sampler.model
    detector = global_state.detector
    controller = global_state.controller
    model.control_scales = [cfg.control_strength] * 13
    model.to(device)

    num_samples = 1
    eta = 0.0
    imgs = sorted(os.listdir(cfg.input_dir))
    imgs = [os.path.join(cfg.input_dir, img) for img in imgs]

    model.cond_stage_model.device = device

    with torch.no_grad():
        frame = cv2.imread(imgs[0])
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = HWC3(frame)
        H, W, C = img.shape

        img_ = numpy2tensor(img)

        def generate_first_img(img_, strength):
            encoder_posterior = model.encode_first_stage(img_.to(device))
            x0 = model.get_first_stage_encoding(encoder_posterior).detach()

            detected_map = detector(img)
            detected_map = HWC3(detected_map)

            control = torch.from_numpy(
                detected_map.copy()).float().to(device) / 255.0
            control = torch.stack([control for _ in range(num_samples)], dim=0)
            control = einops.rearrange(control, 'b h w c -> b c h w').clone()
            cond = {
                'c_concat': [control],
                'c_crossattn': [
                    model.get_learned_conditioning(
                        [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
                ]
            }
            un_cond = {
                'c_concat': [control],
                'c_crossattn':
                [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
            }
            shape = (4, H // 8, W // 8)

            controller.set_task('initfirst')
            seed_everything(cfg.seed)

            samples, _ = ddim_v_sampler.sample(
                cfg.ddim_steps,
                num_samples,
                shape,
                cond,
                verbose=False,
                eta=eta,
                unconditional_guidance_scale=cfg.scale,
                unconditional_conditioning=un_cond,
                controller=controller,
                x0=x0,
                strength=strength)
            x_samples = model.decode_first_stage(samples)
            x_samples_np = (
                einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
                127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
            return x_samples, x_samples_np

        # When not preserve color, draw a different frame at first and use its
        # color to redraw the first frame.
        if not cfg.color_preserve:
            first_strength = -1
        else:
            first_strength = 1 - cfg.x0_strength

        x_samples, x_samples_np = generate_first_img(img_, first_strength)

        if not cfg.color_preserve:
            color_corrections = setup_color_correction(
                Image.fromarray(x_samples_np[0]))
            global_state.color_corrections = color_corrections
            img_ = apply_color_correction(color_corrections,
                                          Image.fromarray(img))
            img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
            x_samples, x_samples_np = generate_first_img(
                img_, 1 - cfg.x0_strength)

        global_state.first_result = x_samples
        global_state.first_img = img

    Image.fromarray(x_samples_np[0]).save(
        os.path.join(cfg.first_dir, 'first.jpg'))

    return x_samples_np[0]


@torch.no_grad()
def process2(*args):
    global global_state
    global global_video_path

    if global_state.processing_state != ProcessingState.FIRST_IMG:
        raise gr.Error('Please generate the first key image before generating'
                       ' all key images')

    cfg = create_cfg(global_video_path, *args)
    global_state.update_sd_model(cfg.sd_model, cfg.control_type)
    global_state.update_detector(cfg.control_type, cfg.canny_low,
                                 cfg.canny_high)
    global_state.processing_state = ProcessingState.KEY_IMGS

    # reset key dir
    shutil.rmtree(cfg.key_dir)
    os.makedirs(cfg.key_dir, exist_ok=True)

    ddim_v_sampler = global_state.ddim_v_sampler
    model = ddim_v_sampler.model
    detector = global_state.detector
    controller = global_state.controller
    flow_model = global_state.flow_model
    model.control_scales = [cfg.control_strength] * 13

    num_samples = 1
    eta = 0.0
    firstx0 = True
    pixelfusion = cfg.use_mask
    imgs = sorted(os.listdir(cfg.input_dir))
    imgs = [os.path.join(cfg.input_dir, img) for img in imgs]

    first_result = global_state.first_result
    first_img = global_state.first_img
    pre_result = first_result
    pre_img = first_img

    for i in range(0, cfg.frame_count - 1, cfg.interval):
        cid = i + 1
        frame = cv2.imread(imgs[i + 1])
        print(cid)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = HWC3(frame)
        H, W, C = img.shape

        if cfg.color_preserve or global_state.color_corrections is None:
            img_ = numpy2tensor(img)
        else:
            img_ = apply_color_correction(global_state.color_corrections,
                                          Image.fromarray(img))
            img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1
        encoder_posterior = model.encode_first_stage(img_.to(device))
        x0 = model.get_first_stage_encoding(encoder_posterior).detach()

        detected_map = detector(img)
        detected_map = HWC3(detected_map)

        control = torch.from_numpy(
            detected_map.copy()).float().to(device) / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()
        cond = {
            'c_concat': [control],
            'c_crossattn': [
                model.get_learned_conditioning(
                    [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
            ]
        }
        un_cond = {
            'c_concat': [control],
            'c_crossattn':
            [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
        }
        shape = (4, H // 8, W // 8)

        cond['c_concat'] = [control]
        un_cond['c_concat'] = [control]

        image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float()
        image2 = torch.from_numpy(img).permute(2, 0, 1).float()
        warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
            flow_model, image1, image2, pre_result, False)
        blend_mask_pre = blur(
            F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
        blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)

        image1 = torch.from_numpy(first_img).permute(2, 0, 1).float()
        warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
            flow_model, image1, image2, first_result, False)
        blend_mask_0 = blur(
            F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
        blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)

        if firstx0:
            mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)
            controller.set_warp(
                F.interpolate(bwd_flow_0 / 8.0,
                              scale_factor=1. / 8,
                              mode='bilinear'), mask)
        else:
            mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8)
            controller.set_warp(
                F.interpolate(bwd_flow_pre / 8.0,
                              scale_factor=1. / 8,
                              mode='bilinear'), mask)

        controller.set_task('keepx0, keepstyle')
        seed_everything(cfg.seed)
        samples, intermediates = ddim_v_sampler.sample(
            cfg.ddim_steps,
            num_samples,
            shape,
            cond,
            verbose=False,
            eta=eta,
            unconditional_guidance_scale=cfg.scale,
            unconditional_conditioning=un_cond,
            controller=controller,
            x0=x0,
            strength=1 - cfg.x0_strength)
        direct_result = model.decode_first_stage(samples)

        if not pixelfusion:
            pre_result = direct_result
            pre_img = img
            viz = (
                einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 +
                127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

        else:

            blend_results = (1 - blend_mask_pre
                             ) * warped_pre + blend_mask_pre * direct_result
            blend_results = (
                1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results

            bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)
            blend_mask = blur(
                F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))
            blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)

            encoder_posterior = model.encode_first_stage(blend_results)
            xtrg = model.get_first_stage_encoding(
                encoder_posterior).detach()  # * mask
            blend_results_rec = model.decode_first_stage(xtrg)
            encoder_posterior = model.encode_first_stage(blend_results_rec)
            xtrg_rec = model.get_first_stage_encoding(
                encoder_posterior).detach()
            xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec))  # * mask
            blend_results_rec_new = model.decode_first_stage(xtrg_)
            tmp = (abs(blend_results_rec_new - blend_results).mean(
                dim=1, keepdims=True) > 0.25).float()
            mask_x = F.max_pool2d((F.interpolate(tmp,
                                                 scale_factor=1 / 8.,
                                                 mode='bilinear') > 0).float(),
                                  kernel_size=3,
                                  stride=1,
                                  padding=1)

            mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8)
                    )  # * (1-mask_x)

            if cfg.smooth_boundary:
                noise_rescale = find_flat_region(mask)
            else:
                noise_rescale = torch.ones_like(mask)
            masks = []
            for i in range(cfg.ddim_steps):
                if i <= cfg.ddim_steps * cfg.mask_period[
                        0] or i >= cfg.ddim_steps * cfg.mask_period[1]:
                    masks += [None]
                else:
                    masks += [mask * cfg.mask_strength]

            # mask 3
            # xtrg = ((1-mask_x) *
            #         (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask
            # mask 2
            # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask
            xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask  # mask 1

            tasks = 'keepstyle, keepx0'
            if not firstx0:
                tasks += ', updatex0'
            if i % cfg.style_update_freq == 0:
                tasks += ', updatestyle'
            controller.set_task(tasks, 1.0)

            seed_everything(cfg.seed)
            samples, _ = ddim_v_sampler.sample(
                cfg.ddim_steps,
                num_samples,
                shape,
                cond,
                verbose=False,
                eta=eta,
                unconditional_guidance_scale=cfg.scale,
                unconditional_conditioning=un_cond,
                controller=controller,
                x0=x0,
                strength=1 - cfg.x0_strength,
                xtrg=xtrg,
                mask=masks,
                noise_rescale=noise_rescale)
            x_samples = model.decode_first_stage(samples)
            pre_result = x_samples
            pre_img = img

            viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
                   127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

        Image.fromarray(viz[0]).save(
            os.path.join(cfg.key_dir, f'{cid:04d}.png'))

    key_video_path = os.path.join(cfg.work_dir, 'key.mp4')
    fps = get_fps(cfg.input_path)
    fps //= cfg.interval
    frame_to_video(key_video_path, cfg.key_dir, fps, False)

    return key_video_path


DESCRIPTION = '''
## [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video)
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
### To avoid overload, we set limitations to the **maximum frame number** (8) and the maximum frame resolution (512x768). 
### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU.
### How to use:
1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video.
2. **Run Key Frames**: translate all the key frames based on the settings of the first frame
3. **Run All**: **Run 1st Key Frame** and **Run Key Frames**
4. **Run Propagation**: propogate the key frames to other frames for full video translation. This function is supported [here](https://github.com/williamyang1991/Rerender_A_Video#webui-recommended)
### Tips: 
1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**.
2. Pixel-aware fusion may not work for large or quick motions.
3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering.
4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style.
5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py).
6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. 

**This code is for research purpose and non-commercial use only.**

[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true) for no queue on your own hardware.
'''


ARTICLE = r"""
If Rerender-A-Video is helpful, please help to ⭐ the <a href='https://github.com/williamyang1991/Rerender_A_Video' target='_blank'>Github Repo</a>. Thanks! 
[![GitHub Stars](https://img.shields.io/github/stars/williamyang1991/Rerender_A_Video?style=social)](https://github.com/williamyang1991/Rerender_A_Video)
---
📝 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@inproceedings{yang2023rerender,
 title = {Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation},
 author = {Yang, Shuai and Zhou, Yifan and Liu, Ziwei and and Loy, Chen Change},
  booktitle = {ACM SIGGRAPH Asia Conference Proceedings},
 year = {2023},
}
```
📋 **License**
This project is licensed under <a rel="license" href="https://github.com/williamyang1991/Rerender_A_Video/blob/main/LICENSE.md">S-Lab License 1.0</a>. 
Redistribution and use for non-commercial purposes should follow this license.

📧 **Contact**
If you have any questions, please feel free to reach me out at <b>williamyang@pku.edu.cn</b>.
"""

FOOTER = '<div align=center><img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.laobi.icu/badge?page_id=williamyang1991/Rerender_A_Video" /></div>'


block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            input_path = gr.Video(label='Input Video',
                                  source='upload',
                                  format='mp4',
                                  visible=True)
            prompt = gr.Textbox(label='Prompt')
            seed = gr.Slider(label='Seed',
                             minimum=0,
                             maximum=2147483647,
                             step=1,
                             value=0,
                             randomize=True)
            run_button = gr.Button(value='Run All')
            with gr.Row():
                run_button1 = gr.Button(value='Run 1st Key Frame')
                run_button2 = gr.Button(value='Run Key Frames')
                run_button3 = gr.Button(value='Run Propagation')
            with gr.Accordion('Advanced options for the 1st frame translation',
                              open=False):
                image_resolution = gr.Slider(
                    label='Frame rsolution',
                    minimum=256,
                    maximum=512,
                    value=512,
                    step=64,
                    info='To avoid overload, maximum 512')
                control_strength = gr.Slider(label='ControNet strength',
                                             minimum=0.0,
                                             maximum=2.0,
                                             value=1.0,
                                             step=0.01)
                x0_strength = gr.Slider(
                    label='Denoising strength',
                    minimum=0.00,
                    maximum=1.05,
                    value=0.75,
                    step=0.05,
                    info=('0: fully recover the input.'
                          '1.05: fully rerender the input.'))
                color_preserve = gr.Checkbox(
                    label='Preserve color',
                    value=True,
                    info='Keep the color of the input video')
                with gr.Row():
                    left_crop = gr.Slider(label='Left crop length',
                                          minimum=0,
                                          maximum=512,
                                          value=0,
                                          step=1)
                    right_crop = gr.Slider(label='Right crop length',
                                           minimum=0,
                                           maximum=512,
                                           value=0,
                                           step=1)
                with gr.Row():
                    top_crop = gr.Slider(label='Top crop length',
                                         minimum=0,
                                         maximum=512,
                                         value=0,
                                         step=1)
                    bottom_crop = gr.Slider(label='Bottom crop length',
                                            minimum=0,
                                            maximum=512,
                                            value=0,
                                            step=1)
                with gr.Row():
                    control_type = gr.Dropdown(['HED', 'canny', 'depth'],
                                               label='Control type',
                                               value='HED')
                    low_threshold = gr.Slider(label='Canny low threshold',
                                              minimum=1,
                                              maximum=255,
                                              value=100,
                                              step=1)
                    high_threshold = gr.Slider(label='Canny high threshold',
                                               minimum=1,
                                               maximum=255,
                                               value=200,
                                               step=1)
                ddim_steps = gr.Slider(label='Steps',
                                       minimum=1,
                                       maximum=20,
                                       value=20,
                                       step=1,
                                       info='To avoid overload, maximum 20')
                scale = gr.Slider(label='CFG scale',
                                  minimum=0.1,
                                  maximum=30.0,
                                  value=7.5,
                                  step=0.1)
                sd_model_list = list(model_dict.keys())
                sd_model = gr.Dropdown(sd_model_list,
                                       label='Base model',
                                       value='Stable Diffusion 1.5')
                a_prompt = gr.Textbox(label='Added prompt',
                                      value='best quality, extremely detailed')
                n_prompt = gr.Textbox(
                    label='Negative prompt',
                    value=('longbody, lowres, bad anatomy, bad hands, '
                           'missing fingers, extra digit, fewer digits, '
                           'cropped, worst quality, low quality'))
            with gr.Accordion('Advanced options for the key fame translation',
                              open=False):
                interval = gr.Slider(
                    label='Key frame frequency (K)',
                    minimum=1,
                    maximum=MAX_KEYFRAME,
                    value=1,
                    step=1,
                    info='Uniformly sample the key frames every K frames')
                keyframe_count = gr.Slider(
                    label='Number of key frames',
                    minimum=1,
                    maximum=MAX_KEYFRAME,
                    value=1,
                    step=1,
                    info='To avoid overload, maximum 8 key frames')

                use_constraints = gr.CheckboxGroup(
                    [
                        'shape-aware fusion', 'pixel-aware fusion',
                        'color-aware AdaIN'
                    ],
                    label='Select the cross-frame contraints to be used',
                    value=[
                        'shape-aware fusion', 'pixel-aware fusion',
                        'color-aware AdaIN'
                    ]),
                with gr.Row():
                    cross_start = gr.Slider(
                        label='Cross-frame attention start',
                        minimum=0,
                        maximum=1,
                        value=0,
                        step=0.05)
                    cross_end = gr.Slider(label='Cross-frame attention end',
                                          minimum=0,
                                          maximum=1,
                                          value=1,
                                          step=0.05)
                style_update_freq = gr.Slider(
                    label='Cross-frame attention update frequency',
                    minimum=1,
                    maximum=100,
                    value=1,
                    step=1,
                    info=('Update the key and value for '
                          'cross-frame attention every N key frames (recommend N*K>=10)'
                          ))
                with gr.Row():
                    warp_start = gr.Slider(label='Shape-aware fusion start',
                                           minimum=0,
                                           maximum=1,
                                           value=0,
                                           step=0.05)
                    warp_end = gr.Slider(label='Shape-aware fusion end',
                                         minimum=0,
                                         maximum=1,
                                         value=0.1,
                                         step=0.05)
                with gr.Row():
                    mask_start = gr.Slider(label='Pixel-aware fusion start',
                                           minimum=0,
                                           maximum=1,
                                           value=0.5,
                                           step=0.05)
                    mask_end = gr.Slider(label='Pixel-aware fusion end',
                                         minimum=0,
                                         maximum=1,
                                         value=0.8,
                                         step=0.05)
                with gr.Row():
                    ada_start = gr.Slider(label='Color-aware AdaIN start',
                                          minimum=0,
                                          maximum=1,
                                          value=0.8,
                                          step=0.05)
                    ada_end = gr.Slider(label='Color-aware AdaIN end',
                                        minimum=0,
                                        maximum=1,
                                        value=1,
                                        step=0.05)
                mask_strength = gr.Slider(label='Pixel-aware fusion stength',
                                          minimum=0,
                                          maximum=1,
                                          value=0.5,
                                          step=0.01)
                inner_strength = gr.Slider(
                    label='Pixel-aware fusion detail level',
                    minimum=0.5,
                    maximum=1,
                    value=0.9,
                    step=0.01,
                    info='Use a low value to prevent artifacts')
                smooth_boundary = gr.Checkbox(
                    label='Smooth fusion boundary',
                    value=True,
                    info='Select to prevent artifacts at boundary')

            with gr.Accordion('Example configs', open=True):
                config_dir = 'config'
                config_list = os.listdir(config_dir)
                args_list = []
                for config in config_list:
                    try:
                        config_path = os.path.join(config_dir, config)
                        args = cfg_to_input(config_path)
                        args_list.append(args)
                    except FileNotFoundError:
                        # The video file does not exist, skipped
                        pass

                ips = [
                    prompt, image_resolution, control_strength, color_preserve,
                    left_crop, right_crop, top_crop, bottom_crop, control_type,
                    low_threshold, high_threshold, ddim_steps, scale, seed,
                    sd_model, a_prompt, n_prompt, interval, keyframe_count,
                    x0_strength, use_constraints[0], cross_start, cross_end,
                    style_update_freq, warp_start, warp_end, mask_start,
                    mask_end, ada_start, ada_end, mask_strength,
                    inner_strength, smooth_boundary
                ]

        with gr.Column():
            result_image = gr.Image(label='Output first frame',
                                    type='numpy',
                                    interactive=False)
            result_keyframe = gr.Video(label='Output key frame video',
                                       format='mp4',
                                       interactive=False)
    with gr.Row():
        gr.Examples(examples=args_list,
                    inputs=[input_path, *ips],
                    fn=process0,
                    outputs=[result_image, result_keyframe],
                    cache_examples=True)

    gr.Markdown(ARTICLE)
    gr.Markdown(FOOTER)

    def input_uploaded(path):
        frame_count = get_frame_count(path)
        if frame_count <= 2:
            raise gr.Error('The input video is too short!'
                           'Please input another video.')

        default_interval = min(10, frame_count - 2)
        max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME)

        global video_frame_count
        video_frame_count = frame_count
        global global_video_path
        global_video_path = path

        return gr.Slider.update(value=default_interval,
                                maximum=frame_count - 2), gr.Slider.update(
                                    value=max_keyframe, maximum=max_keyframe)

    def input_changed(path):
        frame_count = get_frame_count(path)
        if frame_count <= 2:
            return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1)

        default_interval = min(10, frame_count - 2)
        max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME)

        global video_frame_count
        video_frame_count = frame_count
        global global_video_path
        global_video_path = path

        return gr.Slider.update(value=default_interval,
                                maximum=frame_count - 2), \
            gr.Slider.update(maximum=max_keyframe)

    def interval_changed(interval):
        global video_frame_count
        if video_frame_count is None:
            return gr.Slider.update()

        max_keyframe = min((video_frame_count - 2) // interval, MAX_KEYFRAME)

        return gr.Slider.update(value=max_keyframe, maximum=max_keyframe)

    input_path.change(input_changed, input_path, [interval, keyframe_count])
    input_path.upload(input_uploaded, input_path, [interval, keyframe_count])
    interval.change(interval_changed, interval, keyframe_count)

    run_button.click(fn=process,
                     inputs=ips,
                     outputs=[result_image, result_keyframe])
    run_button1.click(fn=process1, inputs=ips, outputs=[result_image])
    run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])

    def process3():
        raise gr.Error(
            "Coming Soon. Full code for full video translation will be "
            "released upon the publication of the paper.")

    run_button3.click(fn=process3, outputs=[result_keyframe])

block.queue(concurrency_count=1, max_size=20)
block.launch(server_name='0.0.0.0')