cakemus's picture
.
71eb29f
import os
import cv2
from moviepy import VideoFileClip, concatenate_videoclips
import gradio as gr
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
import spaces
"""
@spaces.GPU(duration=5)
def test_gpu():
print("Testing GPU availability...")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available.")
test_gpu()
"""
# Function to get the last frame of a video
def get_last_frame(video_path):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
ret, frame = cap.read()
if ret:
last_frame_path = "last_frame.jpg"
cv2.imwrite(last_frame_path, frame)
cap.release()
return last_frame_path
cap.release()
raise ValueError("Failed to extract the last frame from the video.")
# Function to combine multiple videos into a single video
def combine_videos(video_paths, output_file):
clips = [VideoFileClip(video) for video in video_paths]
final_clip = concatenate_videoclips(clips)
final_clip.write_videofile(output_file, codec="libx264")
for clip in clips:
clip.close()
return output_file
# Video generation function using Stable Video Diffusion
@spaces.GPU(duration=120) # Allocate ZeroGPU
def generate_video(image_path, seed):
try:
# Load Stable Video Diffusion pipeline
pipeline = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
torch_dtype=torch.float16,
variant="fp16"
)
pipeline.to("cuda")
# Load and preprocess the input image
image = load_image(image_path)
image = image.resize((1024, 576))
# Generate video
generator = torch.Generator(device="cuda").manual_seed(seed)
frames = pipeline(image, decode_chunk_size=8, generator=generator).frames[0]
# Export frames to video
output_video_path = f"generated_video_seed_{seed}.mp4"
export_to_video(frames, output_video_path, fps=7)
# Release GPU resources
del frames, generator
torch.cuda.empty_cache()
return output_video_path
except Exception as e:
return f"Error during video generation: {str(e)}"
# Main function for iterative video generation
@spaces.GPU(duration=299)
def iterative_video_generation(initial_image, iterations):
try:
seed = 42 # Fixed seed for reproducibility
current_image = initial_image
video_paths = []
for i in range(iterations):
print(f"Iteration {i+1}: Generating video...")
video_path = generate_video(current_image, seed + i)
video_paths.append(video_path)
print("Extracting last frame for next generation...")
current_image = get_last_frame(video_path)
print("Combining all videos into a single output...")
final_video_path = "final_combined_video.mp4"
combine_videos(video_paths, final_video_path)
return final_video_path
except Exception as e:
return f"Error during iterative video generation: {str(e)}"
# Gradio interface for the iterative video generation
iface = gr.Interface(
fn=iterative_video_generation,
inputs=[
gr.Image(type="filepath", label="Upload Initial Image"),
gr.Number(label="Number of Iterations", value=3)
],
outputs=gr.Video(label="Final Combined Video"),
title="Iterative Stable Video Diffusion",
description=(
"MAX ITERATIONS SET TO THREE FOR EVALUATION BY TEACHERS. DO NOT EXCEED FOR THE VIDEO GENERATION TO WORK "
"Generate videos iteratively using Stable Video Diffusion. (stabilityai/stable-video-diffusion-img2vid-xt) "
"Each generated video's last frame is used as input for the next generation. "
"The final output is a single combined video. "
"Each iteration has a maximum ZeroGPU compute time of 120s (generally it takes around 90s). Be wary that if you don't have enough compute, the video cannot be generated and your compute will run out. "
"10 iterations will take around 15-20 min of compute time. More than this will likely result in an incomplete generation"
)
)
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch(server_name="0.0.0.0", server_port=7860)