Spaces:
Sleeping
Sleeping
# t2v_cpu_gradio.py | |
import os, math, tempfile, uuid, time | |
from pathlib import Path | |
from typing import List | |
import torch | |
from diffusers import StableDiffusionPipeline, DDIMScheduler | |
from PIL import Image | |
import imageio | |
import gradio as gr | |
from tqdm import tqdm | |
# ------------------------- | |
# CONFIG (tune for speed) | |
# ------------------------- | |
MODEL_ID = "runwayml/stable-diffusion-v1-5" # or other small sd model | |
DEVICE = "cpu" | |
STEPS = 20 # small # of steps -> much faster but lower quality | |
WIDTH, HEIGHT = 512, 320 # smaller height helps speed; keep multiples of 8 | |
NUM_FRAMES = 8 # base key frames to generate | |
INTERPOLATION_FACTOR = 2 # output FPS multiplier (optional, via RIFE later) | |
SEED = None # None = random | |
OUTPUT_DIR = Path("outputs") | |
OUTPUT_DIR.mkdir(exist_ok=True, parents=True) | |
# ------------------------- | |
torch.set_num_threads(max(1, os.cpu_count()//2)) # limit threads to avoid oversubscription | |
def make_pipeline(): | |
# Use DDIMScheduler for faster sampling (fewer steps usually OK) | |
scheduler = DDIMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler") | |
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, safety_checker=None, torch_dtype=torch.float32) | |
pipe.scheduler = scheduler | |
pipe = pipe.to(DEVICE) | |
# reduce cache, disable progress bars from underlying libs | |
pipe.enable_attention_slicing() # reduces memory usage (useful on CPU) | |
return pipe | |
def seed_context(seed): | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") # small random seed | |
generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
return generator | |
def generate_frames(prompt: str, steps:int, width:int, height:int, n_frames:int, seed:int, progress: gr.Progress): | |
pipe = make_pipeline() | |
gen = seed_context(seed) | |
# create one seed per frame for variance (fast) | |
frame_paths = [] | |
for i in range(n_frames): | |
prog = (i)/max(1, n_frames) | |
progress((prog, f"Generating frame {i+1}/{n_frames}")) | |
# small prompt schedule example — could add motion prompts | |
prompt_i = f"{prompt} --frame:{i}" | |
out = pipe(prompt_i, num_inference_steps=steps, width=width, height=height, generator=gen) | |
img = out.images[0].convert("RGB") | |
fname = OUTPUT_DIR / f"frame_{i:03d}.png" | |
img.save(fname) | |
frame_paths.append(fname) | |
# tiny sleep so UI shows updates smoothly on busy CPUs | |
time.sleep(0.05) | |
progress((1.0, "Done generating keyframes")) | |
return frame_paths | |
def simple_rife_interp(frame_paths: List[Path], factor:int, progress: gr.Progress): | |
""" | |
Placeholder: calls out to a RIFE binary or python function if available. | |
NOTE: RIFE often needs GPU; CPU versions exist but are slow. | |
If RIFE not available, perform simple linear crossfade (cheap, low-quality). | |
""" | |
interp_frames = [] | |
total_pairs = len(frame_paths)-1 | |
for idx in range(total_pairs): | |
progress((idx/total_pairs, f"Interpolating pair {idx+1}/{total_pairs}")) | |
a = Image.open(frame_paths[idx]).convert("RGB") | |
b = Image.open(frame_paths[idx+1]).convert("RGB") | |
interp_frames.append(frame_paths[idx]) # keep first of pair | |
# linear crossfade steps (very cheap) | |
for t in range(1, factor): | |
alpha = t / factor | |
im = Image.blend(a, b, alpha) | |
temp = OUTPUT_DIR / f"interp_{idx:03d}_{t:02d}.png" | |
im.save(temp) | |
interp_frames.append(temp) | |
interp_frames.append(frame_paths[-1]) | |
progress((1.0, "Done interpolation")) | |
return interp_frames | |
def assemble_video(frame_paths: List[Path], fps:int=8): | |
out_vid = OUTPUT_DIR / f"video_{uuid.uuid4().hex[:8]}.mp4" | |
frames = [imageio.imread(str(p)) for p in frame_paths] | |
imageio.mimsave(out_vid, frames, fps=fps) | |
return out_vid | |
# ------------------------- | |
# Gradio UI | |
# ------------------------- | |
def run_pipeline(prompt: str, steps: int, width: int, height: int, n_frames: int, interp_factor: int, seed_input: int, progress=gr.Progress()): | |
start = time.time() | |
seed = seed_input if seed_input>0 else None | |
# generate keyframes | |
frames = generate_frames(prompt, steps, width, height, n_frames, seed, progress) | |
# interpolation (cheap fallback implemented) | |
if interp_factor and interp_factor>1: | |
frames_interp = simple_rife_interp(frames, interp_factor, progress) | |
else: | |
frames_interp = frames | |
# assemble | |
progress((0.95, "Assembling video...")) | |
vid = assemble_video(frames_interp, fps=4*interp_factor if interp_factor>0 else 4) | |
elapsed = time.time()-start | |
progress((1.0, f"Finished in {elapsed:.1f}s -> {vid.name}")) | |
return str(vid) | |
with gr.Blocks() as demo: | |
gr.Markdown("## CPU Text→Video (fast settings) — Stable Diffusion + Gradio progress") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="A cinematic short looping scene, 3D lighting, minimal text") | |
steps = gr.Slider(label="Steps (lower=faster)", minimum=5, maximum=50, value=STEPS, step=1) | |
width_in = gr.Dropdown([256,384,512], value=WIDTH, label="Width") | |
height_in = gr.Dropdown([192,256,320], value=HEIGHT, label="Height") | |
n_frames = gr.Slider(label="Base frames", minimum=2, maximum=12, value=NUM_FRAMES, step=1) | |
interp = gr.Slider(label="Interp factor (optional)", minimum=1, maximum=6, value=INTERPOLATION_FACTOR, step=1) | |
seed_box = gr.Number(label="Seed (0=random)", value=0) | |
run_btn = gr.Button("Generate") | |
with gr.Column(): | |
out_video = gr.Video(label="Result video") | |
logs = gr.Textbox(label="Log (last message shown)", interactive=False) | |
# attach function with progress param | |
run_btn.click(fn=run_pipeline, inputs=[prompt, steps, width_in, height_in, n_frames, interp, seed_box], outputs=[out_video], api_name="generate") | |
demo.launch() | |