import spaces
import gradio as gr
import sys
import os
import torch
import numpy as np
from os.path import join as pjoin
import utils.paramUtil as paramUtil
from utils.plot_script import *
from utils.utils import *
from utils.motion_process import recover_from_ric
from accelerate.utils import set_seed
from models.gaussian_diffusion import DiffusePipeline
from options.generate_options import GenerateOptions
from utils.model_load import load_model_weights
from motion_loader import get_dataset_loader
from models import build_models
import yaml
import time
from box import Box
import hashlib
from huggingface_hub import hf_hub_download

ckptdir = './checkpoints/t2m/release'
os.makedirs(ckptdir, exist_ok=True)


os.environ['GRADIO_TEMP_DIR']="/home/user/app/temp"
os.environ['GRADIO_ALLOWED_PATHS']="/home/user/app/temp"

mean_path = hf_hub_download(
    repo_id="EvanTHU/MotionCLR",
    filename="meta/mean.npy",
    local_dir=ckptdir,
    local_dir_use_symlinks=False
)

std_path = hf_hub_download(
    repo_id="EvanTHU/MotionCLR",
    filename="meta/std.npy",
    local_dir=ckptdir,
    local_dir_use_symlinks=False
)

model_path = hf_hub_download(
    repo_id="EvanTHU/MotionCLR",
    filename="model/latest.tar",
    local_dir=ckptdir,
    local_dir_use_symlinks=False
)

opt_path = hf_hub_download(
    repo_id="EvanTHU/MotionCLR",
    filename="opt.txt",
    local_dir=ckptdir,
    local_dir_use_symlinks=False
)



os.makedirs("/home/user/app/temp", exist_ok=True)

def generate_md5(input_string):
    # Encode the string and compute the MD5 hash
    md5_hash = hashlib.md5(input_string.encode())
    # Return the hexadecimal representation of the hash
    return md5_hash.hexdigest()

def set_all_use_to_false(data):
    for key, value in data.items():
        if isinstance(value, Box): 
            set_all_use_to_false(value)
        elif key == 'use': 
            data[key] = False     
    return data

def yaml_to_box(yaml_file):
    with open(yaml_file, 'r') as file:
        yaml_data = yaml.safe_load(file)
    
    return Box(yaml_data)

HEAD = ("""<div>
<div class="embed_hidden" style="text-align: center;">
    <h1>MotionCLR: Motion Generation and Training-free Editing via Understanding Attention Mechanisms</h1>
    <h3>
        <a href="https://lhchen.top" target="_blank" rel="noopener noreferrer">Ling-Hao Chen</a><sup>1, 2</sup>,
        <a href="https://shunlinlu.github.io" target="_blank" rel="noopener noreferrer">Wenxun Dai</a><sup>2</sup>,
        <a href="https://shunlinlu.github.io" target="_blank" rel="noopener noreferrer">Xuan Ju</a><sup>3</sup>,
        <a href="https://shunlinlu.github.io" target="_blank" rel="noopener noreferrer">Shunlin Lu</a><sup>4</sup>,
        <a href="https://leizhang.org" target="_blank" rel="noopener noreferrer">Lei Zhang</a><sup>🤗 2</sup>
    </h3>
    <h3><sup>🤗</sup><i>Corresponding author.</i></h3>
    <h3>
        <sup>1</sup>THU &emsp;
        <sup>2</sup>IDEA Research &emsp;
        <sup>3</sup>CUHK  &emsp;
        <sup>4</sup>CUHK (SZ)
    </h3>
</div>
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
<a href='https://arxiv.org/abs/2405.20340'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a> 
<a href='https://arxiv.org/pdf/2405.20340.pdf'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a> 
<a href='https://lhchen.top/MotionCLR'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a> 
<a href='https://research.lhchen.top/blogpost/motionclr'><img src='https://img.shields.io/badge/Blog-post-4EABE6?style=flat&logoColor=4EABE6'></a>
<a href='https://github.com/IDEA-Research/MotionCLR'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a> 
<a href='https://huggingface.co/spaces/EvanTHU/MotionCLR'><img src='https://img.shields.io/badge/gradio-demo-red.svg'></a> 
<a href='LICENSE'><img src='https://img.shields.io/badge/License-IDEA-blue.svg'></a> 
<a href="https://huggingface.co/spaces/EvanTHU/MotionCLR" target='_blank'><img src="https://visitor-badge.laobi.icu/badge?page_id=IDEA-Research.MotionCLR&left_color=gray&right_color=%2342b983"></a> 
</div>
</div>
""")


