Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
from typing import AnyStr | |
import pathlib | |
from collections import OrderedDict | |
from packaging import version | |
import torch | |
from diffusers import StableDiffusionPipeline, SchedulerMixin | |
from diffusers import UNet2DConditionModel | |
from diffusers.utils import is_torch_version, is_xformers_available | |
DiffusersModels = OrderedDict({ | |
"sd14": "CompVis/stable-diffusion-v1-4", # resolution: 512 | |
"sd15": "runwayml/stable-diffusion-v1-5", # resolution: 512 | |
"sd21b": "stabilityai/stable-diffusion-2-1-base", # resolution: 512 | |
"sd21": "stabilityai/stable-diffusion-2-1", # resolution: 768 | |
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0", # resolution: 1024 | |
}) | |
# default resolution | |
_model2resolution = { | |
"sd14": 512, | |
"sd15": 512, | |
"sd21b": 512, | |
"sd21": 768, | |
"sdxl": 1024, | |
} | |
def model2res(model_id: str): | |
return _model2resolution.get(model_id, 512) | |
def init_StableDiffusion_pipeline(model_id: AnyStr, | |
custom_pipeline: StableDiffusionPipeline, | |
custom_scheduler: SchedulerMixin = None, | |
device: torch.device = "cuda", | |
torch_dtype: torch.dtype = torch.float32, | |
local_files_only: bool = True, | |
force_download: bool = False, | |
resume_download: bool = False, | |
ldm_speed_up: bool = False, | |
enable_xformers: bool = True, | |
gradient_checkpoint: bool = False, | |
cpu_offload: bool = False, | |
vae_slicing: bool = False, | |
lora_path: AnyStr = None, | |
unet_path: AnyStr = None) -> StableDiffusionPipeline: | |
""" | |
A tool for initial diffusers pipeline. | |
Args: | |
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path | |
custom_pipeline: any StableDiffusionPipeline pipeline | |
custom_scheduler: any scheduler | |
device: set device | |
torch_dtype: data type | |
local_files_only: prohibited download model | |
force_download: forced download model | |
resume_download: re-download model | |
ldm_speed_up: use the `torch.compile` api to speed up unet | |
enable_xformers: enable memory efficient attention from [xFormers] | |
gradient_checkpoint: activates gradient checkpointing for the current model | |
cpu_offload: enable sequential cpu offload | |
vae_slicing: enable sliced VAE decoding | |
lora_path: load LoRA checkpoint | |
unet_path: load unet checkpoint | |
Returns: | |
diffusers.StableDiffusionPipeline | |
""" | |
# get model id | |
model_id = DiffusersModels.get(model_id, model_id) | |
# process diffusion model | |
if custom_scheduler is not None: | |
pipeline = custom_pipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
local_files_only=local_files_only, | |
force_download=force_download, | |
resume_download=resume_download, | |
scheduler=custom_scheduler.from_pretrained(model_id, | |
subfolder="scheduler", | |
local_files_only=local_files_only, | |
force_download=force_download, | |
resume_download=resume_download) | |
).to(device) | |
else: | |
pipeline = custom_pipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
local_files_only=local_files_only, | |
force_download=force_download, | |
resume_download=resume_download, | |
).to(device) | |
print(f"load diffusers pipeline: {model_id}") | |
# process unet model if exist | |
if unet_path is not None and pathlib.Path(unet_path).exists(): | |
print(f"=> load u-net from {unet_path}") | |
pipeline.unet.from_pretrained(model_id, subfolder="unet") | |
# process lora layers if exist | |
if lora_path is not None and pathlib.Path(lora_path).exists(): | |
pipeline.unet.load_attn_procs(lora_path) | |
print(f"=> load lora layers into U-Net from {lora_path} ...") | |
# torch.compile | |
if ldm_speed_up: | |
if is_torch_version(">=", "2.0.0"): | |
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) | |
print(f"=> enable torch.compile on U-Net") | |
else: | |
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0") | |
# Meta xformers | |
if enable_xformers: | |
if is_xformers_available(): | |
import xformers | |
xformers_version = version.parse(xformers.__version__) | |
if xformers_version == version.parse("0.0.16"): | |
print( | |
"xFormers 0.0.16 cannot be used for training in some GPUs. " | |
"If you observe problems during training, please update xFormers to at least 0.0.17. " | |
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
) | |
print(f"=> enable xformers") | |
pipeline.unet.enable_xformers_memory_efficient_attention() | |
else: | |
print(f"=> warning: xformers is not available.") | |
# gradient checkpointing | |
if gradient_checkpoint: | |
# if pipeline.unet.is_gradient_checkpointing: | |
if True: | |
print(f"=> enable gradient checkpointing") | |
pipeline.unet.enable_gradient_checkpointing() | |
else: | |
print("=> waring: gradient checkpointing is not activated for this model.") | |
if cpu_offload: | |
pipeline.enable_sequential_cpu_offload() | |
if vae_slicing: | |
pipeline.enable_vae_slicing() | |
print(pipeline.scheduler) | |
return pipeline | |
def init_diffusers_unet(model_id: AnyStr, | |
device: torch.device = "cuda", | |
torch_dtype: torch.dtype = torch.float32, | |
local_files_only: bool = True, | |
force_download: bool = False, | |
resume_download: bool = False, | |
ldm_speed_up: bool = False, | |
enable_xformers: bool = True, | |
gradient_checkpoint: bool = False, | |
lora_path: AnyStr = None, | |
unet_path: AnyStr = None): | |
""" | |
A tool for initial diffusers UNet model. | |
Args: | |
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path | |
device: set device | |
torch_dtype: data type | |
local_files_only: prohibited download model | |
force_download: forced download model | |
resume_download: re-download model | |
ldm_speed_up: use the `torch.compile` api to speed up unet | |
enable_xformers: enable memory efficient attention from [xFormers] | |
gradient_checkpoint: activates gradient checkpointing for the current model | |
lora_path: load LoRA checkpoint | |
unet_path: load unet checkpoint | |
Returns: | |
diffusers.UNet | |
""" | |
# get model id | |
model_id = DiffusersModels.get(model_id, model_id) | |
# process UNet model | |
unet = UNet2DConditionModel.from_pretrained( | |
model_id, | |
subfolder="unet", | |
torch_dtype=torch_dtype, | |
local_files_only=local_files_only, | |
force_download=force_download, | |
resume_download=resume_download, | |
).to(device) | |
print(f"load diffusers UNet: {model_id}") | |
# process unet model if exist | |
if unet_path is not None and pathlib.Path(unet_path).exists(): | |
print(f"=> load u-net from {unet_path}") | |
unet.from_pretrained(model_id) | |
# process lora layers if exist | |
if lora_path is not None and pathlib.Path(lora_path).exists(): | |
unet.load_attn_procs(lora_path) | |
print(f"=> load lora layers into U-Net from {lora_path} ...") | |
# torch.compile | |
if ldm_speed_up: | |
if is_torch_version(">=", "2.0.0"): | |
unet = torch.compile(unet, mode="reduce-overhead", fullgraph=True) | |
print(f"=> enable torch.compile on U-Net") | |
else: | |
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0") | |
# Meta xformers | |
if enable_xformers: | |
if is_xformers_available(): | |
import xformers | |
xformers_version = version.parse(xformers.__version__) | |
if xformers_version == version.parse("0.0.16"): | |
print( | |
"xFormers 0.0.16 cannot be used for training in some GPUs. " | |
"If you observe problems during training, please update xFormers to at least 0.0.17. " | |
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
) | |
print(f"=> enable xformers") | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
print(f"=> warning: xformers is not available.") | |
# gradient checkpointing | |
if gradient_checkpoint: | |
# if unet.is_gradient_checkpointing: | |
if True: | |
print(f"=> enable gradient checkpointing") | |
unet.enable_gradient_checkpointing() | |
else: | |
print("=> waring: gradient checkpointing is not activated for this model.") | |
return unet | |