Spaces:
Runtime error
Runtime error
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py | |
import argparse | |
import os | |
import random | |
import socket | |
import tempfile | |
import time | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from app import safety_check | |
from app.sana_controlnet_pipeline import SanaControlNetPipeline | |
STYLES = { | |
"None": "{prompt}", | |
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", | |
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting", | |
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed", | |
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed", | |
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed", | |
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics", | |
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", | |
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", | |
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style", | |
} | |
DEFAULT_STYLE_NAME = "None" | |
STYLE_NAMES = list(STYLES.keys()) | |
MAX_SEED = 1000000000 | |
DEFAULT_SKETCH_GUIDANCE = 0.28 | |
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, help="config") | |
parser.add_argument( | |
"--model_path", | |
nargs="?", | |
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth", | |
type=str, | |
help="Path to the model file (positional)", | |
) | |
parser.add_argument("--output", default="./", type=str) | |
parser.add_argument("--bs", default=1, type=int) | |
parser.add_argument("--image_size", default=1024, type=int) | |
parser.add_argument("--cfg_scale", default=5.0, type=float) | |
parser.add_argument("--pag_scale", default=2.0, type=float) | |
parser.add_argument("--seed", default=42, type=int) | |
parser.add_argument("--step", default=-1, type=int) | |
parser.add_argument("--custom_image_size", default=None, type=int) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument( | |
"--shield_model_path", | |
type=str, | |
help="The path to shield model, we employ ShieldGemma-2B by default.", | |
default="google/shieldgemma-2b", | |
) | |
return parser.parse_known_args()[0] | |
args = get_args() | |
if torch.cuda.is_available(): | |
model_path = args.model_path | |
pipe = SanaControlNetPipeline(args.config) | |
pipe.from_pretrained(model_path) | |
pipe.register_progress_bar(gr.Progress()) | |
# safety checker | |
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) | |
safety_checker_model = AutoModelForCausalLM.from_pretrained( | |
args.shield_model_path, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
).to(device) | |
def save_image(img): | |
if isinstance(img, dict): | |
img = img["composite"] | |
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
img.save(temp_file.name) | |
return temp_file.name | |
def norm_ip(img, low, high): | |
img.clamp_(min=low, max=high) | |
img.sub_(low).div_(max(high - low, 1e-5)) | |
return img | |
def run( | |
image, | |
prompt: str, | |
prompt_template: str, | |
sketch_thickness: int, | |
guidance_scale: float, | |
inference_steps: int, | |
seed: int, | |
blend_alpha: float, | |
) -> tuple[Image, str]: | |
print(f"Prompt: {prompt}") | |
image_numpy = np.array(image["composite"].convert("RGB")) | |
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): | |
return blank_image, "Please input the prompt or draw something." | |
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): | |
prompt = "A red heart." | |
prompt = prompt_template.format(prompt=prompt) | |
pipe.set_blend_alpha(blend_alpha) | |
start_time = time.time() | |
images = pipe( | |
prompt=prompt, | |
ref_image=image["composite"], | |
guidance_scale=guidance_scale, | |
num_inference_steps=inference_steps, | |
num_images_per_prompt=1, | |
sketch_thickness=sketch_thickness, | |
generator=torch.Generator(device=device).manual_seed(seed), | |
) | |
latency = time.time() - start_time | |
if latency < 1: | |
latency = latency * 1000 | |
latency_str = f"{latency:.2f}ms" | |
else: | |
latency_str = f"{latency:.2f}s" | |
torch.cuda.empty_cache() | |
img = [ | |
Image.fromarray( | |
norm_ip(img, -1, 1) | |
.mul(255) | |
.add_(0.5) | |
.clamp_(0, 255) | |
.permute(1, 2, 0) | |
.to("cpu", torch.uint8) | |
.numpy() | |
.astype(np.uint8) | |
) | |
for img in images | |
] | |
img = img[0] | |
return img, latency_str | |
model_size = "1.6" if "1600M" in args.model_path else "0.6" | |
title = f""" | |
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'> | |
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/> | |
</div> | |
""" | |
DESCRIPTION = f""" | |
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p> | |
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p> | |
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p> | |
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}. | |
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p> | |
""" | |
if model_size == "0.6": | |
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo: | |
gr.Markdown(title) | |
gr.HTML(DESCRIPTION) | |
with gr.Row(elem_id="main_row"): | |
with gr.Column(elem_id="column_input"): | |
gr.Markdown("## INPUT", elem_id="input_header") | |
with gr.Group(): | |
canvas = gr.Sketchpad( | |
value=blank_image, | |
height=640, | |
image_mode="RGB", | |
sources=["upload", "clipboard"], | |
type="pil", | |
label="Sketch", | |
show_label=False, | |
show_download_button=True, | |
interactive=True, | |
transforms=[], | |
canvas_size=(1024, 1024), | |
scale=1, | |
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"), | |
format="png", | |
layers=False, | |
) | |
with gr.Row(): | |
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) | |
run_button = gr.Button("Run", scale=1, elem_id="run_button") | |
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch") | |
with gr.Row(): | |
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) | |
prompt_template = gr.Textbox( | |
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1 | |
) | |
with gr.Row(): | |
sketch_thickness = gr.Slider( | |
label="Sketch Thickness", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=2, | |
) | |
with gr.Row(): | |
inference_steps = gr.Slider( | |
label="Sampling steps", | |
minimum=5, | |
maximum=40, | |
step=1, | |
value=20, | |
) | |
guidance_scale = gr.Slider( | |
label="CFG Guidance scale", | |
minimum=1, | |
maximum=10, | |
step=0.1, | |
value=4.5, | |
) | |
blend_alpha = gr.Slider( | |
label="Blend Alpha", | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0, | |
) | |
with gr.Row(): | |
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) | |
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") | |
with gr.Column(elem_id="column_output"): | |
gr.Markdown("## OUTPUT", elem_id="output_header") | |
with gr.Group(): | |
result = gr.Image( | |
format="png", | |
height=640, | |
image_mode="RGB", | |
type="pil", | |
label="Result", | |
show_label=False, | |
show_download_button=True, | |
interactive=False, | |
elem_id="output_image", | |
) | |
latency_result = gr.Text(label="Inference Latency", show_label=True) | |
download_result = gr.DownloadButton("Download Result", elem_id="download_result") | |
gr.Markdown("### Instructions") | |
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") | |
gr.Markdown("**2**. Start sketching or upload a reference image") | |
gr.Markdown("**3**. Change the image style using a style template") | |
gr.Markdown("**4**. Try different seeds to generate different results") | |
run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha] | |
run_outputs = [result, latency_result] | |
randomize_seed.click( | |
lambda: random.randint(0, MAX_SEED), | |
inputs=[], | |
outputs=seed, | |
api_name=False, | |
queue=False, | |
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) | |
style.change( | |
lambda x: STYLES[x], | |
inputs=[style], | |
outputs=[prompt_template], | |
api_name=False, | |
queue=False, | |
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False) | |
gr.on( | |
triggers=[prompt.submit, run_button.click, canvas.change], | |
fn=run, | |
inputs=run_inputs, | |
outputs=run_outputs, | |
api_name=False, | |
) | |
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch) | |
download_result.click(fn=save_image, inputs=result, outputs=download_result) | |
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share) | |