Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from diffusers import DiffusionPipeline, StableVideoDiffusionPipeline | |
from PIL import Image | |
import imageio | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MAX_SEED = np.iinfo(np.int32).max | |
# Load SDXL for image generation | |
sdxl_model_id = "stabilityai/sdxl-turbo" | |
image_pipe = DiffusionPipeline.from_pretrained( | |
sdxl_model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
# Load Stable Video Diffusion for video generation | |
svd_model_id = "stabilityai/stable-video-diffusion-img2vid" | |
video_pipe = StableVideoDiffusionPipeline.from_pretrained( | |
svd_model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
variant="fp16" if device == "cuda" else None | |
) | |
if device == "cuda": | |
video_pipe.enable_model_cpu_offload() | |
def generate_video_from_text(prompt, seed=0, randomize_seed=True): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Generate image | |
image = image_pipe( | |
prompt=prompt, | |
generator=generator, | |
guidance_scale=0.0, | |
num_inference_steps=2, | |
width=1024, | |
height=1024 | |
).images[0] | |
# Resize for SVD | |
image = image.resize((512, 512)) | |
# Generate video | |
video_frames = video_pipe(image).frames[0] | |
video_path = f"/tmp/generated_{seed}.mp4" | |
imageio.mimsave(video_path, video_frames, fps=7) | |
return video_path, image, seed | |
# Use Interface instead of Blocks | |
demo = gr.Interface( | |
fn=generate_video_from_text, | |
inputs=[ | |
gr.Textbox(label="Prompt", placeholder="Describe your scene..."), | |
gr.Number(label="Seed", value=0), | |
gr.Checkbox(label="Randomize Seed", value=True) | |
], | |
outputs=[ | |
gr.Video(label="Generated Video"), | |
gr.Image(label="Generated Image"), | |
gr.Number(label="Seed Used") | |
] | |
) | |
# Expose endpoint | |
demo.api_name = "predict" | |
demo.launch() | |