import os
import sys
import shutil
import uuid
import subprocess
import gradio as gr
import cv2  # 用于检查视频帧数
from glob import glob

from huggingface_hub import snapshot_download, hf_hub_download

# Download models
os.makedirs("pretrained_weights", exist_ok=True)

# List of subdirectories to create inside "checkpoints"
subfolders = [
    "stable-video-diffusion-img2vid-xt"
]

# Create each subdirectory
for subfolder in subfolders:
    os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)

snapshot_download(
    repo_id="stabilityai/stable-video-diffusion-img2vid",
    local_dir="./pretrained_weights/stable-video-diffusion-img2vid-xt"
)

snapshot_download(
    repo_id="Yhmeng1106/anidoc",
    local_dir="./pretrained_weights"
)

hf_hub_download(
    repo_id="facebook/cotracker",
    filename="cotracker2.pth",
    local_dir="./pretrained_weights"
)

def normalize_path(path: str) -> str:
    return path
    """标准化路径,将Windows路径转换为正斜杠形式"""
    return os.path.abspath(path).replace('\\', '/')

def check_video_frames(video_path: str) -> int:
    """检查视频帧数"""
    video_path = normalize_path(video_path)
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return frame_count

def preprocess_video(video_path: str) -> str:
    """预处理视频到14帧"""
    try:
        video_path = normalize_path(video_path)
        unique_id = str(uuid.uuid4())
        temp_dir = "outputs"
        output_dir = os.path.join(temp_dir, f"processed_{unique_id}")
        output_dir = normalize_path(output_dir)
        os.makedirs(output_dir, exist_ok=True)

        print(f"Processing video: {video_path}")
        print(f"Output directory: {output_dir}")

        # 调用外部脚本处理视频
        result = subprocess.run(
            [
                "python", "process_video_to_14frames.py",
                "--input", video_path,
                "--output", output_dir
            ],
            check=True,
            capture_output=True,
            text=True
        )

        if result.stdout:
            print(f"Preprocessing stdout: {result.stdout}")
        if result.stderr:
            print(f"Preprocessing stderr: {result.stderr}")

        # 获取处理后的视频路径
        processed_videos = glob(os.path.join(output_dir, "*.mp4"))
        if not processed_videos:
            raise gr.Error("Failed to process video: No output video found")
        return normalize_path(processed_videos[0])
    except subprocess.CalledProcessError as e:
        print(f"Preprocessing stderr: {e.stderr}")
        raise gr.Error(f"Failed to preprocess video: {e.stderr}")
    except Exception as e:
        raise gr.Error(f"Error during video preprocessing: {str(e)}")

def generate(control_sequence, ref_image):
    control_image = control_sequence  # "data_test/sample4.mp4"
    ref_image = ref_image  # "data_test/sample4.png"
    unique_id = str(uuid.uuid4())
    output_dir = f"results_{unique_id}"

    try:
        # 检查视频帧数
        frame_count = check_video_frames(control_image)
        if frame_count != 14:
            print(f"Video has {frame_count} frames, preprocessing to 14 frames...")
            control_image = preprocess_video(control_image)
            print(f"Preprocessed video saved to: {control_image}")

        # 运行推理命令
        subprocess.run(
            [
                "python", "scripts_infer/anidoc_inference.py",
                "--all_sketch",
                "--matching",
                "--tracking",
                "--control_image", f"{control_image}",
                "--ref_image", f"{ref_image}",
                "--output_dir", f"{output_dir}",
                "--max_point", "10",
            ],
            check=True
        )

        # 搜索输出视频
        output_video = glob(os.path.join(output_dir, "*.mp4"))
        print(output_video)

        if output_video:
            output_video_path = output_video[0]  # 获取第一个匹配
        else:
            output_video_path = None

        print(output_video_path)
        return output_video_path

    except subprocess.CalledProcessError as e:
        raise gr.Error(f"Error during inference: {str(e)}")
    except Exception as e:
        raise gr.Error(f"Error: {str(e)}")

css = """
div#col-container{
    margin: 0 auto;
    max-width: 982px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# AniDoc: Animation Creation Made Easier")
        gr.Markdown("AniDoc colorizes a sequence of sketches based on a character design reference with high fidelity, even when the sketches significantly differ in pose and scale.")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href="https://github.com/yihao-meng/AniDoc">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a> 
            <a href="https://yihao-meng.github.io/AniDoc_demo/">
                <img src='https://img.shields.io/badge/Project-Page-green'>
            </a>
            <a href="https://arxiv.org/pdf/2412.14173">
                <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
            </a>
            <a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
            </a>
            <a href="https://huggingface.co/fffiloni">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
            </a>
        </div>
        """)
        with gr.Row():
            with gr.Column():
                control_sequence = gr.Video(label="Control Sequence", format="mp4")
                ref_image = gr.Image(label="Reference Image", type="filepath")
                submit_btn = gr.Button("Submit")
            with gr.Column():
                video_result = gr.Video(label="Result")

                gr.Examples(
                    examples=[
                        ["data_test/sample5.mp4", "data_test/sample5.png"],
                        ["data_test/sample6.mp4", "data_test/sample6.png"],
                        ["data_test/sample1.mp4", "data_test/sample1.png"],
                        ["data_test/sample2.mp4", "data_test/sample2.png"],
                        ["data_test/sample3.mp4", "data_test/sample3.png"],
                        ["data_test/sample4.mp4", "data_test/sample4.png"]
                    ],
                    inputs=[control_sequence, ref_image]
                )

    submit_btn.click(
        fn=generate,
        inputs=[control_sequence, ref_image],
        outputs=[video_result]
    )

demo.queue().launch(show_api=False, show_error=True, share=True)