import os
import time
from pathlib import Path
from loguru import logger
from datetime import datetime
import gradio as gr
import random
import spaces
import torch

from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.utils.preprocess_text_encoder_tokenizer_utils import preprocess_text_encoder_tokenizer
from hyvideo.config import parse_args
from hyvideo.inference import HunyuanVideoSampler
from hyvideo.constants import NEGATIVE_PROMPT

from huggingface_hub import snapshot_download

if torch.cuda.device_count() > 0:
    snapshot_download(repo_id="tencent/HunyuanVideo", repo_type="model", local_dir="ckpts", force_download=True)
    snapshot_download(repo_id="xtuner/llava-llama-3-8b-v1_1-transformers", repo_type="model", local_dir="ckpts/llava-llama-3-8b-v1_1-transformers", force_download=True)
    preprocess_text_encoder_tokenizer(input_dir = "ckpts/llava-llama-3-8b-v1_1-transformers", output_dir = "ckpts/text_encoder")
    snapshot_download(repo_id="openai/clip-vit-large-patch14", repo_type="model", local_dir="ckpts/text_encoder_2", force_download=True)

def initialize_model(model_path):
    print('initialize_model: ' + model_path)
    if torch.cuda.device_count() == 0:
        return None
    
    args = parse_args()
    models_root_path = Path(model_path)
    if not models_root_path.exists():
        raise ValueError(f"`models_root` not exists: {models_root_path}")
    
    print(f"`models_root` exists: {models_root_path}")
    hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
    print('Model initialized: ' + model_path)
    return hunyuan_video_sampler

@spaces.GPU(duration=120)
def generate_video(
    model,
    prompt,
    resolution,
    video_length,
    seed,
    num_inference_steps,
    guidance_scale,
    flow_shift,
    embedded_guidance_scale
):
    if torch.cuda.device_count() == 0:
        gr.Warning('Set this space to GPU config to make it work.')
        return None
    
    seed = None if seed == -1 else seed
    width, height = resolution.split("x")
    width, height = int(width), int(height)
    negative_prompt = "" # not applicable in the inference

    outputs = model.predict(
        prompt=prompt,
        height=height,
        width=width, 
        video_length=video_length,
        seed=seed,
        negative_prompt=negative_prompt,
        infer_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_videos_per_prompt=1,
        flow_shift=flow_shift,
        batch_size=1,
        embedded_guidance_scale=embedded_guidance_scale
    )
    
    samples = outputs['samples']
    sample = samples[0].unsqueeze(0)
    
    save_path = "./gradio_outputs"
    os.makedirs(save_path, exist_ok=True)
    
    time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
    video_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4"
    save_videos_grid(sample, video_path, fps=24)
    logger.info(f'Sample saved to: {video_path}')
    
    return video_path

def create_demo(model_path):
    model = initialize_model(model_path)
    
    with gr.Blocks() as demo:
        if torch.cuda.device_count() == 0:
            with gr.Row():
                gr.HTML("""
                    <p style="background-color: red;"><big><big><big><b>⚠️To use <i>Hunyuan Video</i>, <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/HunyuanVideo?duplicate=true">duplicate this space</a> and set a GPU with 80 GB VRAM.</b>
    
                    You can't use <i>Hunyuan Video</i> directly here because this space runs on a CPU, which is not enough for <i>Hunyuan Video</i>. Please provide <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/HunyuanVideo/discussions/new">feedback</a> if you have issues.
                    </big></big></big></p>
                    """)
        gr.Markdown("# Hunyuan Video Generation")
        
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="A cat walks on the grass, realistic style.")
                with gr.Row():
                    resolution = gr.Dropdown(
                        choices=[
                            # 720p
                            ("1280x720 (16:9, 720p)", "1280x720"),
                            ("720x1280 (9:16, 720p)", "720x1280"), 
                            ("1104x832 (4:3, 720p)", "1104x832"),
                            ("832x1104 (3:4, 720p)", "832x1104"),
                            ("960x960 (1:1, 720p)", "960x960"),
                            # 540p
                            ("960x544 (16:9, 540p)", "960x544"),
                            ("544x960 (9:16, 540p)", "544x960"),
                            ("832x624 (4:3, 540p)", "832x624"), 
                            ("624x832 (3:4, 540p)", "624x832"),
                            ("720x720 (1:1, 540p)", "720x720"),
                        ],
                        value="832x624",
                        label="Resolution"
                    )
                    video_length = gr.Dropdown(
                        label="Video Length",
                        choices=[
                            ("2s(65f)", 65),
                            ("5s(129f)", 129),
                        ],
                        value=65,
                    )
                num_inference_steps = gr.Slider(1, 100, value=5, step=1, label="Number of Inference Steps")
                
                with gr.Accordion("Advanced Options", open=False):
                    with gr.Column():
                        seed = gr.Slider(label="Seed (-1 for random)", value=-1, minimum=-1, maximum=2**63 - 1, step=1)
                        guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale")
                        flow_shift = gr.Slider(0.0, 10.0, value=7.0, step=0.1, label="Flow Shift") 
                        embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale")

                generate_btn = gr.Button(value = "🚀 Generate Video", variant = "primary")
            
            with gr.Row():
                output = gr.Video(label = "Generated Video", autoplay = True)

        gr.Markdown("""
## **Alternatives**
If you can't use _Hunyuan Video_, you can use _[CogVideoX](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)_ or _[LTX Video Playground](https://huggingface.co/spaces/Lightricks/LTX-Video-Playground)_ instead.
                    """)
        
        generate_btn.click(
            fn=lambda *inputs: generate_video(model, *inputs),
            inputs=[
                prompt,
                resolution,
                video_length,
                seed,
                num_inference_steps,
                guidance_scale,
                flow_shift,
                embedded_guidance_scale
            ],
            outputs=output
        )
    
    return demo

if __name__ == "__main__":
    os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
    demo = create_demo("ckpts")
    demo.queue(10).launch()