hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
from typing import AnyStr
import pathlib
from collections import OrderedDict
from packaging import version
import torch
from diffusers import StableDiffusionPipeline, SchedulerMixin
from diffusers import UNet2DConditionModel
from diffusers.utils import is_torch_version, is_xformers_available
DiffusersModels = OrderedDict({
"sd14": "CompVis/stable-diffusion-v1-4", # resolution: 512
"sd15": "runwayml/stable-diffusion-v1-5", # resolution: 512
"sd21b": "stabilityai/stable-diffusion-2-1-base", # resolution: 512
"sd21": "stabilityai/stable-diffusion-2-1", # resolution: 768
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0", # resolution: 1024
})
# default resolution
_model2resolution = {
"sd14": 512,
"sd15": 512,
"sd21b": 512,
"sd21": 768,
"sdxl": 1024,
}
def model2res(model_id: str):
return _model2resolution.get(model_id, 512)
def init_StableDiffusion_pipeline(model_id: AnyStr,
custom_pipeline: StableDiffusionPipeline,
custom_scheduler: SchedulerMixin = None,
device: torch.device = "cuda",
torch_dtype: torch.dtype = torch.float32,
local_files_only: bool = True,
force_download: bool = False,
resume_download: bool = False,
ldm_speed_up: bool = False,
enable_xformers: bool = True,
gradient_checkpoint: bool = False,
cpu_offload: bool = False,
vae_slicing: bool = False,
lora_path: AnyStr = None,
unet_path: AnyStr = None) -> StableDiffusionPipeline:
"""
A tool for initial diffusers pipeline.
Args:
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
custom_pipeline: any StableDiffusionPipeline pipeline
custom_scheduler: any scheduler
device: set device
torch_dtype: data type
local_files_only: prohibited download model
force_download: forced download model
resume_download: re-download model
ldm_speed_up: use the `torch.compile` api to speed up unet
enable_xformers: enable memory efficient attention from [xFormers]
gradient_checkpoint: activates gradient checkpointing for the current model
cpu_offload: enable sequential cpu offload
vae_slicing: enable sliced VAE decoding
lora_path: load LoRA checkpoint
unet_path: load unet checkpoint
Returns:
diffusers.StableDiffusionPipeline
"""
# get model id
model_id = DiffusersModels.get(model_id, model_id)
# process diffusion model
if custom_scheduler is not None:
pipeline = custom_pipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
force_download=force_download,
resume_download=resume_download,
scheduler=custom_scheduler.from_pretrained(model_id,
subfolder="scheduler",
local_files_only=local_files_only,
force_download=force_download,
resume_download=resume_download)
).to(device)
else:
pipeline = custom_pipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
force_download=force_download,
resume_download=resume_download,
).to(device)
print(f"load diffusers pipeline: {model_id}")
# process unet model if exist
if unet_path is not None and pathlib.Path(unet_path).exists():
print(f"=> load u-net from {unet_path}")
pipeline.unet.from_pretrained(model_id, subfolder="unet")
# process lora layers if exist
if lora_path is not None and pathlib.Path(lora_path).exists():
pipeline.unet.load_attn_procs(lora_path)
print(f"=> load lora layers into U-Net from {lora_path} ...")
# torch.compile
if ldm_speed_up:
if is_torch_version(">=", "2.0.0"):
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
print(f"=> enable torch.compile on U-Net")
else:
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
# Meta xformers
if enable_xformers:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
print(
"xFormers 0.0.16 cannot be used for training in some GPUs. "
"If you observe problems during training, please update xFormers to at least 0.0.17. "
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
print(f"=> enable xformers")
pipeline.unet.enable_xformers_memory_efficient_attention()
else:
print(f"=> warning: xformers is not available.")
# gradient checkpointing
if gradient_checkpoint:
# if pipeline.unet.is_gradient_checkpointing:
if True:
print(f"=> enable gradient checkpointing")
pipeline.unet.enable_gradient_checkpointing()
else:
print("=> waring: gradient checkpointing is not activated for this model.")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
if vae_slicing:
pipeline.enable_vae_slicing()
print(pipeline.scheduler)
return pipeline
def init_diffusers_unet(model_id: AnyStr,
device: torch.device = "cuda",
torch_dtype: torch.dtype = torch.float32,
local_files_only: bool = True,
force_download: bool = False,
resume_download: bool = False,
ldm_speed_up: bool = False,
enable_xformers: bool = True,
gradient_checkpoint: bool = False,
lora_path: AnyStr = None,
unet_path: AnyStr = None):
"""
A tool for initial diffusers UNet model.
Args:
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
device: set device
torch_dtype: data type
local_files_only: prohibited download model
force_download: forced download model
resume_download: re-download model
ldm_speed_up: use the `torch.compile` api to speed up unet
enable_xformers: enable memory efficient attention from [xFormers]
gradient_checkpoint: activates gradient checkpointing for the current model
lora_path: load LoRA checkpoint
unet_path: load unet checkpoint
Returns:
diffusers.UNet
"""
# get model id
model_id = DiffusersModels.get(model_id, model_id)
# process UNet model
unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
torch_dtype=torch_dtype,
local_files_only=local_files_only,
force_download=force_download,
resume_download=resume_download,
).to(device)
print(f"load diffusers UNet: {model_id}")
# process unet model if exist
if unet_path is not None and pathlib.Path(unet_path).exists():
print(f"=> load u-net from {unet_path}")
unet.from_pretrained(model_id)
# process lora layers if exist
if lora_path is not None and pathlib.Path(lora_path).exists():
unet.load_attn_procs(lora_path)
print(f"=> load lora layers into U-Net from {lora_path} ...")
# torch.compile
if ldm_speed_up:
if is_torch_version(">=", "2.0.0"):
unet = torch.compile(unet, mode="reduce-overhead", fullgraph=True)
print(f"=> enable torch.compile on U-Net")
else:
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
# Meta xformers
if enable_xformers:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
print(
"xFormers 0.0.16 cannot be used for training in some GPUs. "
"If you observe problems during training, please update xFormers to at least 0.0.17. "
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
print(f"=> enable xformers")
unet.enable_xformers_memory_efficient_attention()
else:
print(f"=> warning: xformers is not available.")
# gradient checkpointing
if gradient_checkpoint:
# if unet.is_gradient_checkpointing:
if True:
print(f"=> enable gradient checkpointing")
unet.enable_gradient_checkpointing()
else:
print("=> waring: gradient checkpointing is not activated for this model.")
return unet