multimodalart's picture
Update app.py
85cc5c2 verified
raw
history blame
8.28 kB
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import SanaSprintPipeline
import peft
from peft.tuners.lora.layer import Linear as LoraLinear
import types
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
adapter_name = "hypernoise_adapter"
# Load the pipeline and adapter
pipe = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
torch_dtype=dtype,
).to(device, dtype)
pipe.transformer = peft.PeftModel.from_pretrained(
pipe.transformer,
"lucaeyring/HyperNoise_Sana_Sprint_0.6B",
adapter_name=adapter_name,
dtype=dtype,
).to(device, dtype)
# Define the custom forward function for LoRA
def scaled_base_lora_forward(self, x, *args, **kwargs):
if self.disable_adapters:
return self.base_layer(x, *args, **kwargs)
return self.lora_B[adapter_name](self.lora_A[adapter_name](x)) * self.scaling[adapter_name]
# Apply the custom forward to proj_out module
for name, module in pipe.transformer.base_model.model.named_modules():
if name == "proj_out" and isinstance(module, LoraLinear):
module.forward = types.MethodType(scaled_base_lora_forward, module)
break
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024 # Sana Sprint is optimized for 1024px
@spaces.GPU()
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
num_inference_steps=4, guidance_scale=4.5, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set random seed for reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Calculate latent dimensions based on image size
# Sana uses 32x downsampling factor
latent_height = height // 32
latent_width = width // 32
with torch.inference_mode():
# Encode the prompt
prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
[prompt],
device=device
)
# Generate initial random latents with correct dimensions
init_latents = torch.randn(
[1, 32, latent_height, latent_width],
device=device,
dtype=dtype
)
# Apply HyperNoise modulation with adapter enabled (single forward pass)
pipe.transformer.enable_adapter_layers()
modulated_latents = pipe.transformer(
hidden_states=init_latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
guidance=torch.tensor([guidance_scale], device=device, dtype=dtype) * 0.1,
timestep=torch.tensor([1.0], device=device, dtype=dtype),
).sample + init_latents
# Generate final image with adapter disabled
pipe.transformer.disable_adapter_layers()
# For SCM scheduler, we need to handle the timesteps carefully
# The pipeline expects intermediate_timesteps only when num_inference_steps=2
# For other values, we use the workaround from the original code
if num_inference_steps == 2:
# Use the default pipeline behavior for 2 steps
image = pipe(
latents=modulated_latents,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
num_inference_steps=num_inference_steps,
).images[0]
else:
# For num_inference_steps != 2, we need to work around the restriction
# by directly calling the denoising loop
pipe.scheduler.set_timesteps(
num_inference_steps,
device=device,
timesteps=torch.linspace(1.57080, 0, num_inference_steps + 1, device=device)
)
# Run the denoising loop manually
latents = modulated_latents
for i, t in enumerate(pipe.scheduler.timesteps[:-1]):
# Expand timestep to match batch dimension
timestep = t.expand(latents.shape[0])
# Predict noise
noise_pred = pipe.transformer(
hidden_states=latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
guidance=torch.tensor([0.0], device=device, dtype=dtype), # No guidance for denoising
return_dict=False,
)[0]
# Compute previous noisy sample
latents = pipe.scheduler.step(
noise_pred,
t,
latents,
return_dict=False
)[0]
# Decode latents to image
latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
latents = (latents / pipe.vae.scaling_factor) + pipe.vae.shift_factor
image = pipe.vae.decode(latents, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
return image, seed
examples = [
"A smiling slice of pizza doing yoga on a mountain top",
"A fluffy cat wearing a wizard hat casting spells",
"A robot painting a self-portrait in Van Gogh style",
"A tiny dragon sleeping in a teacup",
"An astronaut riding a unicorn through a rainbow",
]
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""# HyperNoise Sana Sprint 0.6B
Fast text-to-image generation with HyperNoise adapter for Sana Sprint model.
[[Sana Sprint Model](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers)]
[[HyperNoise Adapter](https://huggingface.co/lucaeyring/HyperNoise_Sana_Sprint_0.6B)]
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1024,
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=20,
step=1,
value=4,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=4.5,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
outputs=[result, seed]
)
demo.launch()