bmarci's picture
adjustable cfg
35c1a87
raw
history blame
9.45 kB
import gradio as gr
import numpy as np
import spaces
from PIL import Image
import torch
from torch.amp import autocast
from transformers import AutoTokenizer, AutoModel
from models.gen_pipeline import NextStepPipeline
HF_HUB = "stepfun-ai/NextStep-1-Large"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
model = AutoModel.from_pretrained(
HF_HUB,
local_files_only=False,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(device)
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
MAX_SEED = np.iinfo(np.int16).max
DEFAULT_POSITIVE_PROMPT = None
DEFAULT_NEGATIVE_PROMPT = None
DEFAULT_CFG = 7.5
def _ensure_pil(x):
"""Ensure returned image is a PIL.Image.Image."""
if isinstance(x, Image.Image):
return x
import numpy as np
if hasattr(x, "detach"):
x = x.detach().float().clamp(0, 1).cpu().numpy()
if isinstance(x, np.ndarray):
if x.dtype != np.uint8:
x = (x * 255.0).clip(0, 255).astype(np.uint8)
if x.ndim == 3 and x.shape[0] in (1, 3, 4): # CHW -> HWC
x = np.moveaxis(x, 0, -1)
return Image.fromarray(x)
raise TypeError("Unsupported image type returned by pipeline.")
@spaces.GPU(duration=300)
def infer(
prompt=None,
seed=0,
width=512,
height=512,
num_inference_steps=28,
cfg=DEFAULT_CFG,
positive_prompt=DEFAULT_POSITIVE_PROMPT,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
progress=gr.Progress(track_tqdm=True),
):
"""Run inference at exactly (width, height)."""
if prompt in [None, ""]:
gr.Warning("⚠️ Please enter a prompt!")
return None
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
imgs = pipeline.generate_image(
prompt,
hw=(int(height), int(width)),
num_images_per_caption=1,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
cfg=float(cfg),
cfg_img=1.0,
cfg_schedule="constant",
use_norm=False,
num_sampling_steps=int(num_inference_steps),
timesteps_shift=1.0,
seed=int(seed),
progress=True,
)
return _ensure_pil(imgs[0]) # Return raw output exactly as generated
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# NextStep-1-Large — Image generation")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
cancel_button = gr.Button("Cancel", scale=0, variant="secondary")
with gr.Row():
with gr.Accordion("Advanced Settings", open=True):
positive_prompt = gr.Text(
label="Positive Prompt",
show_label=True,
max_lines=1,
placeholder="Optional: add positives",
container=True,
)
negative_prompt = gr.Text(
label="Negative Prompt",
show_label=True,
max_lines=2,
placeholder="Optional: add negatives",
container=True,
)
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=3407,
)
num_inference_steps = gr.Slider(
label="Sampling steps",
minimum=10,
maximum=50,
step=1,
value=28,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=512,
step=64,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=512,
step=64,
value=512,
)
cfg = gr.Slider(
label="CFG (guidance scale)",
minimum=0.0,
maximum=20.0,
step=0.5,
value=DEFAULT_CFG,
info="Higher = closer to text, lower = more creative",
)
with gr.Row():
result_1 = gr.Image(
label="Result",
show_label=True,
container=True,
interactive=False,
format="png",
)
examples = [
[
"Studio portrait of an elderly sailor with a weathered face, dramatic Rembrandt lighting, shallow depth of field",
101, 512, 512, 32, 7.5,
"photorealistic, sharp eyes, detailed skin texture, soft rim light, 85mm lens",
"over-smoothed skin, plastic look, extra limbs, watermark",
],
[
"Isometric cozy coffee shop interior with hanging plants and warm Edison bulbs",
202, 512, 384, 30, 8.5,
"isometric view, clean lines, stylized, warm ambience, detailed furniture",
"text, logo, watermark, perspective distortion",
],
[
"Ultra-wide desert canyon at golden hour with long shadows and dust in the air",
303, 512, 320, 28, 7.0,
"cinematic, volumetric light, natural colors, high dynamic range",
"over-saturated, haze artifacts, blown highlights",
],
[
"Cute red panda astronaut sticker, chibi style, white background",
404, 384, 384, 24, 9.0,
"vector look, bold outlines, high contrast, die-cut silhouette",
"background clutter, drop shadow, gradients, text",
],
[
"Product render of matte-black wireless headphones on reflective glass with soft studio lighting",
505, 512, 384, 28, 7.0,
"clean backdrop, realistic reflections, subtle bloom, high detail",
"noise, fingerprints, text, label",
],
[
"Graphic poster in Bauhaus style with geometric shapes and bold typography placeholders",
606, 512, 512, 22, 6.0,
"flat colors, minimal palette, crisp edges, balanced composition",
"photo realism, gradients, noisy texture",
],
[
"Oil painting of a stormy sea with a lighthouse, thick impasto brushwork",
707, 384, 512, 34, 7.0,
"textured canvas, visible brush strokes, dramatic sky, moody lighting",
"smooth digital look, airbrush, neon colors",
],
[
"Architectural concept art: glass pavilion in a pine forest at dawn, ground fog",
808, 512, 384, 30, 8.0,
"physically-based rendering, soft fog, realistic materials, scale figures",
"tilt, skew, warped geometry, chromatic aberration",
],
[
"Fantasy creature: bioluminescent jellyfish dragon swimming through a dark ocean trench",
909, 512, 512, 32, 8.5,
"glowing tendrils, soft caustics, particles, high detail",
"washed out, murky, low contrast, extra heads",
],
[
"Line art coloring page of a city skyline with hot air balloons",
111, 512, 512, 18, 5.5,
"clean black outlines, uniform stroke weight, high contrast, no shading",
"gray fill, gradients, cross-hatching, text",
],
]
gr.Examples(
examples=examples,
inputs=[
prompt,
seed,
width,
height,
num_inference_steps,
cfg,
positive_prompt,
negative_prompt,
],
label="Click & Fill Examples (Exact Size)",
)
def show_result():
return gr.update(visible=True)
generation_event = gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
seed,
width,
height,
num_inference_steps,
cfg,
positive_prompt,
negative_prompt,
],
outputs=[result_1],
)
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event])
if __name__ == "__main__":
demo.launch()