import os
import streamlit as st
import torch
from diffusers.utils import load_image

try:
    from diffusers import CogVideoXImageToVideoPipeline
    pipeline_available = True
except ImportError:
    pipeline_available = False
    st.error("Failed to import `CogVideoXImageToVideoPipeline`. Please run `pip install diffusers`.")

st.title("Image to Video with Hugging Face")
st.write("Upload an image and provide a prompt to generate a video.")

if pipeline_available:
    uploaded_file = st.file_uploader("Upload an image (JPG or PNG):", type=["jpg", "jpeg", "png"])
    prompt = st.text_input("Enter your prompt:", "A little girl is riding a bicycle at high speed. Focused, detailed, realistic.")

    if uploaded_file and prompt:
        try:
            # Save uploaded file
            import uuid
            file_name = f"{uuid.uuid4()}_uploaded_image.jpg"
            with open(file_name, "wb") as f:
                f.write(uploaded_file.read())
            st.write("Uploaded image saved successfully.")

            # Load the image
            image = load_image(file_name)

            # Initialize pipeline
            device = "cuda" if torch.cuda.is_available() else "cpu"
            pipe = CogVideoXImageToVideoPipeline.from_pretrained(
                "THUDM/CogVideoX1.5-5B-I2V",
                torch_dtype=torch.bfloat16,
                cache_dir="./huggingface_cache",
            )
            pipe.enable_sequential_cpu_offload()
            pipe.vae.enable_tiling()
            pipe.vae.enable_slicing()


            # Generate video
            with st.spinner("Generating video... This may take a while."):
                try:
                # Attempt to generate the video
                video_frames = pipe(
                    prompt=prompt,
                    image=image,
                    num_videos_per_prompt=1,
                    num_inference_steps=50,
                    num_frames=81,
                    guidance_scale=6,
                    generator=torch.Generator(device=device).manual_seed(42),
                ).frames[0]
            except Exception as e:
                # Handle errors gracefully
                st.error(f"An error occurred during video generation: {e}")