edit_config = yaml_to_box('options/edit.yaml')
CSS = """
.retrieved_video {
    position: relative;
    margin: 0;
    box-shadow: var(--block-shadow);
    border-width: var(--block-border-width);
    border-color: #000000;
    border-radius: var(--block-radius);
    background: var(--block-background-fill);
    width: 100%;
    line-height: var(--line-sm);
}
.contour_video {
    display: flex;
    flex-direction: column;
    justify-content: center;
    align-items: center;
    z-index: var(--layer-5);
    border-radius: var(--block-radius);
    background: var(--background-fill-primary);
    padding: 0 var(--size-6);
    max-height: var(--size-screen-h);
    overflow: hidden;
}
"""

@spaces.GPU
def generate_video_from_text(text, opt, pipeline):
    width = 500
    height = 500
    texts = [text]
    motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
    
    save_dir = '/home/user/app/temp/'
    filename = generate_md5(str(time.time())) + ".mp4"
    save_path = pjoin(save_dir, str(filename))
    os.makedirs(save_dir, exist_ok=True)
    
    print("xxxxxxx")
    print(pipeline.device)
    print("xxxxxxx")
    
    start_time = time.perf_counter()
    gr.Info("Generating motion...", duration = 3)
    pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
    start_time = time.perf_counter()
    mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
    std = np.load(pjoin(opt.meta_dir, 'std.npy'))
    
    
    samples = []
    
    root_list = []
    for i, motion in enumerate(pred_motions):
        motion = motion.cpu().numpy() * std + mean
        # 1. recover 3d joints representation by ik
        motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
        # 2. put on Floor (Y axis)
        floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
        motion[:, :, 1] -= floor_height
        motion = motion.numpy()
        # 3. remove jitter
        motion = motion_temporal_filter(motion, sigma=1)

        samples.append(motion)
    
    i = 0
    title = texts[i]
    motion = samples[i]
    kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
    plot_3d_motion(save_path, kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
    os.system("ls temp")
    os.system("pwd")

    gr.Info("Rendered motion...", duration = 3)
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
    
    video_dis = f'<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_path}"></video>'
    style_dis = video_dis + """<br> <p align="center"> Content Reference </p>"""
    global edit_config
    edit_config = set_all_use_to_false(edit_config)
    return save_path, style_dis, video_dis, gr.update(visible=True)

@spaces.GPU
def reweighting(text, idx, weight, opt, pipeline):
    global edit_config
    edit_config.reweighting_attn.use = True
    edit_config.reweighting_attn.idx = idx
    edit_config.reweighting_attn.reweighting_attn_weight = weight


    gr.Info("Loading Configurations...", duration = 3)
    model = build_models(opt, edit_config=edit_config)
    ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')  
    niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)

    pipeline = DiffusePipeline(
        opt = opt,
        model = model, 
        diffuser_name = opt.diffuser_name, 
        device=opt.device,
        num_inference_steps=opt.num_inference_steps,
        torch_dtype=torch.float16,
    )
    
    print(edit_config)
    
    width = 500
    height = 500
    texts = [text, text]
    motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
    
    save_dir = '/home/user/app/temp/'
    filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"]
    save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1]))]
    os.makedirs(save_dir, exist_ok=True)
    
    start_time = time.perf_counter()
    gr.Info("Generating motion...", duration = 3)
    pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
    start_time = time.perf_counter()
    mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
    std = np.load(pjoin(opt.meta_dir, 'std.npy'))
    
    
    samples = []
    
    root_list = []
    for i, motion in enumerate(pred_motions):
        motion = motion.cpu().numpy() * std + mean
        # 1. recover 3d joints representation by ik
        motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
        # 2. put on Floor (Y axis)
        floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
        motion[:, :, 1] -= floor_height
        motion = motion.numpy()
        # 3. remove jitter
        motion = motion_temporal_filter(motion, sigma=1)

        samples.append(motion)
    
    i = 1
    title = texts[i]
    motion = samples[i]
    kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
    plot_3d_motion(save_paths[1], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)

    
    gr.Info("Rendered motion...", duration = 3)
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
    
    video_dis = f'<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[1]}"></video>'
    
    
    edit_config = set_all_use_to_false(edit_config)
    return video_dis

