Spaces:
Sleeping
Sleeping
File size: 8,752 Bytes
8a5292f fd14e0a 8a5292f 57f479a 8a5292f 9d87a5d 958ae3c 8a5292f 579ce76 8a5292f a07fb9e 8a5292f a07fb9e 8a5292f 32e7fe7 8a5292f 32e7fe7 8a5292f 053606a 8a5292f 632014f 8a5292f 32e7fe7 8a5292f 32e7fe7 a07fb9e 8a5292f a07fb9e fd14e0a a07fb9e 48fed64 fd14e0a 8a5292f fd14e0a 8a5292f a07fb9e fd14e0a 48fed64 a07fb9e 8a5292f a07fb9e 3cea76f 8a5292f a07fb9e 8a5292f 57f479a 8a5292f 57f479a 8a5292f 632014f 9d87a5d 8a5292f 41aa70b 8a5292f 5f17264 a07fb9e 5f17264 8a5292f f11095a 9d87a5d f11095a 41aa70b 9d87a5d 41aa70b 8a5292f 9d87a5d 8a5292f 41aa70b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import gradio as gr
import torch
from diffusers import (
AutoPipelineForText2Image,
StableDiffusionXLControlNetPipeline,
DiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionAdapterPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
ControlNetModel,
T2IAdapter,
)
import time
import utils
dtype = torch.float16
device = torch.device("cuda")
# pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile
examples = [["SD T2I", 4, True, True]]
pipeline_mapping = {
"SD T2I": (DiffusionPipeline, "runwayml/stable-diffusion-v1-5"),
"SD I2I": (StableDiffusionImg2ImgPipeline, "runwayml/stable-diffusion-v1-5"),
"SD Inpainting": (
StableDiffusionInpaintPipeline,
"runwayml/stable-diffusion-inpainting",
),
"SD ControlNet": (
StableDiffusionControlNetPipeline,
"runwayml/stable-diffusion-v1-5",
"lllyasviel/sd-controlnet-canny",
),
"SD T2I Adapters": (
StableDiffusionAdapterPipeline,
"CompVis/stable-diffusion-v1-4",
"TencentARC/t2iadapter_canny_sd14v1",
),
"SDXL T2I": (DiffusionPipeline, "stabilityai/stable-diffusion-xl-base-1.0"),
"SDXL I2I": (
StableDiffusionXLImg2ImgPipeline,
"stabilityai/stable-diffusion-xl-base-1.0",
),
"SDXL Inpainting": (
StableDiffusionXLInpaintPipeline,
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
),
"SDXL ControlNet": (
StableDiffusionXLControlNetPipeline,
"stabilityai/stable-diffusion-xl-base-1.0",
"diffusers/controlnet-canny-sdxl-1.0",
),
"SDXL T2I Adapters": (
StableDiffusionXLAdapterPipeline,
"stabilityai/stable-diffusion-xl-base-1.0",
"TencentARC/t2i-adapter-canny-sdxl-1.0",
),
"Kandinsky 2.2 (T2I)": (
AutoPipelineForText2Image,
"kandinsky-community/kandinsky-2-2-decoder",
),
"Würstchen (T2I)": (AutoPipelineForText2Image, "warp-ai/wuerstchen"),
}
def load_pipeline(
pipeline_to_benchmark: str,
use_channels_last: bool = False,
do_torch_compile: bool = False,
):
# Get pipeline details.
print(f"Loading pipeline: {pipeline_to_benchmark}")
pipeline_details = pipeline_mapping[pipeline_to_benchmark]
pipeline_cls = pipeline_details[0]
pipeline_ckpt = pipeline_details[1]
# Load adapter if needed.
if "ControlNet" in pipeline_to_benchmark:
controlnet_ckpt = pipeline_details[2]
controlnet = ControlNetModel.from_pretrained(
controlnet_ckpt, torch_dtype=dtype
).to(device)
elif "Adapters" in pipeline_to_benchmark:
adapter_clpt = pipeline_details[2]
adapter = T2IAdapter.from_pretrained(adapter_clpt, torch_dtype=dtype).to(device)
# Load pipeline.
if (
"ControlNet" not in pipeline_to_benchmark
and "Adapters" not in pipeline_to_benchmark
):
pipeline = pipeline_cls.from_pretrained(pipeline_ckpt, torch_dtype=dtype)
elif "ControlNet" in pipeline_to_benchmark:
pipeline = pipeline_cls.from_pretrained(
pipeline_ckpt, controlnet=controlnet, torch_dtype=dtype
)
elif "Adapters" in pipeline_to_benchmark:
pipeline = pipeline_cls.from_pretrained(
pipeline_ckpt, adapter=adapter, torch_dtype=dtype
)
pipeline.to(device)
# Optionally set memory layout.
if use_channels_last:
print("Setting memory layout.")
if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
pipeline.unet.to(memory_format=torch.channels_last)
elif pipeline_to_benchmark == "Würstchen (T2I)":
pipeline.prior_prior.to(memory_format=torch.channels_last)
pipeline.decoder.to(memory_format=torch.channels_last)
elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
pipeline.unet.to(memory_format=torch.channels_last)
if hasattr(pipeline, "controlnet"):
pipeline.controlnet.to(memory_format=torch.channels_last)
elif hasattr(pipeline, "adapter"):
pipeline.adapter.to(memory_format=torch.channels_last)
# Optional torch compilation.
if do_torch_compile:
print("Compiling pipeline.")
if pipeline_to_benchmark not in ["Würstchen (T2I)", "Kandinsky 2.2 (T2I)"]:
pipeline.unet = torch.compile(
pipeline.unet, mode="reduce-overhead", fullgraph=True
)
elif pipeline_to_benchmark == "Würstchen (T2I)":
pipeline.prior_prior = torch.compile(
pipeline.prior_prior, mode="reduce-overhead", fullgraph=True
)
pipeline.decoder = torch.compile(
pipeline.decoder, mode="reduce-overhead", fullgraph=True
)
elif pipeline_to_benchmark == "Kandinsky 2.2 (T2I)":
pipeline.unet = torch.compile(
pipeline.unet, mode="reduce-overhead", fullgraph=True
)
if hasattr(pipeline, "controlnet"):
pipeline.controlnet = torch.compile(
pipeline.controlnet, mode="reduce-overhead", fullgraph=True
)
elif hasattr(pipeline, "adapter"):
pipeline.adapter = torch.compile(
pipeline.adapter, mode="reduce-overhead", fullgraph=True
)
print("Pipeline loaded.")
pipeline.set_progress_bar_config(disable=True)
return pipeline
def generate(
pipeline_to_benchmark: str,
num_images_per_prompt: int = 1,
use_channels_last: bool = False,
do_torch_compile: bool = False,
):
if isinstance(pipeline_to_benchmark, list):
# It can only happen when we don't select a pipeline to benchmark.
raise ValueError(
"pipeline_to_benchmark cannot be None. Please select a pipeline to benchmark."
)
print("Start...")
print("Torch version", torch.__version__)
print("Torch CUDA version", torch.version.cuda)
pipeline = load_pipeline(
pipeline_to_benchmark=pipeline_to_benchmark,
use_channels_last=use_channels_last,
do_torch_compile=do_torch_compile,
)
for _ in range(3):
prompt = 77 * "a"
num_inference_steps = 20
call_args = dict(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps,
)
if pipeline_to_benchmark in ["SD I2I", "SDXL I2I"]:
image = utils.get_image_for_img_to_img(pipeline_to_benchmark)
call_args.update({"image": image})
elif "Inpainting" in pipeline_to_benchmark:
image, mask_image = utils.get_image_from_inpainting(pipeline_to_benchmark)
call_args.update({"image": image, "mask_image": mask_image})
elif "ControlNet" in pipeline_to_benchmark:
image = utils.get_image_for_controlnet(pipeline_to_benchmark)
call_args.update({"image": image})
elif "Adapters" in pipeline_to_benchmark:
image = utils.get_image_for_adapters(pipeline_to_benchmark)
call_args.update({"image": image})
start_time = time.time()
_ = pipeline(**call_args).images
end_time = time.time()
print(f"For {num_inference_steps} steps", end_time - start_time)
print("Avg per step", (end_time - start_time) / num_inference_steps)
return (
f"Avg per step: {((end_time - start_time) / num_inference_steps):.4f} seconds."
)
with gr.Blocks(css="style.css") as demo:
do_torch_compile = gr.Checkbox(label="Enable torch.compile()?")
use_channels_last = gr.Checkbox(label="Use `channels_last` memory layout?")
pipeline_to_benchmark = gr.Dropdown(
list(pipeline_mapping.keys()),
value=None,
multiselect=False,
label="Pipeline to benchmark",
)
batch_size = gr.Slider(
label="Number of images per prompt",
minimum=1,
maximum=16,
step=1,
value=1,
)
btn = gr.Button("Benchmark!").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
result = gr.Text(label="Result")
gr.Examples(
examples=examples,
inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
outputs=result,
fn=generate,
cache_examples=True,
)
btn.click(
fn=generate,
inputs=[pipeline_to_benchmark, batch_size, use_channels_last, do_torch_compile],
outputs=result,
)
demo.launch(show_error=True)
|