File size: 5,985 Bytes
957f5fa
 
 
 
 
 
0317f8c
957f5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0317f8c
957f5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# t2v_cpu_gradio.py
import os, math, tempfile, uuid, time
from pathlib import Path
from typing import List
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from PIL import Image
import imageio
import gradio as gr
from tqdm import tqdm

# -------------------------
# CONFIG (tune for speed)
# -------------------------
MODEL_ID = "runwayml/stable-diffusion-v1-5"  # or other small sd model
DEVICE = "cpu"
STEPS = 20                # small # of steps -> much faster but lower quality
WIDTH, HEIGHT = 512, 320  # smaller height helps speed; keep multiples of 8
NUM_FRAMES = 8            # base key frames to generate
INTERPOLATION_FACTOR = 2  # output FPS multiplier (optional, via RIFE later)
SEED = None               # None = random
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
# -------------------------

torch.set_num_threads(max(1, os.cpu_count()//2))  # limit threads to avoid oversubscription

def make_pipeline():
    # Use DDIMScheduler for faster sampling (fewer steps usually OK)
    scheduler = DDIMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")
    pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, safety_checker=None, torch_dtype=torch.float32)
    pipe.scheduler = scheduler
    pipe = pipe.to(DEVICE)
    # reduce cache, disable progress bars from underlying libs
    pipe.enable_attention_slicing()  # reduces memory usage (useful on CPU)
    return pipe

def seed_context(seed):
    if seed is None:
        seed = int.from_bytes(os.urandom(2), "big")  # small random seed
    generator = torch.Generator(device=DEVICE).manual_seed(seed)
    return generator

def generate_frames(prompt: str, steps:int, width:int, height:int, n_frames:int, seed:int, progress: gr.Progress):
    pipe = make_pipeline()
    gen = seed_context(seed)
    # create one seed per frame for variance (fast)
    frame_paths = []
    for i in range(n_frames):
        prog = (i)/max(1, n_frames)
        progress((prog, f"Generating frame {i+1}/{n_frames}"))
        # small prompt schedule example — could add motion prompts
        prompt_i = f"{prompt} --frame:{i}"
        out = pipe(prompt_i, num_inference_steps=steps, width=width, height=height, generator=gen)
        img = out.images[0].convert("RGB")
        fname = OUTPUT_DIR / f"frame_{i:03d}.png"
        img.save(fname)
        frame_paths.append(fname)
        # tiny sleep so UI shows updates smoothly on busy CPUs
        time.sleep(0.05)
    progress((1.0, "Done generating keyframes"))
    return frame_paths

def simple_rife_interp(frame_paths: List[Path], factor:int, progress: gr.Progress):
    """
    Placeholder: calls out to a RIFE binary or python function if available.
    NOTE: RIFE often needs GPU; CPU versions exist but are slow.
    If RIFE not available, perform simple linear crossfade (cheap, low-quality).
    """
    interp_frames = []
    total_pairs = len(frame_paths)-1
    for idx in range(total_pairs):
        progress((idx/total_pairs, f"Interpolating pair {idx+1}/{total_pairs}"))
        a = Image.open(frame_paths[idx]).convert("RGB")
        b = Image.open(frame_paths[idx+1]).convert("RGB")
        interp_frames.append(frame_paths[idx])  # keep first of pair
        # linear crossfade steps (very cheap)
        for t in range(1, factor):
            alpha = t / factor
            im = Image.blend(a, b, alpha)
            temp = OUTPUT_DIR / f"interp_{idx:03d}_{t:02d}.png"
            im.save(temp)
            interp_frames.append(temp)
    interp_frames.append(frame_paths[-1])
    progress((1.0, "Done interpolation"))
    return interp_frames

def assemble_video(frame_paths: List[Path], fps:int=8):
    out_vid = OUTPUT_DIR / f"video_{uuid.uuid4().hex[:8]}.mp4"
    frames = [imageio.imread(str(p)) for p in frame_paths]
    imageio.mimsave(out_vid, frames, fps=fps)
    return out_vid

# -------------------------
# Gradio UI
# -------------------------
def run_pipeline(prompt: str, steps: int, width: int, height: int, n_frames: int, interp_factor: int, seed_input: int, progress=gr.Progress()):
    start = time.time()
    seed = seed_input if seed_input>0 else None
    # generate keyframes
    frames = generate_frames(prompt, steps, width, height, n_frames, seed, progress)
    # interpolation (cheap fallback implemented)
    if interp_factor and interp_factor>1:
        frames_interp = simple_rife_interp(frames, interp_factor, progress)
    else:
        frames_interp = frames
    # assemble
    progress((0.95, "Assembling video..."))
    vid = assemble_video(frames_interp, fps=4*interp_factor if interp_factor>0 else 4)
    elapsed = time.time()-start
    progress((1.0, f"Finished in {elapsed:.1f}s -> {vid.name}"))
    return str(vid)

with gr.Blocks() as demo:
    gr.Markdown("## CPU Text→Video (fast settings) — Stable Diffusion + Gradio progress")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="A cinematic short looping scene, 3D lighting, minimal text")
            steps = gr.Slider(label="Steps (lower=faster)", minimum=5, maximum=50, value=STEPS, step=1)
            width_in = gr.Dropdown([256,384,512], value=WIDTH, label="Width")
            height_in = gr.Dropdown([192,256,320], value=HEIGHT, label="Height")
            n_frames = gr.Slider(label="Base frames", minimum=2, maximum=12, value=NUM_FRAMES, step=1)
            interp = gr.Slider(label="Interp factor (optional)", minimum=1, maximum=6, value=INTERPOLATION_FACTOR, step=1)
            seed_box = gr.Number(label="Seed (0=random)", value=0)
            run_btn = gr.Button("Generate")
        with gr.Column():
            out_video = gr.Video(label="Result video")
            logs = gr.Textbox(label="Log (last message shown)", interactive=False)
    # attach function with progress param
    run_btn.click(fn=run_pipeline, inputs=[prompt, steps, width_in, height_in, n_frames, interp, seed_box], outputs=[out_video], api_name="generate")

demo.launch()