AngelBottomless's picture
Update app.py
cf1a5eb verified
"""
Lumina 2.0 ZeroGPU demo – Hugging Face Spaces
Run locally with: python app.py
Push to HF and pick the “ZeroGPU” hardware tier.
"""
import os, gc, json, random, time
import numpy as np
import torch
import gradio as gr
import spaces
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
import functools
import models # local
from transport import Sampler, create_transport
# ------------------------------------------------------------------
# CONSTANTS / ENV CONFIG
# ------------------------------------------------------------------
REPO_ID = "OnomaAIResearch/Illustrious-Lumina-v0.03"
HF_TOKEN = os.getenv("HF_TOKEN")
CKPT_FILE = "consolidated_ema.00-of-01.pth"
ARGS_FILE = "model_args.pth"
VAE_TYPE = os.getenv("VAE_TYPE", "flux") # flux | ema | mse | sdxl
PRECISION = os.getenv("PRECISION", "bf16") # bf16 | fp16 | fp32
TEXT_ENCODER_MODEL = os.getenv("TEXT_ENCODER_MODEL", "google/gemma-2-2B")
DTYPE = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[PRECISION]
# ------------------------------------------------------------------
# GLOBALS (kept on **CPU**)
# ------------------------------------------------------------------
tokenizer = text_encoder = vae = model = sampler = transport = None
train_args = cap_feat_dim = None
# ------------------------------------------------------------------
# HELPERS
# ------------------------------------------------------------------
def encode_prompt(batch, enc, tok, dev, dtype):
"""Temporarily moves the text‑encoder to *dev* to extract embeddings."""
captions = [(c if isinstance(c, str) else c[0]) for c in batch]
enc.to(dev)
with torch.no_grad(), torch.autocast(device_type=dev.split(":")[0], dtype=dtype):
inputs = tok(
captions,
padding=True,
pad_to_multiple_of=8,
max_length=256,
truncation=True,
return_tensors="pt",
).to(dev)
out = enc(**inputs, output_hidden_states=True).hidden_states[-2]
enc.to("cpu"); gc.collect()
return out.cpu(), inputs.attention_mask.cpu()
def none_or_str(v): # used by create_transport in original repo
return None if v in (None, "None") else str(v)
def load_models():
global tokenizer, text_encoder, vae, model, sampler, transport, train_args, cap_feat_dim
torch.set_grad_enabled(False)
# ---------------- Hub downloads (cached after first run) ------
ckpt_path = hf_hub_download(REPO_ID, filename=CKPT_FILE, token=HF_TOKEN)
args_path = hf_hub_download(REPO_ID, filename=ARGS_FILE, token=HF_TOKEN)
# --------------- training args --------------------------------
train_args = torch.load(args_path, map_location="cpu", weights_only=False)
# --------------- tokenizer / text encoder ---------------------
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", token=HF_TOKEN, padding_side="right")
text_encoder = AutoModel.from_pretrained(
"google/gemma-2-2b", torch_dtype=DTYPE, token=HF_TOKEN
).eval().cpu()
cap_feat_dim = text_encoder.config.hidden_size
# --------------- VAE ------------------------------------------
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
token=HF_TOKEN,
torch_dtype=DTYPE
).eval().cpu()
# --------------- DiT backbone --------------------------------
dit_cls = getattr(models, train_args.model)
model = dit_cls(in_channels=16, qk_norm=getattr(train_args, "qk_norm", True),
cap_feat_dim=cap_feat_dim).eval().cpu()
state = torch.load(ckpt_path, map_location="cpu")
state = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state.items()}
model.load_state_dict(state, strict=False)
# --------------- Sampler --------------------------------------
transport = create_transport("Linear", "velocity", None, None, None)
sampler = Sampler(transport)
print("🔄 Loading models to CPU …")
load_models()
print("✅ Models loaded on CPU")
# ------------------------------------------------------------------
# INFERENCE (GPU on‑demand)
# ------------------------------------------------------------------
@spaces.GPU(duration=120) # << ZeroGPU magic
def generate_image(
prompt, negative_prompt, system_type, solver, resolution_str,
guidance_scale, num_steps, seed, time_shifting_factor, t_shift,
atol=1e-6, rtol=1e-3,
):
"""Runs every time a user clicks “Generate”.
ZeroGPU will attach an A100; GPU is released on return."""
# ---- pick GPU if present (after decorator) ----
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = DTYPE if device == "cuda" else torch.float32
# ---- seed ----
seed = random.randint(0, 2**32 - 1) if int(seed) == -1 else int(seed)
torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
# ---- prompts ----
sys_prompts = {
"align": "You are an assistant designed to generate high-quality images with the highest degree of image‑text alignment based on textual prompts. <Prompt Start> ",
"base": "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ",
"aesthetics": "You are an assistant designed to generate high-quality images with highest degree of aesthetics based on user prompts. <Prompt Start> ",
"real": "You are an assistant designed to generate superior images with the superior degree of image‑text alignment based on textual prompts or user prompts. <Prompt Start> ",
"4grid": "You are an assistant designed to generate four high-quality images with highest degree of aesthetics arranged in 2x2 grids based on user prompts. <Prompt Start> ",
"tags": "You are an assistant designed to generate high-quality images based on user prompts based on danbooru tags. <Prompt Start> ",
"empty": "",
}
full_prompt = sys_prompts.get(system_type, sys_prompts["base"]) + prompt
full_neg = (sys_prompts.get(system_type, "") + negative_prompt) if negative_prompt else ""
# ---- resolution ----
w, h = map(int, resolution_str.split("x")); lat_w, lat_h = w // 8, h // 8
# ---- encode prompts ----
cap_feats_cpu, cap_mask_cpu = encode_prompt(
[full_prompt, full_neg], text_encoder, tokenizer, device, dtype
)
# ---------------- SAMPLING ----------------
model.to(device)
z = torch.randn([1, 16, lat_h, lat_w], device=device, dtype=dtype).repeat(2,1,1,1)
model_kwargs = dict(
cap_feats = cap_feats_cpu.to(device, dtype=dtype),
cap_mask = cap_mask_cpu.to(device),
cfg_scale = guidance_scale,
); del cap_feats_cpu, cap_mask_cpu
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=dtype):
if solver == "dpm":
_sampler = Sampler(create_transport("Linear", "velocity"))
samples = _sampler.sample_dpm(
model.forward_with_cfg, model_kwargs=model_kwargs
)(z, steps=num_steps, order=2,
skip_type="time_uniform_flow", method="multistep",
flow_shift=time_shifting_factor)
else:
samples = sampler.sample_ode(
sampling_method=solver, num_steps=num_steps,
atol=atol, rtol=rtol, time_shifting_factor=t_shift
)(z, model.forward_with_cfg, **model_kwargs)[-1]
samples = samples[:1] # keep positive branch
# ---------------- DECODE ----------------
vae.to(device)
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=dtype):
sf, sh = vae.config.scaling_factor, vae.config.shift_factor
img = vae.decode(samples / sf + sh)[0]
vae.to("cpu"); model.to("cpu"); torch.cuda.empty_cache(); gc.collect()
img = ((img.cpu() + 1) / 2).clamp(0,1)
pil = to_pil_image(img[0].float())
return pil, f"Seed {seed}", seed
# ------------------------------------------------------------------
# GRADIO UI
# ------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Lumina 2.0 (ZeroGPU)")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Prompt", value="1girl, kita ikuyo, bocchi the rock!, solo, backlighting, blurry, depth of field, bloom, light particles, transparent, blurry foreground, indoors, upper body, red hair, yellow flower, school uniform, white shirt, hair between eyes, lily \(flower\), floating hair, serafuku, chromatic aberration, white lily, green eyes, looking at viewer, red neckerchief, flower, pink flower, sunlight, day, neckerchief, sailor collar, grey sailor collar, white flower, lens flare abuse, medium hair, holding, closed mouth, one side up, long sleeves, holding bouquet, arms at sides, light smile, shirt, blurry background, bouquet")
negative_prompt = gr.Textbox(label="Negative prompt", value="mutated, worst quality, blurry, bad anatomy, bad hands")
system_type = gr.Dropdown(choices=["align","base","aesthetics","real","4grid","tags","empty"],
value="tags", label="System prompt")
resolution = gr.Dropdown(choices=["1024x1024","1280x768","768x1280","1536x1024","1024x1536"],
value="1024x1024", label="Resolution")
solver = gr.Dropdown(choices=["dpm","euler","midpoint","heun","rk4"],
value="euler", label="Solver")
run_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=4.0, label="CFG scale")
num_steps = gr.Slider(10, 200, step=1, value=50, label="Steps")
seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
time_shifting_factor = gr.Slider(0.0, 10.0, step=0.1, value=1.0, label="Time‑shift (DPM)")
t_shift = gr.Slider(0, 10, step=1, value=4, label="T‑shift (ODE)")
out_img = gr.Image(label="Output")
out_txt = gr.Textbox(label="Status")
out_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
generate_image,
inputs=[prompt,negative_prompt,system_type,solver,resolution,
guidance_scale,num_steps,seed,time_shifting_factor,t_shift],
outputs=[out_img,out_txt,out_seed],
concurrency_limit=1 # helps stay within 1 GPU / user
)
if __name__ == "__main__":
demo.launch()