Video-gen-CPU / app.py
Xenobd's picture
Update app.py
957f5fa verified
# t2v_cpu_gradio.py
import os, math, tempfile, uuid, time
from pathlib import Path
from typing import List
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
import imageio
import gradio as gr
from tqdm import tqdm
# -------------------------
# CONFIG (tune for speed)
# -------------------------
MODEL_ID = "runwayml/stable-diffusion-v1-5" # or other small sd model
DEVICE = "cpu"
STEPS = 20 # small # of steps -> much faster but lower quality
WIDTH, HEIGHT = 512, 320 # smaller height helps speed; keep multiples of 8
NUM_FRAMES = 8 # base key frames to generate
INTERPOLATION_FACTOR = 2 # output FPS multiplier (optional, via RIFE later)
SEED = None # None = random
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
# -------------------------
torch.set_num_threads(max(1, os.cpu_count()//2)) # limit threads to avoid oversubscription
def make_pipeline():
# Use DDIMScheduler for faster sampling (fewer steps usually OK)
scheduler = DDIMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, safety_checker=None, torch_dtype=torch.float32)
pipe.scheduler = scheduler
pipe = pipe.to(DEVICE)
# reduce cache, disable progress bars from underlying libs
pipe.enable_attention_slicing() # reduces memory usage (useful on CPU)
return pipe
def seed_context(seed):
if seed is None:
seed = int.from_bytes(os.urandom(2), "big") # small random seed
generator = torch.Generator(device=DEVICE).manual_seed(seed)
return generator
def generate_frames(prompt: str, steps:int, width:int, height:int, n_frames:int, seed:int, progress: gr.Progress):
pipe = make_pipeline()
gen = seed_context(seed)
# create one seed per frame for variance (fast)
frame_paths = []
for i in range(n_frames):
prog = (i)/max(1, n_frames)
progress((prog, f"Generating frame {i+1}/{n_frames}"))
# small prompt schedule example — could add motion prompts
prompt_i = f"{prompt} --frame:{i}"
out = pipe(prompt_i, num_inference_steps=steps, width=width, height=height, generator=gen)
img = out.images[0].convert("RGB")
fname = OUTPUT_DIR / f"frame_{i:03d}.png"
img.save(fname)
frame_paths.append(fname)
# tiny sleep so UI shows updates smoothly on busy CPUs
time.sleep(0.05)
progress((1.0, "Done generating keyframes"))
return frame_paths
def simple_rife_interp(frame_paths: List[Path], factor:int, progress: gr.Progress):
"""
Placeholder: calls out to a RIFE binary or python function if available.
NOTE: RIFE often needs GPU; CPU versions exist but are slow.
If RIFE not available, perform simple linear crossfade (cheap, low-quality).
"""
interp_frames = []
total_pairs = len(frame_paths)-1
for idx in range(total_pairs):
progress((idx/total_pairs, f"Interpolating pair {idx+1}/{total_pairs}"))
a = Image.open(frame_paths[idx]).convert("RGB")
b = Image.open(frame_paths[idx+1]).convert("RGB")
interp_frames.append(frame_paths[idx]) # keep first of pair
# linear crossfade steps (very cheap)
for t in range(1, factor):
alpha = t / factor
im = Image.blend(a, b, alpha)
temp = OUTPUT_DIR / f"interp_{idx:03d}_{t:02d}.png"
im.save(temp)
interp_frames.append(temp)
interp_frames.append(frame_paths[-1])
progress((1.0, "Done interpolation"))
return interp_frames
def assemble_video(frame_paths: List[Path], fps:int=8):
out_vid = OUTPUT_DIR / f"video_{uuid.uuid4().hex[:8]}.mp4"
frames = [imageio.imread(str(p)) for p in frame_paths]
imageio.mimsave(out_vid, frames, fps=fps)
return out_vid
# -------------------------
# Gradio UI
# -------------------------
def run_pipeline(prompt: str, steps: int, width: int, height: int, n_frames: int, interp_factor: int, seed_input: int, progress=gr.Progress()):
start = time.time()
seed = seed_input if seed_input>0 else None
# generate keyframes
frames = generate_frames(prompt, steps, width, height, n_frames, seed, progress)
# interpolation (cheap fallback implemented)
if interp_factor and interp_factor>1:
frames_interp = simple_rife_interp(frames, interp_factor, progress)
else:
frames_interp = frames
# assemble
progress((0.95, "Assembling video..."))
vid = assemble_video(frames_interp, fps=4*interp_factor if interp_factor>0 else 4)
elapsed = time.time()-start
progress((1.0, f"Finished in {elapsed:.1f}s -> {vid.name}"))
return str(vid)
with gr.Blocks() as demo:
gr.Markdown("## CPU Text→Video (fast settings) — Stable Diffusion + Gradio progress")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="A cinematic short looping scene, 3D lighting, minimal text")
steps = gr.Slider(label="Steps (lower=faster)", minimum=5, maximum=50, value=STEPS, step=1)
width_in = gr.Dropdown([256,384,512], value=WIDTH, label="Width")
height_in = gr.Dropdown([192,256,320], value=HEIGHT, label="Height")
n_frames = gr.Slider(label="Base frames", minimum=2, maximum=12, value=NUM_FRAMES, step=1)
interp = gr.Slider(label="Interp factor (optional)", minimum=1, maximum=6, value=INTERPOLATION_FACTOR, step=1)
seed_box = gr.Number(label="Seed (0=random)", value=0)
run_btn = gr.Button("Generate")
with gr.Column():
out_video = gr.Video(label="Result video")
logs = gr.Textbox(label="Log (last message shown)", interactive=False)
# attach function with progress param
run_btn.click(fn=run_pipeline, inputs=[prompt, steps, width_in, height_in, n_frames, interp, seed_box], outputs=[out_video], api_name="generate")
demo.launch()