@spaces.GPU
def generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline):
    global edit_config
    edit_config.example_based.use = True
    edit_config.example_based.chunk_size = chunk_size
    edit_config.example_based.example_based_steps_end = example_based_steps_end
    edit_config.example_based.temp_seed = temp_seed
    edit_config.example_based.temp_seed_bar = temp_seed_bar


    gr.Info("Loading Configurations...", duration = 3)
    model = build_models(opt, edit_config=edit_config)
    ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')  
    niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)

    pipeline = DiffusePipeline(
        opt = opt,
        model = model, 
        diffuser_name = opt.diffuser_name, 
        device=opt.device,
        num_inference_steps=opt.num_inference_steps,
        torch_dtype=torch.float16,
    )
    
    width = 500
    height = 500
    texts = [text for _ in range(num_motion)]
    motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
    
    save_dir = '/home/user/app/temp/'
    filenames = [generate_md5(str(time.time())) + ".mp4" for _ in range(num_motion)]
    save_paths = [pjoin(save_dir, str(filenames[i])) for i in range(num_motion)]
    os.makedirs(save_dir, exist_ok=True)
    
    start_time = time.perf_counter()
    gr.Info("Generating motion...", duration = 3)
    pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
    start_time = time.perf_counter()
    mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
    std = np.load(pjoin(opt.meta_dir, 'std.npy'))
    
    
    samples = []
    
    root_list = []
    progress=gr.Progress()
    progress(0, desc="Starting...")
    for i, motion in enumerate(pred_motions):
        motion = motion.cpu().numpy() * std + mean
        # 1. recover 3d joints representation by ik
        motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
        # 2. put on Floor (Y axis)
        floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
        motion[:, :, 1] -= floor_height
        motion = motion.numpy()
        # 3. remove jitter
        motion = motion_temporal_filter(motion, sigma=1)

        samples.append(motion)
    
    video_dis = []
    i = 0
    for title in progress.tqdm(texts):
        print(save_paths[i])
        title = texts[i]
        motion = samples[i]
        kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
        plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
        video_html = f'''
        <video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()" autoplay loop disablepictureinpicture src="./file={save_paths[i]}"> </video>
        '''
        video_dis.append(video_html)
        i += 1
        
    for _ in range(24 - num_motion):
        video_dis.append(None)
    gr.Info("Rendered motion...", duration = 3)
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
        
    edit_config = set_all_use_to_false(edit_config)
    return video_dis

@spaces.GPU
def transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline):
    global edit_config
    edit_config.style_tranfer.use = True
    edit_config.style_tranfer.style_transfer_steps_end = style_transfer_steps_end

    gr.Info("Loading Configurations...", duration = 3)
    model = build_models(opt, edit_config=edit_config)
    ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')  
    niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)

    pipeline = DiffusePipeline(
        opt = opt,
        model = model, 
        diffuser_name = opt.diffuser_name, 
        device=opt.device,
        num_inference_steps=opt.num_inference_steps,
        torch_dtype=torch.float16,
    )
    
    print(edit_config)
    
    width = 500
    height = 500
    texts = [style_text, text, text]
    motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
    
    save_dir = '/home/user/app/temp/'
    filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"]
    save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1])), pjoin(save_dir, str(filenames[2]))]
    os.makedirs(save_dir, exist_ok=True)
    
    start_time = time.perf_counter()
    gr.Info("Generating motion...", duration = 3)
    pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
    start_time = time.perf_counter()
    mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
    std = np.load(pjoin(opt.meta_dir, 'std.npy'))
    
    samples = []
    
    root_list = []
    for i, motion in enumerate(pred_motions):
        motion = motion.cpu().numpy() * std + mean
        # 1. recover 3d joints representation by ik
        motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
        # 2. put on Floor (Y axis)
        floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
        motion[:, :, 1] -= floor_height
        motion = motion.numpy()
        # 3. remove jitter
        motion = motion_temporal_filter(motion, sigma=1)

        samples.append(motion)
    
    for i,title in enumerate(texts):
        title = texts[i]
        motion = samples[i]
        kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
        plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)

    gr.Info("Rendered motion...", duration = 3)
    end_time = time.perf_counter()
    exc = end_time - start_time
    gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
    
    video_dis0 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[0]}"></video> <br> <p align="center"> Style Reference </p>"""
    video_dis1 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[2]}"></video> <br> <p align="center"> Content Reference </p>"""
    video_dis2 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[1]}"></video> <br> <p align="center"> Transfered Result </p>"""
    
    edit_config = set_all_use_to_false(edit_config)
    return video_dis0, video_dis2


