zzc0208 commited on
Commit
0bc6b41
·
verified ·
1 Parent(s): 4a96787

Delete app

Browse files
app/app_sana.py DELETED
@@ -1,502 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # SPDX-License-Identifier: Apache-2.0
17
- from __future__ import annotations
18
-
19
- import argparse
20
- import os
21
- import random
22
- import socket
23
- import sqlite3
24
- import time
25
- import uuid
26
- from datetime import datetime
27
-
28
- import gradio as gr
29
- import numpy as np
30
- import spaces
31
- import torch
32
- from PIL import Image
33
- from torchvision.utils import make_grid, save_image
34
- from transformers import AutoModelForCausalLM, AutoTokenizer
35
-
36
- from app import safety_check
37
- from app.sana_pipeline import SanaPipeline
38
-
39
- MAX_SEED = np.iinfo(np.int32).max
40
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
41
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
42
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
43
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
44
- DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
45
- os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
46
- COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
47
-
48
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
-
50
- style_list = [
51
- {
52
- "name": "(No style)",
53
- "prompt": "{prompt}",
54
- "negative_prompt": "",
55
- },
56
- {
57
- "name": "Cinematic",
58
- "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
59
- "cinemascope, moody, epic, gorgeous, film grain, grainy",
60
- "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
61
- },
62
- {
63
- "name": "Photographic",
64
- "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
65
- "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
66
- },
67
- {
68
- "name": "Anime",
69
- "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
70
- "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
71
- },
72
- {
73
- "name": "Manga",
74
- "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
75
- "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
76
- },
77
- {
78
- "name": "Digital Art",
79
- "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
80
- "negative_prompt": "photo, photorealistic, realism, ugly",
81
- },
82
- {
83
- "name": "Pixel art",
84
- "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
85
- "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
86
- },
87
- {
88
- "name": "Fantasy art",
89
- "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
90
- "majestic, magical, fantasy art, cover art, dreamy",
91
- "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
92
- "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
93
- "disfigured, sloppy, duplicate, mutated, black and white",
94
- },
95
- {
96
- "name": "Neonpunk",
97
- "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
98
- "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
99
- "ultra detailed, intricate, professional",
100
- "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
101
- },
102
- {
103
- "name": "3D Model",
104
- "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
105
- "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
106
- },
107
- ]
108
-
109
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
110
- STYLE_NAMES = list(styles.keys())
111
- DEFAULT_STYLE_NAME = "(No style)"
112
- SCHEDULE_NAME = ["Flow_DPM_Solver"]
113
- DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
114
- NUM_IMAGES_PER_PROMPT = 1
115
- INFER_SPEED = 0
116
-
117
-
118
- def norm_ip(img, low, high):
119
- img.clamp_(min=low, max=high)
120
- img.sub_(low).div_(max(high - low, 1e-5))
121
- return img
122
-
123
-
124
- def open_db():
125
- db = sqlite3.connect(COUNTER_DB)
126
- db.execute("CREATE TABLE IF NOT EXISTS counter(app CHARS PRIMARY KEY UNIQUE, value INTEGER)")
127
- db.execute('INSERT OR IGNORE INTO counter(app, value) VALUES("Sana", 0)')
128
- return db
129
-
130
-
131
- def read_inference_count():
132
- with open_db() as db:
133
- cur = db.execute('SELECT value FROM counter WHERE app="Sana"')
134
- db.commit()
135
- return cur.fetchone()[0]
136
-
137
-
138
- def write_inference_count(count):
139
- count = max(0, int(count))
140
- with open_db() as db:
141
- db.execute(f'UPDATE counter SET value=value+{count} WHERE app="Sana"')
142
- db.commit()
143
-
144
-
145
- def run_inference(num_imgs=1):
146
- write_inference_count(num_imgs)
147
- count = read_inference_count()
148
-
149
- return (
150
- f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
151
- f"16px; color:red; font-weight: bold;'>{count}</span>"
152
- )
153
-
154
-
155
- def update_inference_count():
156
- count = read_inference_count()
157
- return (
158
- f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
159
- f"16px; color:red; font-weight: bold;'>{count}</span>"
160
- )
161
-
162
-
163
- def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
164
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
165
- if not negative:
166
- negative = ""
167
- return p.replace("{prompt}", positive), n + negative
168
-
169
-
170
- def get_args():
171
- parser = argparse.ArgumentParser()
172
- parser.add_argument("--config", type=str, help="config")
173
- parser.add_argument(
174
- "--model_path",
175
- nargs="?",
176
- default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
177
- type=str,
178
- help="Path to the model file (positional)",
179
- )
180
- parser.add_argument("--output", default="./", type=str)
181
- parser.add_argument("--bs", default=1, type=int)
182
- parser.add_argument("--image_size", default=1024, type=int)
183
- parser.add_argument("--cfg_scale", default=5.0, type=float)
184
- parser.add_argument("--pag_scale", default=2.0, type=float)
185
- parser.add_argument("--seed", default=42, type=int)
186
- parser.add_argument("--step", default=-1, type=int)
187
- parser.add_argument("--custom_image_size", default=None, type=int)
188
- parser.add_argument("--share", action="store_true")
189
- parser.add_argument(
190
- "--shield_model_path",
191
- type=str,
192
- help="The path to shield model, we employ ShieldGemma-2B by default.",
193
- default="google/shieldgemma-2b",
194
- )
195
-
196
- return parser.parse_known_args()[0]
197
-
198
-
199
- args = get_args()
200
-
201
- if torch.cuda.is_available():
202
- model_path = args.model_path
203
- pipe = SanaPipeline(args.config)
204
- pipe.from_pretrained(model_path)
205
- pipe.register_progress_bar(gr.Progress())
206
-
207
- # safety checker
208
- safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
209
- safety_checker_model = AutoModelForCausalLM.from_pretrained(
210
- args.shield_model_path,
211
- device_map="auto",
212
- torch_dtype=torch.bfloat16,
213
- ).to(device)
214
-
215
-
216
- def save_image_sana(img, seed="", save_img=False):
217
- unique_name = f"{str(uuid.uuid4())}_{seed}.png"
218
- save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
219
- os.umask(0o000) # file permission: 666; dir permission: 777
220
- os.makedirs(save_path, exist_ok=True)
221
- unique_name = os.path.join(save_path, unique_name)
222
- if save_img:
223
- save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
224
-
225
- return unique_name
226
-
227
-
228
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
229
- if randomize_seed:
230
- seed = random.randint(0, MAX_SEED)
231
- return seed
232
-
233
-
234
- @torch.no_grad()
235
- @torch.inference_mode()
236
- @spaces.GPU(enable_queue=True)
237
- def generate(
238
- prompt: str = None,
239
- negative_prompt: str = "",
240
- style: str = DEFAULT_STYLE_NAME,
241
- use_negative_prompt: bool = False,
242
- num_imgs: int = 1,
243
- seed: int = 0,
244
- height: int = 1024,
245
- width: int = 1024,
246
- flow_dpms_guidance_scale: float = 5.0,
247
- flow_dpms_pag_guidance_scale: float = 2.0,
248
- flow_dpms_inference_steps: int = 20,
249
- randomize_seed: bool = False,
250
- ):
251
- global INFER_SPEED
252
- # seed = 823753551
253
- box = run_inference(num_imgs)
254
- seed = int(randomize_seed_fn(seed, randomize_seed))
255
- generator = torch.Generator(device=device).manual_seed(seed)
256
- print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
257
- if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
258
- prompt = "A red heart."
259
-
260
- print(prompt)
261
-
262
- num_inference_steps = flow_dpms_inference_steps
263
- guidance_scale = flow_dpms_guidance_scale
264
- pag_guidance_scale = flow_dpms_pag_guidance_scale
265
-
266
- if not use_negative_prompt:
267
- negative_prompt = None # type: ignore
268
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
269
-
270
- pipe.progress_fn(0, desc="Sana Start")
271
-
272
- time_start = time.time()
273
- images = pipe(
274
- prompt=prompt,
275
- height=height,
276
- width=width,
277
- negative_prompt=negative_prompt,
278
- guidance_scale=guidance_scale,
279
- pag_guidance_scale=pag_guidance_scale,
280
- num_inference_steps=num_inference_steps,
281
- num_images_per_prompt=num_imgs,
282
- generator=generator,
283
- )
284
-
285
- pipe.progress_fn(1.0, desc="Sana End")
286
- INFER_SPEED = (time.time() - time_start) / num_imgs
287
-
288
- save_img = False
289
- if save_img:
290
- img = [save_image_sana(img, seed, save_img=save_image) for img in images]
291
- print(img)
292
- else:
293
- img = [
294
- Image.fromarray(
295
- norm_ip(img, -1, 1)
296
- .mul(255)
297
- .add_(0.5)
298
- .clamp_(0, 255)
299
- .permute(1, 2, 0)
300
- .to("cpu", torch.uint8)
301
- .numpy()
302
- .astype(np.uint8)
303
- )
304
- for img in images
305
- ]
306
-
307
- torch.cuda.empty_cache()
308
-
309
- return (
310
- img,
311
- seed,
312
- f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
313
- box,
314
- )
315
-
316
-
317
- model_size = "1.6" if "1600M" in args.model_path else "0.6"
318
- title = f"""
319
- <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
320
- <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
321
- </div>
322
- """
323
- DESCRIPTION = f"""
324
- <p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
325
- <p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
326
- <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
327
- <p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
328
- <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
329
- """
330
- if model_size == "0.6":
331
- DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
332
- if not torch.cuda.is_available():
333
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
334
-
335
- examples = [
336
- 'a cyberpunk cat with a neon sign that says "Sana"',
337
- "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
338
- "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
339
- "portrait photo of a girl, photograph, highly detailed face, depth of field",
340
- 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
341
- "🐶 Wearing 🕶 flying on the 🌈",
342
- "👧 with 🌹 in the ❄️",
343
- "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
344
- "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
345
- "Astronaut in a jungle, cold color palette, muted colors, detailed",
346
- "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
347
- ]
348
-
349
- css = """
350
- .gradio-container{max-width: 640px !important}
351
- h1{text-align:center}
352
- """
353
- with gr.Blocks(css=css, title="Sana") as demo:
354
- gr.Markdown(title)
355
- gr.HTML(DESCRIPTION)
356
- gr.DuplicateButton(
357
- value="Duplicate Space for private use",
358
- elem_id="duplicate-button",
359
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
360
- )
361
- info_box = gr.Markdown(
362
- value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
363
- )
364
- demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
365
- # with gr.Row(equal_height=False):
366
- with gr.Group():
367
- with gr.Row():
368
- prompt = gr.Text(
369
- label="Prompt",
370
- show_label=False,
371
- max_lines=1,
372
- placeholder="Enter your prompt",
373
- container=False,
374
- )
375
- run_button = gr.Button("Run", scale=0)
376
- result = gr.Gallery(label="Result", show_label=False, columns=NUM_IMAGES_PER_PROMPT, format="png")
377
- speed_box = gr.Markdown(
378
- value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
379
- )
380
- with gr.Accordion("Advanced options", open=False):
381
- with gr.Group():
382
- with gr.Row(visible=True):
383
- height = gr.Slider(
384
- label="Height",
385
- minimum=256,
386
- maximum=MAX_IMAGE_SIZE,
387
- step=32,
388
- value=args.image_size,
389
- )
390
- width = gr.Slider(
391
- label="Width",
392
- minimum=256,
393
- maximum=MAX_IMAGE_SIZE,
394
- step=32,
395
- value=args.image_size,
396
- )
397
- with gr.Row():
398
- flow_dpms_inference_steps = gr.Slider(
399
- label="Sampling steps",
400
- minimum=5,
401
- maximum=40,
402
- step=1,
403
- value=20,
404
- )
405
- flow_dpms_guidance_scale = gr.Slider(
406
- label="CFG Guidance scale",
407
- minimum=1,
408
- maximum=10,
409
- step=0.1,
410
- value=4.5,
411
- )
412
- flow_dpms_pag_guidance_scale = gr.Slider(
413
- label="PAG Guidance scale",
414
- minimum=1,
415
- maximum=4,
416
- step=0.5,
417
- value=1.0,
418
- )
419
- with gr.Row():
420
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
421
- negative_prompt = gr.Text(
422
- label="Negative prompt",
423
- max_lines=1,
424
- placeholder="Enter a negative prompt",
425
- visible=True,
426
- )
427
- style_selection = gr.Radio(
428
- show_label=True,
429
- container=True,
430
- interactive=True,
431
- choices=STYLE_NAMES,
432
- value=DEFAULT_STYLE_NAME,
433
- label="Image Style",
434
- )
435
- seed = gr.Slider(
436
- label="Seed",
437
- minimum=0,
438
- maximum=MAX_SEED,
439
- step=1,
440
- value=0,
441
- )
442
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
443
- with gr.Row(visible=True):
444
- schedule = gr.Radio(
445
- show_label=True,
446
- container=True,
447
- interactive=True,
448
- choices=SCHEDULE_NAME,
449
- value=DEFAULT_SCHEDULE_NAME,
450
- label="Sampler Schedule",
451
- visible=True,
452
- )
453
- num_imgs = gr.Slider(
454
- label="Num Images",
455
- minimum=1,
456
- maximum=6,
457
- step=1,
458
- value=1,
459
- )
460
-
461
- gr.Examples(
462
- examples=examples,
463
- inputs=prompt,
464
- outputs=[result, seed],
465
- fn=generate,
466
- cache_examples=CACHE_EXAMPLES,
467
- )
468
-
469
- use_negative_prompt.change(
470
- fn=lambda x: gr.update(visible=x),
471
- inputs=use_negative_prompt,
472
- outputs=negative_prompt,
473
- api_name=False,
474
- )
475
-
476
- gr.on(
477
- triggers=[
478
- prompt.submit,
479
- negative_prompt.submit,
480
- run_button.click,
481
- ],
482
- fn=generate,
483
- inputs=[
484
- prompt,
485
- negative_prompt,
486
- style_selection,
487
- use_negative_prompt,
488
- num_imgs,
489
- seed,
490
- height,
491
- width,
492
- flow_dpms_guidance_scale,
493
- flow_dpms_pag_guidance_scale,
494
- flow_dpms_inference_steps,
495
- randomize_seed,
496
- ],
497
- outputs=[result, seed, speed_box, info_box],
498
- api_name="run",
499
- )
500
-
501
- if __name__ == "__main__":
502
- demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/app_sana_4bit.py DELETED
@@ -1,409 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- #!/usr/bin/env python
6
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
7
- #
8
- # Licensed under the Apache License, Version 2.0 (the "License");
9
- # you may not use this file except in compliance with the License.
10
- # You may obtain a copy of the License at
11
- #
12
- # http://www.apache.org/licenses/LICENSE-2.0
13
- #
14
- # Unless required by applicable law or agreed to in writing, software
15
- # distributed under the License is distributed on an "AS IS" BASIS,
16
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
- # See the License for the specific language governing permissions and
18
- # limitations under the License.
19
- #
20
- # SPDX-License-Identifier: Apache-2.0
21
- from __future__ import annotations
22
-
23
- import argparse
24
- import os
25
- import random
26
- import time
27
- import uuid
28
- from datetime import datetime
29
-
30
- import gradio as gr
31
- import numpy as np
32
- import spaces
33
- import torch
34
- from diffusers import SanaPipeline
35
- from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
36
- from torchvision.utils import save_image
37
-
38
- MAX_SEED = np.iinfo(np.int32).max
39
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
40
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
41
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
42
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
43
- DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
44
- os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
45
- COUNTER_DB = os.getenv("COUNTER_DB", ".count.db")
46
- INFER_SPEED = 0
47
-
48
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
-
50
- style_list = [
51
- {
52
- "name": "(No style)",
53
- "prompt": "{prompt}",
54
- "negative_prompt": "",
55
- },
56
- {
57
- "name": "Cinematic",
58
- "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
59
- "cinemascope, moody, epic, gorgeous, film grain, grainy",
60
- "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
61
- },
62
- {
63
- "name": "Photographic",
64
- "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
65
- "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
66
- },
67
- {
68
- "name": "Anime",
69
- "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
70
- "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
71
- },
72
- {
73
- "name": "Manga",
74
- "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
75
- "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
76
- },
77
- {
78
- "name": "Digital Art",
79
- "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
80
- "negative_prompt": "photo, photorealistic, realism, ugly",
81
- },
82
- {
83
- "name": "Pixel art",
84
- "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
85
- "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
86
- },
87
- {
88
- "name": "Fantasy art",
89
- "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
90
- "majestic, magical, fantasy art, cover art, dreamy",
91
- "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
92
- "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
93
- "disfigured, sloppy, duplicate, mutated, black and white",
94
- },
95
- {
96
- "name": "Neonpunk",
97
- "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
98
- "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
99
- "ultra detailed, intricate, professional",
100
- "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
101
- },
102
- {
103
- "name": "3D Model",
104
- "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
105
- "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
106
- },
107
- ]
108
-
109
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
110
- STYLE_NAMES = list(styles.keys())
111
- DEFAULT_STYLE_NAME = "(No style)"
112
- SCHEDULE_NAME = ["Flow_DPM_Solver"]
113
- DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
114
- NUM_IMAGES_PER_PROMPT = 1
115
-
116
-
117
- def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
118
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
119
- if not negative:
120
- negative = ""
121
- return p.replace("{prompt}", positive), n + negative
122
-
123
-
124
- def get_args():
125
- parser = argparse.ArgumentParser()
126
- parser.add_argument(
127
- "--model_path",
128
- nargs="?",
129
- default="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
130
- type=str,
131
- help="Path to the model file (positional)",
132
- )
133
- parser.add_argument("--share", action="store_true")
134
-
135
- return parser.parse_known_args()[0]
136
-
137
-
138
- args = get_args()
139
-
140
- if torch.cuda.is_available():
141
-
142
- transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
143
- pipe = SanaPipeline.from_pretrained(
144
- "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
145
- transformer=transformer,
146
- variant="bf16",
147
- torch_dtype=torch.bfloat16,
148
- ).to(device)
149
-
150
- pipe.text_encoder.to(torch.bfloat16)
151
- pipe.vae.to(torch.bfloat16)
152
-
153
-
154
- def save_image_sana(img, seed="", save_img=False):
155
- unique_name = f"{str(uuid.uuid4())}_{seed}.png"
156
- save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
157
- os.umask(0o000) # file permission: 666; dir permission: 777
158
- os.makedirs(save_path, exist_ok=True)
159
- unique_name = os.path.join(save_path, unique_name)
160
- if save_img:
161
- save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
162
-
163
- return unique_name
164
-
165
-
166
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
167
- if randomize_seed:
168
- seed = random.randint(0, MAX_SEED)
169
- return seed
170
-
171
-
172
- @torch.no_grad()
173
- @torch.inference_mode()
174
- @spaces.GPU(enable_queue=True)
175
- def generate(
176
- prompt: str = None,
177
- negative_prompt: str = "",
178
- style: str = DEFAULT_STYLE_NAME,
179
- use_negative_prompt: bool = False,
180
- num_imgs: int = 1,
181
- seed: int = 0,
182
- height: int = 1024,
183
- width: int = 1024,
184
- flow_dpms_guidance_scale: float = 5.0,
185
- flow_dpms_inference_steps: int = 20,
186
- randomize_seed: bool = False,
187
- ):
188
- global INFER_SPEED
189
- # seed = 823753551
190
- seed = int(randomize_seed_fn(seed, randomize_seed))
191
- generator = torch.Generator(device=device).manual_seed(seed)
192
- print(f"PORT: {DEMO_PORT}, model_path: {args.model_path}")
193
-
194
- print(prompt)
195
-
196
- num_inference_steps = flow_dpms_inference_steps
197
- guidance_scale = flow_dpms_guidance_scale
198
-
199
- if not use_negative_prompt:
200
- negative_prompt = None # type: ignore
201
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
202
-
203
- time_start = time.time()
204
- images = pipe(
205
- prompt=prompt,
206
- height=height,
207
- width=width,
208
- negative_prompt=negative_prompt,
209
- guidance_scale=guidance_scale,
210
- num_inference_steps=num_inference_steps,
211
- num_images_per_prompt=num_imgs,
212
- generator=generator,
213
- ).images
214
- INFER_SPEED = (time.time() - time_start) / num_imgs
215
-
216
- save_img = False
217
- if save_img:
218
- img = [save_image_sana(img, seed, save_img=save_image) for img in images]
219
- print(img)
220
- else:
221
- img = images
222
-
223
- torch.cuda.empty_cache()
224
-
225
- return (
226
- img,
227
- seed,
228
- f"<span style='font-size: 16px; font-weight: bold;'>Inference Speed: {INFER_SPEED:.3f} s/Img</span>",
229
- )
230
-
231
-
232
- model_size = "1.6" if "1600M" in args.model_path else "0.6"
233
- title = f"""
234
- <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
235
- <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="30%" alt="logo"/>
236
- </div>
237
- """
238
- DESCRIPTION = f"""
239
- <p style="font-size: 30px; font-weight: bold; text-align: center;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer (4bit version)</p>
240
- """
241
- if model_size == "0.6":
242
- DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
243
- if not torch.cuda.is_available():
244
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
245
-
246
- examples = [
247
- 'a cyberpunk cat with a neon sign that says "Sana"',
248
- "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
249
- "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
250
- "portrait photo of a girl, photograph, highly detailed face, depth of field",
251
- 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
252
- "🐶 Wearing 🕶 flying on the 🌈",
253
- "👧 with 🌹 in the ❄️",
254
- "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
255
- "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
256
- "Astronaut in a jungle, cold color palette, muted colors, detailed",
257
- "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
258
- ]
259
-
260
- css = """
261
- .gradio-container {max-width: 850px !important; height: auto !important;}
262
- h1 {text-align: center;}
263
- """
264
- theme = gr.themes.Base()
265
- with gr.Blocks(css=css, theme=theme, title="Sana") as demo:
266
- gr.Markdown(title)
267
- gr.HTML(DESCRIPTION)
268
- gr.DuplicateButton(
269
- value="Duplicate Space for private use",
270
- elem_id="duplicate-button",
271
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
272
- )
273
- # with gr.Row(equal_height=False):
274
- with gr.Group():
275
- with gr.Row():
276
- prompt = gr.Text(
277
- label="Prompt",
278
- show_label=False,
279
- max_lines=1,
280
- placeholder="Enter your prompt",
281
- container=False,
282
- )
283
- run_button = gr.Button("Run", scale=0)
284
- result = gr.Gallery(
285
- label="Result",
286
- show_label=False,
287
- height=750,
288
- columns=NUM_IMAGES_PER_PROMPT,
289
- format="jpeg",
290
- )
291
-
292
- speed_box = gr.Markdown(
293
- value=f"<span style='font-size: 16px; font-weight: bold;'>Inference speed: {INFER_SPEED} s/Img</span>"
294
- )
295
- with gr.Accordion("Advanced options", open=False):
296
- with gr.Group():
297
- with gr.Row(visible=True):
298
- height = gr.Slider(
299
- label="Height",
300
- minimum=256,
301
- maximum=MAX_IMAGE_SIZE,
302
- step=32,
303
- value=1024,
304
- )
305
- width = gr.Slider(
306
- label="Width",
307
- minimum=256,
308
- maximum=MAX_IMAGE_SIZE,
309
- step=32,
310
- value=1024,
311
- )
312
- with gr.Row():
313
- flow_dpms_inference_steps = gr.Slider(
314
- label="Sampling steps",
315
- minimum=5,
316
- maximum=40,
317
- step=1,
318
- value=20,
319
- )
320
- flow_dpms_guidance_scale = gr.Slider(
321
- label="CFG Guidance scale",
322
- minimum=1,
323
- maximum=10,
324
- step=0.1,
325
- value=4.5,
326
- )
327
- with gr.Row():
328
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
329
- negative_prompt = gr.Text(
330
- label="Negative prompt",
331
- max_lines=1,
332
- placeholder="Enter a negative prompt",
333
- visible=True,
334
- )
335
- style_selection = gr.Radio(
336
- show_label=True,
337
- container=True,
338
- interactive=True,
339
- choices=STYLE_NAMES,
340
- value=DEFAULT_STYLE_NAME,
341
- label="Image Style",
342
- )
343
- seed = gr.Slider(
344
- label="Seed",
345
- minimum=0,
346
- maximum=MAX_SEED,
347
- step=1,
348
- value=0,
349
- )
350
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
351
- with gr.Row(visible=True):
352
- schedule = gr.Radio(
353
- show_label=True,
354
- container=True,
355
- interactive=True,
356
- choices=SCHEDULE_NAME,
357
- value=DEFAULT_SCHEDULE_NAME,
358
- label="Sampler Schedule",
359
- visible=True,
360
- )
361
- num_imgs = gr.Slider(
362
- label="Num Images",
363
- minimum=1,
364
- maximum=6,
365
- step=1,
366
- value=1,
367
- )
368
-
369
- gr.Examples(
370
- examples=examples,
371
- inputs=prompt,
372
- outputs=[result, seed],
373
- fn=generate,
374
- cache_examples=CACHE_EXAMPLES,
375
- )
376
-
377
- use_negative_prompt.change(
378
- fn=lambda x: gr.update(visible=x),
379
- inputs=use_negative_prompt,
380
- outputs=negative_prompt,
381
- api_name=False,
382
- )
383
-
384
- gr.on(
385
- triggers=[
386
- prompt.submit,
387
- negative_prompt.submit,
388
- run_button.click,
389
- ],
390
- fn=generate,
391
- inputs=[
392
- prompt,
393
- negative_prompt,
394
- style_selection,
395
- use_negative_prompt,
396
- num_imgs,
397
- seed,
398
- height,
399
- width,
400
- flow_dpms_guidance_scale,
401
- flow_dpms_inference_steps,
402
- randomize_seed,
403
- ],
404
- outputs=[result, seed, speed_box],
405
- api_name="run",
406
- )
407
-
408
- if __name__ == "__main__":
409
- demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/app_sana_4bit_compare_bf16.py DELETED
@@ -1,313 +0,0 @@
1
- # Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
2
- import argparse
3
- import os
4
- import random
5
- import time
6
- from datetime import datetime
7
-
8
- import GPUtil
9
-
10
- # import gradio last to avoid conflicts with other imports
11
- import gradio as gr
12
- import safety_check
13
- import spaces
14
- import torch
15
- from diffusers import SanaPipeline
16
- from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
17
- from transformers import AutoModelForCausalLM, AutoTokenizer
18
-
19
- MAX_IMAGE_SIZE = 2048
20
- MAX_SEED = 1000000000
21
-
22
- DEFAULT_HEIGHT = 1024
23
- DEFAULT_WIDTH = 1024
24
-
25
- # num_inference_steps, guidance_scale, seed
26
- EXAMPLES = [
27
- [
28
- "🐶 Wearing 🕶 flying on the 🌈",
29
- 1024,
30
- 1024,
31
- 20,
32
- 5,
33
- 2,
34
- ],
35
- [
36
- "大漠孤烟直, 长河落日圆",
37
- 1024,
38
- 1024,
39
- 20,
40
- 5,
41
- 23,
42
- ],
43
- [
44
- "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, "
45
- "volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, "
46
- "art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
47
- 1024,
48
- 1024,
49
- 20,
50
- 5,
51
- 233,
52
- ],
53
- [
54
- "A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be "
55
- "sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic "
56
- "lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field "
57
- "for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, "
58
- "cinematic lighting, ultra-HD.",
59
- 1024,
60
- 1024,
61
- 20,
62
- 5,
63
- 2333,
64
- ],
65
- [
66
- "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. "
67
- "She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. "
68
- "She wears sunglasses and red lipstick. She walks confidently and casually. "
69
- "The street is damp and reflective, creating a mirror effect of the colorful lights. "
70
- "Many pedestrians walk about.",
71
- 1024,
72
- 1024,
73
- 20,
74
- 5,
75
- 23333,
76
- ],
77
- [
78
- "Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, "
79
- "opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, "
80
- "and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, "
81
- "cinematic lighting, ultra-HD.",
82
- 1024,
83
- 1024,
84
- 20,
85
- 5,
86
- 233333,
87
- ],
88
- ]
89
-
90
-
91
- def hash_str_to_int(s: str) -> int:
92
- """Hash a string to an integer."""
93
- modulus = 10**9 + 7 # Large prime modulus
94
- hash_int = 0
95
- for char in s:
96
- hash_int = (hash_int * 31 + ord(char)) % modulus
97
- return hash_int
98
-
99
-
100
- def get_pipeline(
101
- precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {}
102
- ) -> SanaPipeline:
103
- if precision == "int4":
104
- assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
105
- transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
106
-
107
- pipeline_init_kwargs["transformer"] = transformer
108
- if use_qencoder:
109
- raise NotImplementedError("Quantized encoder not supported for Sana for now")
110
- else:
111
- assert precision == "bf16"
112
- pipeline = SanaPipeline.from_pretrained(
113
- "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
114
- variant="bf16",
115
- torch_dtype=torch.bfloat16,
116
- **pipeline_init_kwargs,
117
- )
118
-
119
- pipeline = pipeline.to(device)
120
- return pipeline
121
-
122
-
123
- def get_args() -> argparse.Namespace:
124
- parser = argparse.ArgumentParser()
125
- parser.add_argument(
126
- "-p",
127
- "--precisions",
128
- type=str,
129
- default=["int4"],
130
- nargs="*",
131
- choices=["int4", "bf16"],
132
- help="Which precisions to use",
133
- )
134
- parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
135
- parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
136
- parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
137
- return parser.parse_args()
138
-
139
-
140
- args = get_args()
141
-
142
-
143
- pipelines = []
144
- pipeline_init_kwargs = {}
145
- for i, precision in enumerate(args.precisions):
146
-
147
- pipeline = get_pipeline(
148
- precision=precision,
149
- use_qencoder=args.use_qencoder,
150
- device="cuda",
151
- pipeline_init_kwargs={**pipeline_init_kwargs},
152
- )
153
- pipelines.append(pipeline)
154
- if i == 0:
155
- pipeline_init_kwargs["vae"] = pipeline.vae
156
- pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
157
-
158
- # safety checker
159
- safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
160
- safety_checker_model = AutoModelForCausalLM.from_pretrained(
161
- args.shield_model_path,
162
- device_map="auto",
163
- torch_dtype=torch.bfloat16,
164
- ).to(pipeline.device)
165
-
166
-
167
- @spaces.GPU(enable_queue=True)
168
- def generate(
169
- prompt: str = None,
170
- height: int = 1024,
171
- width: int = 1024,
172
- num_inference_steps: int = 4,
173
- guidance_scale: float = 0,
174
- seed: int = 0,
175
- ):
176
- print(f"Prompt: {prompt}")
177
- is_unsafe_prompt = False
178
- if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
179
- prompt = "A peaceful world."
180
- images, latency_strs = [], []
181
- for i, pipeline in enumerate(pipelines):
182
- progress = gr.Progress(track_tqdm=True)
183
- start_time = time.time()
184
- image = pipeline(
185
- prompt=prompt,
186
- height=height,
187
- width=width,
188
- guidance_scale=guidance_scale,
189
- num_inference_steps=num_inference_steps,
190
- generator=torch.Generator().manual_seed(seed),
191
- ).images[0]
192
- end_time = time.time()
193
- latency = end_time - start_time
194
- if latency < 1:
195
- latency = latency * 1000
196
- latency_str = f"{latency:.2f}ms"
197
- else:
198
- latency_str = f"{latency:.2f}s"
199
- images.append(image)
200
- latency_strs.append(latency_str)
201
- if is_unsafe_prompt:
202
- for i in range(len(latency_strs)):
203
- latency_strs[i] += " (Unsafe prompt detected)"
204
- torch.cuda.empty_cache()
205
-
206
- if args.count_use:
207
- if os.path.exists("use_count.txt"):
208
- with open("use_count.txt") as f:
209
- count = int(f.read())
210
- else:
211
- count = 0
212
- count += 1
213
- current_time = datetime.now()
214
- print(f"{current_time}: {count}")
215
- with open("use_count.txt", "w") as f:
216
- f.write(str(count))
217
- with open("use_record.txt", "a") as f:
218
- f.write(f"{current_time}: {count}\n")
219
-
220
- return *images, *latency_strs
221
-
222
-
223
- with open("./assets/description.html") as f:
224
- DESCRIPTION = f.read()
225
- gpus = GPUtil.getGPUs()
226
- if len(gpus) > 0:
227
- gpu = gpus[0]
228
- memory = gpu.memoryTotal / 1024
229
- device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
230
- else:
231
- device_info = "Running on CPU 🥶 This demo does not work on CPU."
232
- notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
233
-
234
- with gr.Blocks(
235
- css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
236
- title=f"SVDQuant SANA-1600M Demo",
237
- ) as demo:
238
-
239
- def get_header_str():
240
-
241
- if args.count_use:
242
- if os.path.exists("use_count.txt"):
243
- with open("use_count.txt") as f:
244
- count = int(f.read())
245
- else:
246
- count = 0
247
- count_info = (
248
- f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
249
- f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
250
- f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
251
- )
252
- else:
253
- count_info = ""
254
- header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
255
- return header_str
256
-
257
- header = gr.HTML(get_header_str())
258
- demo.load(fn=get_header_str, outputs=header)
259
-
260
- with gr.Row():
261
- image_results, latency_results = [], []
262
- for i, precision in enumerate(args.precisions):
263
- with gr.Column():
264
- gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
265
- with gr.Group():
266
- image_result = gr.Image(
267
- format="png",
268
- image_mode="RGB",
269
- label="Result",
270
- show_label=False,
271
- show_download_button=True,
272
- interactive=False,
273
- )
274
- latency_result = gr.Text(label="Inference Latency", show_label=True)
275
- image_results.append(image_result)
276
- latency_results.append(latency_result)
277
- with gr.Row():
278
- prompt = gr.Text(
279
- label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
280
- )
281
- run_button = gr.Button("Run", scale=1)
282
-
283
- with gr.Row():
284
- seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
285
- randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
286
- with gr.Accordion("Advanced options", open=False):
287
- with gr.Group():
288
- height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
289
- width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
290
- with gr.Group():
291
- num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
292
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
293
-
294
- input_args = [prompt, height, width, num_inference_steps, guidance_scale, seed]
295
-
296
- gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
297
-
298
- gr.on(
299
- triggers=[prompt.submit, run_button.click],
300
- fn=generate,
301
- inputs=input_args,
302
- outputs=[*image_results, *latency_results],
303
- api_name="run",
304
- )
305
- randomize_seed.click(
306
- lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
307
- ).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)
308
-
309
- gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
310
-
311
-
312
- if __name__ == "__main__":
313
- demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/app_sana_controlnet_hed.py DELETED
@@ -1,306 +0,0 @@
1
- # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
2
- import argparse
3
- import os
4
- import random
5
- import socket
6
- import tempfile
7
- import time
8
-
9
- import gradio as gr
10
- import numpy as np
11
- import torch
12
- from PIL import Image
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
14
-
15
- from app import safety_check
16
- from app.sana_controlnet_pipeline import SanaControlNetPipeline
17
-
18
- STYLES = {
19
- "None": "{prompt}",
20
- "Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
21
- "3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
22
- "Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
23
- "Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
24
- "Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
25
- "Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
26
- "Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
27
- "Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
28
- "Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
29
- }
30
- DEFAULT_STYLE_NAME = "None"
31
- STYLE_NAMES = list(STYLES.keys())
32
-
33
- MAX_SEED = 1000000000
34
- DEFAULT_SKETCH_GUIDANCE = 0.28
35
- DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
36
-
37
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
-
39
- blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
40
-
41
-
42
- def get_args():
43
- parser = argparse.ArgumentParser()
44
- parser.add_argument("--config", type=str, help="config")
45
- parser.add_argument(
46
- "--model_path",
47
- nargs="?",
48
- default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
49
- type=str,
50
- help="Path to the model file (positional)",
51
- )
52
- parser.add_argument("--output", default="./", type=str)
53
- parser.add_argument("--bs", default=1, type=int)
54
- parser.add_argument("--image_size", default=1024, type=int)
55
- parser.add_argument("--cfg_scale", default=5.0, type=float)
56
- parser.add_argument("--pag_scale", default=2.0, type=float)
57
- parser.add_argument("--seed", default=42, type=int)
58
- parser.add_argument("--step", default=-1, type=int)
59
- parser.add_argument("--custom_image_size", default=None, type=int)
60
- parser.add_argument("--share", action="store_true")
61
- parser.add_argument(
62
- "--shield_model_path",
63
- type=str,
64
- help="The path to shield model, we employ ShieldGemma-2B by default.",
65
- default="google/shieldgemma-2b",
66
- )
67
-
68
- return parser.parse_known_args()[0]
69
-
70
-
71
- args = get_args()
72
-
73
- if torch.cuda.is_available():
74
- model_path = args.model_path
75
- pipe = SanaControlNetPipeline(args.config)
76
- pipe.from_pretrained(model_path)
77
- pipe.register_progress_bar(gr.Progress())
78
-
79
- # safety checker
80
- safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
81
- safety_checker_model = AutoModelForCausalLM.from_pretrained(
82
- args.shield_model_path,
83
- device_map="auto",
84
- torch_dtype=torch.bfloat16,
85
- ).to(device)
86
-
87
-
88
- def save_image(img):
89
- if isinstance(img, dict):
90
- img = img["composite"]
91
- temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
92
- img.save(temp_file.name)
93
- return temp_file.name
94
-
95
-
96
- def norm_ip(img, low, high):
97
- img.clamp_(min=low, max=high)
98
- img.sub_(low).div_(max(high - low, 1e-5))
99
- return img
100
-
101
-
102
- @torch.no_grad()
103
- @torch.inference_mode()
104
- def run(
105
- image,
106
- prompt: str,
107
- prompt_template: str,
108
- sketch_thickness: int,
109
- guidance_scale: float,
110
- inference_steps: int,
111
- seed: int,
112
- blend_alpha: float,
113
- ) -> tuple[Image, str]:
114
-
115
- print(f"Prompt: {prompt}")
116
- image_numpy = np.array(image["composite"].convert("RGB"))
117
-
118
- if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
119
- return blank_image, "Please input the prompt or draw something."
120
-
121
- if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
122
- prompt = "A red heart."
123
-
124
- prompt = prompt_template.format(prompt=prompt)
125
- pipe.set_blend_alpha(blend_alpha)
126
- start_time = time.time()
127
- images = pipe(
128
- prompt=prompt,
129
- ref_image=image["composite"],
130
- guidance_scale=guidance_scale,
131
- num_inference_steps=inference_steps,
132
- num_images_per_prompt=1,
133
- sketch_thickness=sketch_thickness,
134
- generator=torch.Generator(device=device).manual_seed(seed),
135
- )
136
-
137
- latency = time.time() - start_time
138
-
139
- if latency < 1:
140
- latency = latency * 1000
141
- latency_str = f"{latency:.2f}ms"
142
- else:
143
- latency_str = f"{latency:.2f}s"
144
- torch.cuda.empty_cache()
145
-
146
- img = [
147
- Image.fromarray(
148
- norm_ip(img, -1, 1)
149
- .mul(255)
150
- .add_(0.5)
151
- .clamp_(0, 255)
152
- .permute(1, 2, 0)
153
- .to("cpu", torch.uint8)
154
- .numpy()
155
- .astype(np.uint8)
156
- )
157
- for img in images
158
- ]
159
- img = img[0]
160
- return img, latency_str
161
-
162
-
163
- model_size = "1.6" if "1600M" in args.model_path else "0.6"
164
- title = f"""
165
- <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
166
- <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
167
- </div>
168
- """
169
- DESCRIPTION = f"""
170
- <p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
171
- <p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
172
- <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
173
- <p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
174
- <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
175
- """
176
- if model_size == "0.6":
177
- DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
178
- if not torch.cuda.is_available():
179
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
180
-
181
-
182
- with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
183
- gr.Markdown(title)
184
- gr.HTML(DESCRIPTION)
185
-
186
- with gr.Row(elem_id="main_row"):
187
- with gr.Column(elem_id="column_input"):
188
- gr.Markdown("## INPUT", elem_id="input_header")
189
- with gr.Group():
190
- canvas = gr.Sketchpad(
191
- value=blank_image,
192
- height=640,
193
- image_mode="RGB",
194
- sources=["upload", "clipboard"],
195
- type="pil",
196
- label="Sketch",
197
- show_label=False,
198
- show_download_button=True,
199
- interactive=True,
200
- transforms=[],
201
- canvas_size=(1024, 1024),
202
- scale=1,
203
- brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
204
- format="png",
205
- layers=False,
206
- )
207
- with gr.Row():
208
- prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
209
- run_button = gr.Button("Run", scale=1, elem_id="run_button")
210
- download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
211
- with gr.Row():
212
- style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
213
- prompt_template = gr.Textbox(
214
- label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
215
- )
216
-
217
- with gr.Row():
218
- sketch_thickness = gr.Slider(
219
- label="Sketch Thickness",
220
- minimum=1,
221
- maximum=4,
222
- step=1,
223
- value=2,
224
- )
225
- with gr.Row():
226
- inference_steps = gr.Slider(
227
- label="Sampling steps",
228
- minimum=5,
229
- maximum=40,
230
- step=1,
231
- value=20,
232
- )
233
- guidance_scale = gr.Slider(
234
- label="CFG Guidance scale",
235
- minimum=1,
236
- maximum=10,
237
- step=0.1,
238
- value=4.5,
239
- )
240
- blend_alpha = gr.Slider(
241
- label="Blend Alpha",
242
- minimum=0,
243
- maximum=1,
244
- step=0.1,
245
- value=0,
246
- )
247
- with gr.Row():
248
- seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
249
- randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
250
-
251
- with gr.Column(elem_id="column_output"):
252
- gr.Markdown("## OUTPUT", elem_id="output_header")
253
- with gr.Group():
254
- result = gr.Image(
255
- format="png",
256
- height=640,
257
- image_mode="RGB",
258
- type="pil",
259
- label="Result",
260
- show_label=False,
261
- show_download_button=True,
262
- interactive=False,
263
- elem_id="output_image",
264
- )
265
- latency_result = gr.Text(label="Inference Latency", show_label=True)
266
-
267
- download_result = gr.DownloadButton("Download Result", elem_id="download_result")
268
- gr.Markdown("### Instructions")
269
- gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
270
- gr.Markdown("**2**. Start sketching or upload a reference image")
271
- gr.Markdown("**3**. Change the image style using a style template")
272
- gr.Markdown("**4**. Try different seeds to generate different results")
273
-
274
- run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
275
- run_outputs = [result, latency_result]
276
-
277
- randomize_seed.click(
278
- lambda: random.randint(0, MAX_SEED),
279
- inputs=[],
280
- outputs=seed,
281
- api_name=False,
282
- queue=False,
283
- ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
284
-
285
- style.change(
286
- lambda x: STYLES[x],
287
- inputs=[style],
288
- outputs=[prompt_template],
289
- api_name=False,
290
- queue=False,
291
- ).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
292
- gr.on(
293
- triggers=[prompt.submit, run_button.click, canvas.change],
294
- fn=run,
295
- inputs=run_inputs,
296
- outputs=run_outputs,
297
- api_name=False,
298
- )
299
-
300
- download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
301
- download_result.click(fn=save_image, inputs=result, outputs=download_result)
302
- gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
303
-
304
-
305
- if __name__ == "__main__":
306
- demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/app_sana_multithread.py DELETED
@@ -1,565 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # SPDX-License-Identifier: Apache-2.0
17
- from __future__ import annotations
18
-
19
- import argparse
20
- import os
21
- import random
22
- import uuid
23
- from datetime import datetime
24
-
25
- import gradio as gr
26
- import numpy as np
27
- import spaces
28
- import torch
29
- from diffusers import FluxPipeline
30
- from PIL import Image
31
- from torchvision.utils import make_grid, save_image
32
- from transformers import AutoModelForCausalLM, AutoTokenizer
33
-
34
- from app import safety_check
35
- from app.sana_pipeline import SanaPipeline
36
-
37
- MAX_SEED = np.iinfo(np.int32).max
38
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
39
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
40
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
41
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
42
- DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
43
- os.environ["GRADIO_EXAMPLES_CACHE"] = "./.gradio/cache"
44
-
45
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46
-
47
- style_list = [
48
- {
49
- "name": "(No style)",
50
- "prompt": "{prompt}",
51
- "negative_prompt": "",
52
- },
53
- {
54
- "name": "Cinematic",
55
- "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, "
56
- "cinemascope, moody, epic, gorgeous, film grain, grainy",
57
- "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
58
- },
59
- {
60
- "name": "Photographic",
61
- "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
62
- "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
63
- },
64
- {
65
- "name": "Anime",
66
- "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
67
- "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
68
- },
69
- {
70
- "name": "Manga",
71
- "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
72
- "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
73
- },
74
- {
75
- "name": "Digital Art",
76
- "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
77
- "negative_prompt": "photo, photorealistic, realism, ugly",
78
- },
79
- {
80
- "name": "Pixel art",
81
- "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
82
- "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
83
- },
84
- {
85
- "name": "Fantasy art",
86
- "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, "
87
- "majestic, magical, fantasy art, cover art, dreamy",
88
- "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, "
89
- "glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, "
90
- "disfigured, sloppy, duplicate, mutated, black and white",
91
- },
92
- {
93
- "name": "Neonpunk",
94
- "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, "
95
- "detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, "
96
- "ultra detailed, intricate, professional",
97
- "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
98
- },
99
- {
100
- "name": "3D Model",
101
- "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
102
- "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
103
- },
104
- ]
105
-
106
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
107
- STYLE_NAMES = list(styles.keys())
108
- DEFAULT_STYLE_NAME = "(No style)"
109
- SCHEDULE_NAME = ["Flow_DPM_Solver"]
110
- DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
111
- NUM_IMAGES_PER_PROMPT = 1
112
- TEST_TIMES = 0
113
- FILENAME = f"output/port{DEMO_PORT}_inference_count.txt"
114
-
115
-
116
- def set_env(seed=0):
117
- torch.manual_seed(seed)
118
- torch.set_grad_enabled(False)
119
-
120
-
121
- def read_inference_count():
122
- global TEST_TIMES
123
- try:
124
- with open(FILENAME) as f:
125
- count = int(f.read().strip())
126
- except FileNotFoundError:
127
- count = 0
128
- TEST_TIMES = count
129
-
130
- return count
131
-
132
-
133
- def write_inference_count(count):
134
- with open(FILENAME, "w") as f:
135
- f.write(str(count))
136
-
137
-
138
- def run_inference(num_imgs=1):
139
- TEST_TIMES = read_inference_count()
140
- TEST_TIMES += int(num_imgs)
141
- write_inference_count(TEST_TIMES)
142
-
143
- return (
144
- f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
145
- f"16px; color:red; font-weight: bold;'>{TEST_TIMES}</span>"
146
- )
147
-
148
-
149
- def update_inference_count():
150
- count = read_inference_count()
151
- return (
152
- f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
153
- f"16px; color:red; font-weight: bold;'>{count}</span>"
154
- )
155
-
156
-
157
- def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
158
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
159
- if not negative:
160
- negative = ""
161
- return p.replace("{prompt}", positive), n + negative
162
-
163
-
164
- def get_args():
165
- parser = argparse.ArgumentParser()
166
- parser.add_argument("--config", type=str, help="config")
167
- parser.add_argument(
168
- "--model_path",
169
- nargs="?",
170
- default="output/Sana_D20/SANA.pth",
171
- type=str,
172
- help="Path to the model file (positional)",
173
- )
174
- parser.add_argument("--output", default="./", type=str)
175
- parser.add_argument("--bs", default=1, type=int)
176
- parser.add_argument("--image_size", default=1024, type=int)
177
- parser.add_argument("--cfg_scale", default=5.0, type=float)
178
- parser.add_argument("--pag_scale", default=2.0, type=float)
179
- parser.add_argument("--seed", default=42, type=int)
180
- parser.add_argument("--step", default=-1, type=int)
181
- parser.add_argument("--custom_image_size", default=None, type=int)
182
- parser.add_argument(
183
- "--shield_model_path",
184
- type=str,
185
- help="The path to shield model, we employ ShieldGemma-2B by default.",
186
- default="google/shieldgemma-2b",
187
- )
188
-
189
- return parser.parse_args()
190
-
191
-
192
- args = get_args()
193
-
194
- if torch.cuda.is_available():
195
- weight_dtype = torch.float16
196
- model_path = args.model_path
197
- pipe = SanaPipeline(args.config)
198
- pipe.from_pretrained(model_path)
199
- pipe.register_progress_bar(gr.Progress())
200
-
201
- repo_name = "black-forest-labs/FLUX.1-dev"
202
- pipe2 = FluxPipeline.from_pretrained(repo_name, torch_dtype=torch.float16).to("cuda")
203
-
204
- # safety checker
205
- safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
206
- safety_checker_model = AutoModelForCausalLM.from_pretrained(
207
- args.shield_model_path,
208
- device_map="auto",
209
- torch_dtype=torch.bfloat16,
210
- ).to(device)
211
-
212
- set_env(42)
213
-
214
-
215
- def save_image_sana(img, seed="", save_img=False):
216
- unique_name = f"{str(uuid.uuid4())}_{seed}.png"
217
- save_path = os.path.join(f"output/online_demo_img/{datetime.now().date()}")
218
- os.umask(0o000) # file permission: 666; dir permission: 777
219
- os.makedirs(save_path, exist_ok=True)
220
- unique_name = os.path.join(save_path, unique_name)
221
- if save_img:
222
- save_image(img, unique_name, nrow=1, normalize=True, value_range=(-1, 1))
223
-
224
- return unique_name
225
-
226
-
227
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
228
- if randomize_seed:
229
- seed = random.randint(0, MAX_SEED)
230
- return seed
231
-
232
-
233
- @spaces.GPU(enable_queue=True)
234
- async def generate_2(
235
- prompt: str = None,
236
- negative_prompt: str = "",
237
- style: str = DEFAULT_STYLE_NAME,
238
- use_negative_prompt: bool = False,
239
- num_imgs: int = 1,
240
- seed: int = 0,
241
- height: int = 1024,
242
- width: int = 1024,
243
- flow_dpms_guidance_scale: float = 5.0,
244
- flow_dpms_pag_guidance_scale: float = 2.0,
245
- flow_dpms_inference_steps: int = 20,
246
- randomize_seed: bool = False,
247
- ):
248
- seed = int(randomize_seed_fn(seed, randomize_seed))
249
- generator = torch.Generator(device=device).manual_seed(seed)
250
- print(f"PORT: {DEMO_PORT}, model_path: {model_path}")
251
- if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
252
- prompt = "A red heart."
253
-
254
- print(prompt)
255
-
256
- if not use_negative_prompt:
257
- negative_prompt = None # type: ignore
258
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
259
-
260
- with torch.no_grad():
261
- images = pipe2(
262
- prompt=prompt,
263
- height=height,
264
- width=width,
265
- guidance_scale=3.5,
266
- num_inference_steps=50,
267
- num_images_per_prompt=num_imgs,
268
- max_sequence_length=256,
269
- generator=generator,
270
- ).images
271
-
272
- save_img = False
273
- img = images
274
- if save_img:
275
- img = [save_image_sana(img, seed, save_img=save_image) for img in images]
276
- print(img)
277
- torch.cuda.empty_cache()
278
-
279
- return img
280
-
281
-
282
- @spaces.GPU(enable_queue=True)
283
- async def generate(
284
- prompt: str = None,
285
- negative_prompt: str = "",
286
- style: str = DEFAULT_STYLE_NAME,
287
- use_negative_prompt: bool = False,
288
- num_imgs: int = 1,
289
- seed: int = 0,
290
- height: int = 1024,
291
- width: int = 1024,
292
- flow_dpms_guidance_scale: float = 5.0,
293
- flow_dpms_pag_guidance_scale: float = 2.0,
294
- flow_dpms_inference_steps: int = 20,
295
- randomize_seed: bool = False,
296
- ):
297
- global TEST_TIMES
298
- # seed = 823753551
299
- seed = int(randomize_seed_fn(seed, randomize_seed))
300
- generator = torch.Generator(device=device).manual_seed(seed)
301
- print(f"PORT: {DEMO_PORT}, model_path: {model_path}, time_times: {TEST_TIMES}")
302
- if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt):
303
- prompt = "A red heart."
304
-
305
- print(prompt)
306
-
307
- num_inference_steps = flow_dpms_inference_steps
308
- guidance_scale = flow_dpms_guidance_scale
309
- pag_guidance_scale = flow_dpms_pag_guidance_scale
310
-
311
- if not use_negative_prompt:
312
- negative_prompt = None # type: ignore
313
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
314
-
315
- pipe.progress_fn(0, desc="Sana Start")
316
-
317
- with torch.no_grad():
318
- images = pipe(
319
- prompt=prompt,
320
- height=height,
321
- width=width,
322
- negative_prompt=negative_prompt,
323
- guidance_scale=guidance_scale,
324
- pag_guidance_scale=pag_guidance_scale,
325
- num_inference_steps=num_inference_steps,
326
- num_images_per_prompt=num_imgs,
327
- generator=generator,
328
- )
329
-
330
- pipe.progress_fn(1.0, desc="Sana End")
331
-
332
- save_img = False
333
- if save_img:
334
- img = [save_image_sana(img, seed, save_img=save_image) for img in images]
335
- print(img)
336
- else:
337
- if num_imgs > 1:
338
- nrow = 2
339
- else:
340
- nrow = 1
341
- img = make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
342
- img = img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
343
- img = [Image.fromarray(img.astype(np.uint8))]
344
-
345
- torch.cuda.empty_cache()
346
-
347
- return img
348
-
349
-
350
- TEST_TIMES = read_inference_count()
351
- model_size = "1.6" if "D20" in args.model_path else "0.6"
352
- title = f"""
353
- <div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
354
- <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
355
- </div>
356
- """
357
- DESCRIPTION = f"""
358
- <p><span style="font-size: 36px; font-weight: bold;">Sana-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
359
- <p style="font-size: 16px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
360
- <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
361
- <p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space</p>
362
- <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
363
- """
364
- if model_size == "0.6":
365
- DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
366
- if not torch.cuda.is_available():
367
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
368
-
369
- examples = [
370
- 'a cyberpunk cat with a neon sign that says "Sana"',
371
- "A very detailed and realistic full body photo set of a tall, slim, and athletic Shiba Inu in a white oversized straight t-shirt, white shorts, and short white shoes.",
372
- "Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
373
- "portrait photo of a girl, photograph, highly detailed face, depth of field",
374
- 'make me a logo that says "So Fast" with a really cool flying dragon shape with lightning sparks all over the sides and all of it contains Indonesian language',
375
- "🐶 Wearing 🕶 flying on the 🌈",
376
- # "👧 with 🌹 in the ❄️",
377
- # "an old rusted robot wearing pants and a jacket riding skis in a supermarket.",
378
- # "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
379
- # "Astronaut in a jungle, cold color palette, muted colors, detailed",
380
- # "a stunning and luxurious bedroom carved into a rocky mountainside seamlessly blending nature with modern design with a plush earth-toned bed textured stone walls circular fireplace massive uniquely shaped window framing snow-capped mountains dense forests",
381
- ]
382
-
383
- css = """
384
- .gradio-container{max-width: 1024px !important}
385
- h1{text-align:center}
386
- """
387
- with gr.Blocks(css=css) as demo:
388
- gr.Markdown(title)
389
- gr.Markdown(DESCRIPTION)
390
- gr.DuplicateButton(
391
- value="Duplicate Space for private use",
392
- elem_id="duplicate-button",
393
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
394
- )
395
- info_box = gr.Markdown(
396
- value=f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: 16px; color:red; font-weight: bold;'>{read_inference_count()}</span>"
397
- )
398
- demo.load(fn=update_inference_count, outputs=info_box) # update the value when re-loading the page
399
- # with gr.Row(equal_height=False):
400
- with gr.Group():
401
- with gr.Row():
402
- prompt = gr.Text(
403
- label="Prompt",
404
- show_label=False,
405
- max_lines=1,
406
- placeholder="Enter your prompt",
407
- container=False,
408
- )
409
- run_button = gr.Button("Run-sana", scale=0)
410
- run_button2 = gr.Button("Run-flux", scale=0)
411
-
412
- with gr.Row():
413
- result = gr.Gallery(label="Result from Sana", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp")
414
- result_2 = gr.Gallery(
415
- label="Result from FLUX", show_label=True, columns=NUM_IMAGES_PER_PROMPT, format="webp"
416
- )
417
-
418
- with gr.Accordion("Advanced options", open=False):
419
- with gr.Group():
420
- with gr.Row(visible=True):
421
- height = gr.Slider(
422
- label="Height",
423
- minimum=256,
424
- maximum=MAX_IMAGE_SIZE,
425
- step=32,
426
- value=1024,
427
- )
428
- width = gr.Slider(
429
- label="Width",
430
- minimum=256,
431
- maximum=MAX_IMAGE_SIZE,
432
- step=32,
433
- value=1024,
434
- )
435
- with gr.Row():
436
- flow_dpms_inference_steps = gr.Slider(
437
- label="Sampling steps",
438
- minimum=5,
439
- maximum=40,
440
- step=1,
441
- value=18,
442
- )
443
- flow_dpms_guidance_scale = gr.Slider(
444
- label="CFG Guidance scale",
445
- minimum=1,
446
- maximum=10,
447
- step=0.1,
448
- value=5.0,
449
- )
450
- flow_dpms_pag_guidance_scale = gr.Slider(
451
- label="PAG Guidance scale",
452
- minimum=1,
453
- maximum=4,
454
- step=0.5,
455
- value=2.0,
456
- )
457
- with gr.Row():
458
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
459
- negative_prompt = gr.Text(
460
- label="Negative prompt",
461
- max_lines=1,
462
- placeholder="Enter a negative prompt",
463
- visible=True,
464
- )
465
- style_selection = gr.Radio(
466
- show_label=True,
467
- container=True,
468
- interactive=True,
469
- choices=STYLE_NAMES,
470
- value=DEFAULT_STYLE_NAME,
471
- label="Image Style",
472
- )
473
- seed = gr.Slider(
474
- label="Seed",
475
- minimum=0,
476
- maximum=MAX_SEED,
477
- step=1,
478
- value=0,
479
- )
480
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
481
- with gr.Row(visible=True):
482
- schedule = gr.Radio(
483
- show_label=True,
484
- container=True,
485
- interactive=True,
486
- choices=SCHEDULE_NAME,
487
- value=DEFAULT_SCHEDULE_NAME,
488
- label="Sampler Schedule",
489
- visible=True,
490
- )
491
- num_imgs = gr.Slider(
492
- label="Num Images",
493
- minimum=1,
494
- maximum=6,
495
- step=1,
496
- value=1,
497
- )
498
-
499
- run_button.click(fn=run_inference, inputs=num_imgs, outputs=info_box)
500
-
501
- gr.Examples(
502
- examples=examples,
503
- inputs=prompt,
504
- outputs=[result],
505
- fn=generate,
506
- cache_examples=CACHE_EXAMPLES,
507
- )
508
- gr.Examples(
509
- examples=examples,
510
- inputs=prompt,
511
- outputs=[result_2],
512
- fn=generate_2,
513
- cache_examples=CACHE_EXAMPLES,
514
- )
515
-
516
- use_negative_prompt.change(
517
- fn=lambda x: gr.update(visible=x),
518
- inputs=use_negative_prompt,
519
- outputs=negative_prompt,
520
- api_name=False,
521
- )
522
-
523
- run_button.click(
524
- fn=generate,
525
- inputs=[
526
- prompt,
527
- negative_prompt,
528
- style_selection,
529
- use_negative_prompt,
530
- num_imgs,
531
- seed,
532
- height,
533
- width,
534
- flow_dpms_guidance_scale,
535
- flow_dpms_pag_guidance_scale,
536
- flow_dpms_inference_steps,
537
- randomize_seed,
538
- ],
539
- outputs=[result],
540
- queue=True,
541
- )
542
-
543
- run_button2.click(
544
- fn=generate_2,
545
- inputs=[
546
- prompt,
547
- negative_prompt,
548
- style_selection,
549
- use_negative_prompt,
550
- num_imgs,
551
- seed,
552
- height,
553
- width,
554
- flow_dpms_guidance_scale,
555
- flow_dpms_pag_guidance_scale,
556
- flow_dpms_inference_steps,
557
- randomize_seed,
558
- ],
559
- outputs=[result_2],
560
- queue=True,
561
- )
562
-
563
-
564
- if __name__ == "__main__":
565
- demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/safety_check.py DELETED
@@ -1,72 +0,0 @@
1
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
15
- # SPDX-License-Identifier: Apache-2.0
16
- # This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
17
- import json
18
- from functools import lru_cache
19
- from typing import Dict, Sequence
20
-
21
- import torch
22
- import transformers
23
- from torch.nn.functional import softmax
24
- from transformers import AutoModelForCausalLM, AutoTokenizer
25
-
26
- __all__ = ["is_dangerous"]
27
-
28
- __rank, __local_rank, __world_size, __device = (
29
- 0,
30
- 0,
31
- 1,
32
- "cuda" if torch.cuda.is_available() else "cpu",
33
- )
34
-
35
-
36
- safety_policy = """
37
- The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
38
- """
39
-
40
-
41
- def is_dangerous(tokenizer, model, user_prompt, threshold=0.5):
42
- prompt = f"""You are a policy expert trying to help determine whether a user
43
- prompt is in violation of the defined safety policies.
44
-
45
- <start_of_turn>
46
- Human Question: {user_prompt.strip()}
47
- <end_of_turn>
48
-
49
- Our safety principle is defined in the below:
50
-
51
- {safety_policy.strip()}
52
-
53
- Does the human question violate the above principle? Your answer must start
54
- with 'Yes' or 'No'. And then walk through step by step to be sure we answer
55
- correctly.
56
- """
57
-
58
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
59
- with torch.no_grad():
60
- logits = model(**inputs).logits
61
-
62
- # Extract the logits for the Yes and No tokens
63
- vocab = tokenizer.get_vocab()
64
- selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
65
-
66
- # Convert these logits to a probability with softmax
67
- probabilities = softmax(selected_logits, dim=0)
68
-
69
- # Return probability of 'Yes'
70
- score = probabilities[0].item()
71
-
72
- return score > threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/sana_controlnet_pipeline.py DELETED
@@ -1,353 +0,0 @@
1
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
15
- # SPDX-License-Identifier: Apache-2.0
16
- import warnings
17
- from dataclasses import dataclass, field
18
- from typing import Optional, Tuple
19
-
20
- import numpy as np
21
- import pyrallis
22
- import torch
23
- import torch.nn as nn
24
- from PIL import Image
25
-
26
- warnings.filterwarnings("ignore") # ignore warning
27
-
28
-
29
- from diffusion import DPMS, FlowEuler
30
- from diffusion.data.datasets.utils import (
31
- ASPECT_RATIO_512_TEST,
32
- ASPECT_RATIO_1024_TEST,
33
- ASPECT_RATIO_2048_TEST,
34
- ASPECT_RATIO_4096_TEST,
35
- )
36
- from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode
37
- from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
38
- from diffusion.utils.config import SanaConfig, model_init_config
39
- from diffusion.utils.logger import get_root_logger
40
- from tools.controlnet.utils import get_scribble_map, transform_control_signal
41
- from tools.download import find_model
42
-
43
-
44
- def guidance_type_select(default_guidance_type, pag_scale, attn_type):
45
- guidance_type = default_guidance_type
46
- if not (pag_scale > 1.0 and attn_type == "linear"):
47
- guidance_type = "classifier-free"
48
- elif pag_scale > 1.0 and attn_type == "linear":
49
- guidance_type = "classifier-free_PAG"
50
- return guidance_type
51
-
52
-
53
- def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
54
- """Returns binned height and width."""
55
- ar = float(height / width)
56
- closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
57
- default_hw = ratios[closest_ratio]
58
- return int(default_hw[0]), int(default_hw[1])
59
-
60
-
61
- def get_ar_from_ref_image(ref_image):
62
- def reduce_ratio(h, w):
63
- def gcd(a, b):
64
- while b:
65
- a, b = b, a % b
66
- return a
67
-
68
- divisor = gcd(h, w)
69
- return f"{h // divisor}:{w // divisor}"
70
-
71
- if isinstance(ref_image, str):
72
- ref_image = Image.open(ref_image)
73
- w, h = ref_image.size
74
- return reduce_ratio(h, w)
75
-
76
-
77
- @dataclass
78
- class SanaControlNetInference(SanaConfig):
79
- config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
80
- model_path: str = field(
81
- default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
82
- )
83
- output: str = "./output"
84
- bs: int = 1
85
- image_size: int = 1024
86
- cfg_scale: float = 5.0
87
- pag_scale: float = 2.0
88
- seed: int = 42
89
- step: int = -1
90
- custom_image_size: Optional[int] = None
91
- shield_model_path: str = field(
92
- default="google/shieldgemma-2b",
93
- metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
94
- )
95
-
96
-
97
- class SanaControlNetPipeline(nn.Module):
98
- def __init__(
99
- self,
100
- config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
101
- ):
102
- super().__init__()
103
- config = pyrallis.load(SanaControlNetInference, open(config))
104
- self.args = self.config = config
105
-
106
- # set some hyper-parameters
107
- self.image_size = self.config.model.image_size
108
-
109
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
110
- logger = get_root_logger()
111
- self.logger = logger
112
- self.progress_fn = lambda progress, desc: None
113
- self.thickness = 2
114
- self.blend_alpha = 0.0
115
-
116
- self.latent_size = self.image_size // config.vae.vae_downsample_rate
117
- self.max_sequence_length = config.text_encoder.model_max_length
118
- self.flow_shift = config.scheduler.flow_shift
119
- guidance_type = "classifier-free_PAG"
120
-
121
- weight_dtype = get_weight_dtype(config.model.mixed_precision)
122
- self.weight_dtype = weight_dtype
123
- self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
124
-
125
- self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
126
- self.vis_sampler = self.config.scheduler.vis_sampler
127
- logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
128
- self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
129
- logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
130
-
131
- # 1. build vae and text encoder
132
- self.vae = self.build_vae(config.vae)
133
- self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
134
-
135
- # 2. build Sana model
136
- self.model = self.build_sana_model(config).to(self.device)
137
-
138
- # 3. pre-compute null embedding
139
- with torch.no_grad():
140
- null_caption_token = self.tokenizer(
141
- "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
142
- ).to(self.device)
143
- self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
144
- 0
145
- ]
146
-
147
- def build_vae(self, config):
148
- vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
149
- return vae
150
-
151
- def build_text_encoder(self, config):
152
- tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
153
- return tokenizer, text_encoder
154
-
155
- def build_sana_model(self, config):
156
- # model setting
157
- model_kwargs = model_init_config(config, latent_size=self.latent_size)
158
- model = build_model(
159
- config.model.model,
160
- use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
161
- **model_kwargs,
162
- )
163
- self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
164
- self.logger.info(
165
- f"{model.__class__.__name__}:{config.model.model},"
166
- f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
167
- )
168
- return model
169
-
170
- def from_pretrained(self, model_path):
171
- state_dict = find_model(model_path)
172
- state_dict = state_dict.get("state_dict", state_dict)
173
- if "pos_embed" in state_dict:
174
- del state_dict["pos_embed"]
175
- missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
176
- self.model.eval().to(self.weight_dtype)
177
-
178
- self.logger.info("Generating sample from ckpt: %s" % model_path)
179
- self.logger.warning(f"Missing keys: {missing}")
180
- self.logger.warning(f"Unexpected keys: {unexpected}")
181
-
182
- def register_progress_bar(self, progress_fn=None):
183
- self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
184
-
185
- def set_blend_alpha(self, blend_alpha):
186
- self.blend_alpha = blend_alpha
187
-
188
- @torch.inference_mode()
189
- def forward(
190
- self,
191
- prompt=None,
192
- ref_image=None,
193
- negative_prompt="",
194
- num_inference_steps=20,
195
- guidance_scale=5,
196
- pag_guidance_scale=2.5,
197
- num_images_per_prompt=1,
198
- sketch_thickness=2,
199
- generator=torch.Generator().manual_seed(42),
200
- latents=None,
201
- ):
202
- self.ori_height, self.ori_width = ref_image.height, ref_image.width
203
- self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
204
-
205
- # 1. pre-compute negative embedding
206
- if negative_prompt != "":
207
- null_caption_token = self.tokenizer(
208
- negative_prompt,
209
- max_length=self.max_sequence_length,
210
- padding="max_length",
211
- truncation=True,
212
- return_tensors="pt",
213
- ).to(self.device)
214
- self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
215
- 0
216
- ]
217
-
218
- if prompt is None:
219
- prompt = [""]
220
- prompts = prompt if isinstance(prompt, list) else [prompt]
221
- samples = []
222
-
223
- for prompt in prompts:
224
- # data prepare
225
- prompts, hw, ar = (
226
- [],
227
- torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
228
- num_images_per_prompt, 1
229
- ),
230
- torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
231
- )
232
-
233
- ar = get_ar_from_ref_image(ref_image)
234
- prompt += f" --ar {ar}"
235
- for _ in range(num_images_per_prompt):
236
- prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(
237
- prompt, self.base_ratios, device=self.device, show=False
238
- )
239
- prompts.append(prompt_clean.strip())
240
-
241
- self.latent_size_h, self.latent_size_w = (
242
- int(hw[0, 0] // self.config.vae.vae_downsample_rate),
243
- int(hw[0, 1] // self.config.vae.vae_downsample_rate),
244
- )
245
-
246
- with torch.no_grad():
247
- # prepare text feature
248
- if not self.config.text_encoder.chi_prompt:
249
- max_length_all = self.config.text_encoder.model_max_length
250
- prompts_all = prompts
251
- else:
252
- chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
253
- prompts_all = [chi_prompt + prompt for prompt in prompts]
254
- num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
255
- max_length_all = (
256
- num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
257
- ) # magic number 2: [bos], [_]
258
-
259
- caption_token = self.tokenizer(
260
- prompts_all,
261
- max_length=max_length_all,
262
- padding="max_length",
263
- truncation=True,
264
- return_tensors="pt",
265
- ).to(device=self.device)
266
- select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
267
- caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
268
- :, :, select_index
269
- ].to(self.weight_dtype)
270
- emb_masks = caption_token.attention_mask[:, select_index]
271
- null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
272
-
273
- n = len(prompts)
274
- if latents is None:
275
- z = torch.randn(
276
- n,
277
- self.config.vae.vae_latent_dim,
278
- self.latent_size_h,
279
- self.latent_size_w,
280
- generator=generator,
281
- device=self.device,
282
- )
283
- else:
284
- z = latents.to(self.device)
285
- model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
286
-
287
- # control signal
288
- if isinstance(ref_image, str):
289
- ref_image = cv2.imread(ref_image)
290
- elif isinstance(ref_image, Image.Image):
291
- ref_image = np.array(ref_image)
292
- control_signal = get_scribble_map(
293
- input_image=ref_image,
294
- det="Scribble_HED",
295
- detect_resolution=int(hw.min()),
296
- thickness=sketch_thickness,
297
- )
298
-
299
- control_signal = transform_control_signal(control_signal, hw).to(self.device).to(self.weight_dtype)
300
-
301
- control_signal_latent = vae_encode(
302
- self.config.vae.vae_type, self.vae, control_signal, self.config.vae.sample_posterior, self.device
303
- )
304
-
305
- model_kwargs["control_signal"] = control_signal_latent
306
-
307
- if self.vis_sampler == "flow_euler":
308
- flow_solver = FlowEuler(
309
- self.model,
310
- condition=caption_embs,
311
- uncondition=null_y,
312
- cfg_scale=guidance_scale,
313
- model_kwargs=model_kwargs,
314
- )
315
- sample = flow_solver.sample(
316
- z,
317
- steps=num_inference_steps,
318
- )
319
- elif self.vis_sampler == "flow_dpm-solver":
320
- scheduler = DPMS(
321
- self.model.forward_with_dpmsolver,
322
- condition=caption_embs,
323
- uncondition=null_y,
324
- guidance_type=self.guidance_type,
325
- cfg_scale=guidance_scale,
326
- model_type="flow",
327
- model_kwargs=model_kwargs,
328
- schedule="FLOW",
329
- )
330
- scheduler.register_progress_bar(self.progress_fn)
331
- sample = scheduler.sample(
332
- z,
333
- steps=num_inference_steps,
334
- order=2,
335
- skip_type="time_uniform_flow",
336
- method="multistep",
337
- flow_shift=self.flow_shift,
338
- )
339
-
340
- sample = sample.to(self.vae_dtype)
341
- with torch.no_grad():
342
- sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
343
-
344
- if self.blend_alpha > 0:
345
- print(f"blend image and mask with alpha: {self.blend_alpha}")
346
- sample = sample * (1 - self.blend_alpha) + control_signal * self.blend_alpha
347
-
348
- sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
349
- samples.append(sample)
350
-
351
- return sample
352
-
353
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/sana_pipeline.py DELETED
@@ -1,304 +0,0 @@
1
- # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
15
- # SPDX-License-Identifier: Apache-2.0
16
- import argparse
17
- import warnings
18
- from dataclasses import dataclass, field
19
- from typing import Optional, Tuple
20
-
21
- import pyrallis
22
- import torch
23
- import torch.nn as nn
24
-
25
- warnings.filterwarnings("ignore") # ignore warning
26
-
27
-
28
- from diffusion import DPMS, FlowEuler
29
- from diffusion.data.datasets.utils import (
30
- ASPECT_RATIO_512_TEST,
31
- ASPECT_RATIO_1024_TEST,
32
- ASPECT_RATIO_2048_TEST,
33
- ASPECT_RATIO_4096_TEST,
34
- )
35
- from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
36
- from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
37
- from diffusion.utils.config import SanaConfig, model_init_config
38
- from diffusion.utils.logger import get_root_logger
39
-
40
- # from diffusion.utils.misc import read_config
41
- from tools.download import find_model
42
-
43
-
44
- def guidance_type_select(default_guidance_type, pag_scale, attn_type):
45
- guidance_type = default_guidance_type
46
- if not (pag_scale > 1.0 and attn_type == "linear"):
47
- guidance_type = "classifier-free"
48
- elif pag_scale > 1.0 and attn_type == "linear":
49
- guidance_type = "classifier-free_PAG"
50
- return guidance_type
51
-
52
-
53
- def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
54
- """Returns binned height and width."""
55
- ar = float(height / width)
56
- closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
57
- default_hw = ratios[closest_ratio]
58
- return int(default_hw[0]), int(default_hw[1])
59
-
60
-
61
- @dataclass
62
- class SanaInference(SanaConfig):
63
- config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml" # config
64
- model_path: str = field(
65
- default="output/Sana_D20/SANA.pth", metadata={"help": "Path to the model file (positional)"}
66
- )
67
- output: str = "./output"
68
- bs: int = 1
69
- image_size: int = 1024
70
- cfg_scale: float = 5.0
71
- pag_scale: float = 2.0
72
- seed: int = 42
73
- step: int = -1
74
- custom_image_size: Optional[int] = None
75
- shield_model_path: str = field(
76
- default="google/shieldgemma-2b",
77
- metadata={"help": "The path to shield model, we employ ShieldGemma-2B by default."},
78
- )
79
-
80
-
81
- class SanaPipeline(nn.Module):
82
- def __init__(
83
- self,
84
- config: Optional[str] = "configs/sana_config/1024ms/Sana_1600M_img1024.yaml",
85
- ):
86
- super().__init__()
87
- config = pyrallis.load(SanaInference, open(config))
88
- self.args = self.config = config
89
-
90
- # set some hyper-parameters
91
- self.image_size = self.config.model.image_size
92
-
93
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
94
- logger = get_root_logger()
95
- self.logger = logger
96
- self.progress_fn = lambda progress, desc: None
97
-
98
- self.latent_size = self.image_size // config.vae.vae_downsample_rate
99
- self.max_sequence_length = config.text_encoder.model_max_length
100
- self.flow_shift = config.scheduler.flow_shift
101
- guidance_type = "classifier-free_PAG"
102
-
103
- weight_dtype = get_weight_dtype(config.model.mixed_precision)
104
- self.weight_dtype = weight_dtype
105
- self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
106
-
107
- self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
108
- self.vis_sampler = self.config.scheduler.vis_sampler
109
- logger.info(f"Sampler {self.vis_sampler}, flow_shift: {self.flow_shift}")
110
- self.guidance_type = guidance_type_select(guidance_type, self.args.pag_scale, config.model.attn_type)
111
- logger.info(f"Inference with {self.weight_dtype}, PAG guidance layer: {self.config.model.pag_applied_layers}")
112
-
113
- # 1. build vae and text encoder
114
- self.vae = self.build_vae(config.vae)
115
- self.tokenizer, self.text_encoder = self.build_text_encoder(config.text_encoder)
116
-
117
- # 2. build Sana model
118
- self.model = self.build_sana_model(config).to(self.device)
119
-
120
- # 3. pre-compute null embedding
121
- with torch.no_grad():
122
- null_caption_token = self.tokenizer(
123
- "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
124
- ).to(self.device)
125
- self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
126
- 0
127
- ]
128
-
129
- def build_vae(self, config):
130
- vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
131
- return vae
132
-
133
- def build_text_encoder(self, config):
134
- tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder_name, device=self.device)
135
- return tokenizer, text_encoder
136
-
137
- def build_sana_model(self, config):
138
- # model setting
139
- model_kwargs = model_init_config(config, latent_size=self.latent_size)
140
- model = build_model(
141
- config.model.model,
142
- use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
143
- **model_kwargs,
144
- )
145
- self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
146
- self.logger.info(
147
- f"{model.__class__.__name__}:{config.model.model},"
148
- f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}"
149
- )
150
- return model
151
-
152
- def from_pretrained(self, model_path):
153
- state_dict = find_model(model_path)
154
- state_dict = state_dict.get("state_dict", state_dict)
155
- if "pos_embed" in state_dict:
156
- del state_dict["pos_embed"]
157
- missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
158
- self.model.eval().to(self.weight_dtype)
159
-
160
- self.logger.info("Generating sample from ckpt: %s" % model_path)
161
- self.logger.warning(f"Missing keys: {missing}")
162
- self.logger.warning(f"Unexpected keys: {unexpected}")
163
-
164
- def register_progress_bar(self, progress_fn=None):
165
- self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn
166
-
167
- @torch.inference_mode()
168
- def forward(
169
- self,
170
- prompt=None,
171
- height=1024,
172
- width=1024,
173
- negative_prompt="",
174
- num_inference_steps=20,
175
- guidance_scale=5,
176
- pag_guidance_scale=2.5,
177
- num_images_per_prompt=1,
178
- generator=torch.Generator().manual_seed(42),
179
- latents=None,
180
- ):
181
- self.ori_height, self.ori_width = height, width
182
- self.height, self.width = classify_height_width_bin(height, width, ratios=self.base_ratios)
183
- self.latent_size_h, self.latent_size_w = (
184
- self.height // self.config.vae.vae_downsample_rate,
185
- self.width // self.config.vae.vae_downsample_rate,
186
- )
187
- self.guidance_type = guidance_type_select(self.guidance_type, pag_guidance_scale, self.config.model.attn_type)
188
-
189
- # 1. pre-compute negative embedding
190
- if negative_prompt != "":
191
- null_caption_token = self.tokenizer(
192
- negative_prompt,
193
- max_length=self.max_sequence_length,
194
- padding="max_length",
195
- truncation=True,
196
- return_tensors="pt",
197
- ).to(self.device)
198
- self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
199
- 0
200
- ]
201
-
202
- if prompt is None:
203
- prompt = [""]
204
- prompts = prompt if isinstance(prompt, list) else [prompt]
205
- samples = []
206
-
207
- for prompt in prompts:
208
- # data prepare
209
- prompts, hw, ar = (
210
- [],
211
- torch.tensor([[self.image_size, self.image_size]], dtype=torch.float, device=self.device).repeat(
212
- num_images_per_prompt, 1
213
- ),
214
- torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
215
- )
216
-
217
- for _ in range(num_images_per_prompt):
218
- prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
219
-
220
- with torch.no_grad():
221
- # prepare text feature
222
- if not self.config.text_encoder.chi_prompt:
223
- max_length_all = self.config.text_encoder.model_max_length
224
- prompts_all = prompts
225
- else:
226
- chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
227
- prompts_all = [chi_prompt + prompt for prompt in prompts]
228
- num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
229
- max_length_all = (
230
- num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
231
- ) # magic number 2: [bos], [_]
232
-
233
- caption_token = self.tokenizer(
234
- prompts_all,
235
- max_length=max_length_all,
236
- padding="max_length",
237
- truncation=True,
238
- return_tensors="pt",
239
- ).to(device=self.device)
240
- select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
241
- caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
242
- :, :, select_index
243
- ].to(self.weight_dtype)
244
- emb_masks = caption_token.attention_mask[:, select_index]
245
- null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
246
-
247
- n = len(prompts)
248
- if latents is None:
249
- z = torch.randn(
250
- n,
251
- self.config.vae.vae_latent_dim,
252
- self.latent_size_h,
253
- self.latent_size_w,
254
- generator=generator,
255
- device=self.device,
256
- )
257
- else:
258
- z = latents.to(self.device)
259
- model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
260
- if self.vis_sampler == "flow_euler":
261
- flow_solver = FlowEuler(
262
- self.model,
263
- condition=caption_embs,
264
- uncondition=null_y,
265
- cfg_scale=guidance_scale,
266
- model_kwargs=model_kwargs,
267
- )
268
- sample = flow_solver.sample(
269
- z,
270
- steps=num_inference_steps,
271
- )
272
- elif self.vis_sampler == "flow_dpm-solver":
273
- scheduler = DPMS(
274
- self.model,
275
- condition=caption_embs,
276
- uncondition=null_y,
277
- guidance_type=self.guidance_type,
278
- cfg_scale=guidance_scale,
279
- pag_scale=pag_guidance_scale,
280
- pag_applied_layers=self.config.model.pag_applied_layers,
281
- model_type="flow",
282
- model_kwargs=model_kwargs,
283
- schedule="FLOW",
284
- )
285
- scheduler.register_progress_bar(self.progress_fn)
286
- sample = scheduler.sample(
287
- z,
288
- steps=num_inference_steps,
289
- order=2,
290
- skip_type="time_uniform_flow",
291
- method="multistep",
292
- flow_shift=self.flow_shift,
293
- )
294
-
295
- sample = sample.to(self.vae_dtype)
296
- with torch.no_grad():
297
- sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
298
-
299
- sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height)
300
- samples.append(sample)
301
-
302
- return sample
303
-
304
- return samples