""" Lumina 2.0 ZeroGPU demo – Hugging Face Spaces Run locally with: python app.py Push to HF and pick the “ZeroGPU” hardware tier. """ import os, gc, json, random, time import numpy as np import torch import gradio as gr import spaces from torchvision.transforms.functional import to_pil_image from tqdm import tqdm from diffusers.models import AutoencoderKL from transformers import AutoModel, AutoTokenizer from huggingface_hub import hf_hub_download import functools import models # local from transport import Sampler, create_transport # ------------------------------------------------------------------ # CONSTANTS / ENV CONFIG # ------------------------------------------------------------------ REPO_ID = "OnomaAIResearch/Illustrious-Lumina-v0.03" HF_TOKEN = os.getenv("HF_TOKEN") CKPT_FILE = "consolidated_ema.00-of-01.pth" ARGS_FILE = "model_args.pth" VAE_TYPE = os.getenv("VAE_TYPE", "flux") # flux | ema | mse | sdxl PRECISION = os.getenv("PRECISION", "bf16") # bf16 | fp16 | fp32 TEXT_ENCODER_MODEL = os.getenv("TEXT_ENCODER_MODEL", "google/gemma-2-2B") DTYPE = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[PRECISION] # ------------------------------------------------------------------ # GLOBALS (kept on **CPU**) # ------------------------------------------------------------------ tokenizer = text_encoder = vae = model = sampler = transport = None train_args = cap_feat_dim = None # ------------------------------------------------------------------ # HELPERS # ------------------------------------------------------------------ def encode_prompt(batch, enc, tok, dev, dtype): """Temporarily moves the text‑encoder to *dev* to extract embeddings.""" captions = [(c if isinstance(c, str) else c[0]) for c in batch] enc.to(dev) with torch.no_grad(), torch.autocast(device_type=dev.split(":")[0], dtype=dtype): inputs = tok( captions, padding=True, pad_to_multiple_of=8, max_length=256, truncation=True, return_tensors="pt", ).to(dev) out = enc(**inputs, output_hidden_states=True).hidden_states[-2] enc.to("cpu"); gc.collect() return out.cpu(), inputs.attention_mask.cpu() def none_or_str(v): # used by create_transport in original repo return None if v in (None, "None") else str(v) def load_models(): global tokenizer, text_encoder, vae, model, sampler, transport, train_args, cap_feat_dim torch.set_grad_enabled(False) # ---------------- Hub downloads (cached after first run) ------ ckpt_path = hf_hub_download(REPO_ID, filename=CKPT_FILE, token=HF_TOKEN) args_path = hf_hub_download(REPO_ID, filename=ARGS_FILE, token=HF_TOKEN) # --------------- training args -------------------------------- train_args = torch.load(args_path, map_location="cpu", weights_only=False) # --------------- tokenizer / text encoder --------------------- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", token=HF_TOKEN, padding_side="right") text_encoder = AutoModel.from_pretrained( "google/gemma-2-2b", torch_dtype=DTYPE, token=HF_TOKEN ).eval().cpu() cap_feat_dim = text_encoder.config.hidden_size # --------------- VAE ------------------------------------------ vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", token=HF_TOKEN, torch_dtype=DTYPE ).eval().cpu() # --------------- DiT backbone -------------------------------- dit_cls = getattr(models, train_args.model) model = dit_cls(in_channels=16, qk_norm=getattr(train_args, "qk_norm", True), cap_feat_dim=cap_feat_dim).eval().cpu() state = torch.load(ckpt_path, map_location="cpu") state = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state.items()} model.load_state_dict(state, strict=False) # --------------- Sampler -------------------------------------- transport = create_transport("Linear", "velocity", None, None, None) sampler = Sampler(transport) print("🔄 Loading models to CPU …") load_models() print("✅ Models loaded on CPU") # ------------------------------------------------------------------ # INFERENCE (GPU on‑demand) # ------------------------------------------------------------------ @spaces.GPU(duration=120) # << ZeroGPU magic def generate_image( prompt, negative_prompt, system_type, solver, resolution_str, guidance_scale, num_steps, seed, time_shifting_factor, t_shift, atol=1e-6, rtol=1e-3, ): """Runs every time a user clicks “Generate”. ZeroGPU will attach an A100; GPU is released on return.""" # ---- pick GPU if present (after decorator) ---- device = "cuda" if torch.cuda.is_available() else "cpu" dtype = DTYPE if device == "cuda" else torch.float32 # ---- seed ---- seed = random.randint(0, 2**32 - 1) if int(seed) == -1 else int(seed) torch.manual_seed(seed); np.random.seed(seed); random.seed(seed) # ---- prompts ---- sys_prompts = { "align": "You are an assistant designed to generate high-quality images with the highest degree of image‑text alignment based on textual prompts. ", "base": "You are an assistant designed to generate high-quality images based on user prompts. ", "aesthetics": "You are an assistant designed to generate high-quality images with highest degree of aesthetics based on user prompts. ", "real": "You are an assistant designed to generate superior images with the superior degree of image‑text alignment based on textual prompts or user prompts. ", "4grid": "You are an assistant designed to generate four high-quality images with highest degree of aesthetics arranged in 2x2 grids based on user prompts. ", "tags": "You are an assistant designed to generate high-quality images based on user prompts based on danbooru tags. ", "empty": "", } full_prompt = sys_prompts.get(system_type, sys_prompts["base"]) + prompt full_neg = (sys_prompts.get(system_type, "") + negative_prompt) if negative_prompt else "" # ---- resolution ---- w, h = map(int, resolution_str.split("x")); lat_w, lat_h = w // 8, h // 8 # ---- encode prompts ---- cap_feats_cpu, cap_mask_cpu = encode_prompt( [full_prompt, full_neg], text_encoder, tokenizer, device, dtype ) # ---------------- SAMPLING ---------------- model.to(device) z = torch.randn([1, 16, lat_h, lat_w], device=device, dtype=dtype).repeat(2,1,1,1) model_kwargs = dict( cap_feats = cap_feats_cpu.to(device, dtype=dtype), cap_mask = cap_mask_cpu.to(device), cfg_scale = guidance_scale, ); del cap_feats_cpu, cap_mask_cpu with torch.no_grad(), torch.autocast(device_type="cuda", dtype=dtype): if solver == "dpm": _sampler = Sampler(create_transport("Linear", "velocity")) samples = _sampler.sample_dpm( model.forward_with_cfg, model_kwargs=model_kwargs )(z, steps=num_steps, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=time_shifting_factor) else: samples = sampler.sample_ode( sampling_method=solver, num_steps=num_steps, atol=atol, rtol=rtol, time_shifting_factor=t_shift )(z, model.forward_with_cfg, **model_kwargs)[-1] samples = samples[:1] # keep positive branch # ---------------- DECODE ---------------- vae.to(device) with torch.no_grad(), torch.autocast(device_type="cuda", dtype=dtype): sf, sh = vae.config.scaling_factor, vae.config.shift_factor img = vae.decode(samples / sf + sh)[0] vae.to("cpu"); model.to("cpu"); torch.cuda.empty_cache(); gc.collect() img = ((img.cpu() + 1) / 2).clamp(0,1) pil = to_pil_image(img[0].float()) return pil, f"Seed {seed}", seed # ------------------------------------------------------------------ # GRADIO UI # ------------------------------------------------------------------ with gr.Blocks() as demo: gr.Markdown("# Lumina 2.0 (ZeroGPU)") with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox(label="Prompt", value="1girl, kita ikuyo, bocchi the rock!, solo, backlighting, blurry, depth of field, bloom, light particles, transparent, blurry foreground, indoors, upper body, red hair, yellow flower, school uniform, white shirt, hair between eyes, lily \(flower\), floating hair, serafuku, chromatic aberration, white lily, green eyes, looking at viewer, red neckerchief, flower, pink flower, sunlight, day, neckerchief, sailor collar, grey sailor collar, white flower, lens flare abuse, medium hair, holding, closed mouth, one side up, long sleeves, holding bouquet, arms at sides, light smile, shirt, blurry background, bouquet") negative_prompt = gr.Textbox(label="Negative prompt", value="mutated, worst quality, blurry, bad anatomy, bad hands") system_type = gr.Dropdown(choices=["align","base","aesthetics","real","4grid","tags","empty"], value="tags", label="System prompt") resolution = gr.Dropdown(choices=["1024x1024","1280x768","768x1280","1536x1024","1024x1536"], value="1024x1024", label="Resolution") solver = gr.Dropdown(choices=["dpm","euler","midpoint","heun","rk4"], value="euler", label="Solver") run_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=4.0, label="CFG scale") num_steps = gr.Slider(10, 200, step=1, value=50, label="Steps") seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)") time_shifting_factor = gr.Slider(0.0, 10.0, step=0.1, value=1.0, label="Time‑shift (DPM)") t_shift = gr.Slider(0, 10, step=1, value=4, label="T‑shift (ODE)") out_img = gr.Image(label="Output") out_txt = gr.Textbox(label="Status") out_seed = gr.Number(label="Seed used", interactive=False) run_btn.click( generate_image, inputs=[prompt,negative_prompt,system_type,solver,resolution, guidance_scale,num_steps,seed,time_shifting_factor,t_shift], outputs=[out_img,out_txt,out_seed], concurrency_limit=1 # helps stay within 1 GPU / user ) if __name__ == "__main__": demo.launch()