def main():
    parser = GenerateOptions()
    opt = parser.parse_app()
    set_seed(opt.seed)
    device_id = opt.gpu_id
    device = torch.device('cuda:%d' % device_id if torch.cuda.is_available() else 'cpu')
    opt.device = device
    print(device)

    # load model
    model = build_models(opt, edit_config=edit_config)
    ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')  
    niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)

    pipeline = DiffusePipeline(
        opt = opt,
        model = model, 
        diffuser_name = opt.diffuser_name, 
        device=device,
        num_inference_steps=opt.num_inference_steps,
        torch_dtype=torch.float16,
    )
    
    with gr.Blocks(theme=gr.themes.Glass()) as demo:
        gr.HTML(HEAD)
        with gr.Row():
            with gr.Column(scale=7):
                text_input = gr.Textbox(label="Input the text prompt to generate motion...")
            with gr.Column(scale=3):
                sequence_length = gr.Slider(minimum=1, maximum=9.6, step=0.1, label="Motion length", value=8)
        with gr.Row(): 
            generate_button = gr.Button("Generate motion")
            
        with gr.Row():
            video_display = gr.Video() #gr.HTML(label="生成的视频", visible=True)
        

        tabs = gr.Tabs(visible=True)
        with tabs:
            with gr.Tab("Motion (de-)emphasizing"):
                with gr.Row():
                    int_input = gr.Number(label="Editing word index", minimum=0, maximum=70)
                    weight_input = gr.Slider(minimum=-1, maximum=1, step=0.01, label="Input weight for (de-)emphasizing [-1, 1]", value=0)
                
                trim_button = gr.Button("Edit reweighting")
                
                with gr.Row():
                    original_video1 = gr.HTML(label="before editing", visible=False)
                    edited_video = gr.HTML(label="after editing")
                
                trim_button.click(
                    fn=lambda x, int_input, weight_input : reweighting(x, int_input, weight_input, opt, pipeline), 
                    inputs=[text_input, int_input, weight_input],
                    outputs=edited_video,
                    )

            with gr.Tab("Example-based motion genration"):
                with gr.Row():
                    with gr.Column(scale=4):
                        chunk_size = gr.Number(minimum=10, maximum=20, step=10,label="Chunk size (#frames)", value=20)
                        example_based_steps_end = gr.Number(minimum=0, maximum=9,label="Ending step of manipulation", value=6)
                    with gr.Column(scale=3):
                        temp_seed = gr.Number(label="Seed for random", value=200, minimum=0)
                        temp_seed_bar = gr.Slider(minimum=0, maximum=100, step=1, label="Seed for random bar", value=15)
                    with gr.Column(scale=3):
                        num_motion = gr.Radio(choices=[4, 8, 12, 16, 24], value=8, label="Select number of motions")
                    
                gen_button = gr.Button("Generate example-based motion")
                
                
                example_video_display = []
                for _ in range(6):
                    with gr.Row():
                        for _ in range(4):
                            video = gr.HTML(label="Example-based motion", visible=True)
                            example_video_display.append(video)

                gen_button.click(
                    fn=lambda text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion: generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline),
                    inputs=[text_input, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion],
                    outputs=example_video_display
                )

            with gr.Tab("Style transfer"):
                with gr.Row():
                    style_text = gr.Textbox(label="Reference prompt (e.g. 'a man walks.')", value="a man walks.")
                    style_transfer_steps_end = gr.Number(label="The end step of diffusion (0~9)", minimum=0, maximum=9, value=5)

                style_transfer_button = gr.Button("Transfer style")

                with gr.Row():
                    style_reference = gr.HTML(label="style reference")
                    original_video4 = gr.HTML(label="before style transfer", visible=False)
                    styled_video = gr.HTML(label="after style transfer")

                style_transfer_button.click(
                    fn=lambda text, style_text, style_transfer_steps_end: transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline),
                    inputs=[text_input, style_text, style_transfer_steps_end],
                    outputs=[style_reference, styled_video],
                )
        
        def update_motion_length(sequence_length):
            opt.motion_length = sequence_length
        
        def on_generate(text, length, pipeline):
            update_motion_length(length)
            return generate_video_from_text(text, opt, pipeline)

                
        generate_button.click(
            fn=lambda text, length: on_generate(text, length, pipeline),  
            inputs=[text_input, sequence_length],
            outputs=[
                video_display, 
                original_video1, 
                original_video4,
                tabs,
                ], 
            show_progress=True
        )
        
        generate_button.click(
            fn=lambda: [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)],
            inputs=None,
            outputs=[video_display, original_video1, original_video4]
        )

    demo.launch()


if __name__ == '__main__':
    main()