|
import gc |
|
import os |
|
from typing import Any, Callable, List, Literal, Union, Dict, Tuple |
|
import logging |
|
|
|
from safetensors.torch import load_file |
|
from safetensors import safe_open |
|
import torch |
|
from torch import nn |
|
from diffusers.models.controlnet import ControlNetModel |
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
from .convert_from_ckpt import ( |
|
convert_ldm_unet_checkpoint, |
|
convert_ldm_vae_checkpoint, |
|
convert_ldm_clip_checkpoint, |
|
) |
|
from .convert_lora_safetensor_to_diffusers import convert_motion_lora_ckpt_to_diffusers |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def update_pipeline_model_parameters( |
|
pipeline: DiffusionPipeline, |
|
model_path: str = None, |
|
lora_dict: Dict[str, Dict] = None, |
|
text_model_path: str = None, |
|
device="cuda", |
|
need_unload: bool = False, |
|
): |
|
if model_path is not None: |
|
pipeline = update_pipeline_basemodel( |
|
pipeline, model_path, text_sd_model_path=text_model_path, device=device |
|
) |
|
if lora_dict is not None: |
|
pipeline, unload_dict = update_pipeline_lora_models( |
|
pipeline, |
|
lora_dict, |
|
device=device, |
|
need_unload=need_unload, |
|
) |
|
if need_unload: |
|
return pipeline, unload_dict |
|
return pipeline |
|
|
|
|
|
def update_pipeline_basemodel( |
|
pipeline: DiffusionPipeline, |
|
model_path: str, |
|
text_sd_model_path: str, |
|
device: str = "cuda", |
|
): |
|
"""使用model_path更新pipeline中的基础参数 |
|
|
|
Args: |
|
pipeline (DiffusionPipeline): _description_ |
|
model_path (str): _description_ |
|
text_sd_model_path (str): _description_ |
|
device (str, optional): _description_. Defaults to "cuda". |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
if model_path.endswith(".ckpt"): |
|
state_dict = torch.load(model_path, map_location=device) |
|
pipeline.unet.load_state_dict(state_dict) |
|
print("update sd_model", model_path) |
|
elif model_path.endswith(".safetensors"): |
|
base_state_dict = {} |
|
with safe_open(model_path, framework="pt", device=device) as f: |
|
for key in f.keys(): |
|
base_state_dict[key] = f.get_tensor(key) |
|
|
|
is_lora = all("lora" in k for k in base_state_dict.keys()) |
|
assert is_lora == False, "Base model cannot be LoRA: {}".format(model_path) |
|
|
|
|
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint( |
|
base_state_dict, pipeline.vae.config |
|
) |
|
pipeline.vae.load_state_dict(converted_vae_checkpoint) |
|
|
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint( |
|
base_state_dict, pipeline.unet.config |
|
) |
|
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
|
|
pipeline.text_encoder = convert_ldm_clip_checkpoint( |
|
base_state_dict, text_sd_model_path |
|
) |
|
print("update sd_model", model_path) |
|
pipeline.to(device) |
|
return pipeline |
|
|
|
|
|
|
|
LORA_BLOCK_WEIGHT_MAP = { |
|
"FACE": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], |
|
"DEFACE": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1], |
|
"ALL": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], |
|
"MIDD": [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], |
|
"OUTALL": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], |
|
} |
|
|
|
|
|
|
|
def update_pipeline_lora_model( |
|
pipeline: DiffusionPipeline, |
|
lora: Union[str, Dict], |
|
alpha: float = 0.75, |
|
device: str = "cuda", |
|
lora_prefix_unet: str = "lora_unet", |
|
lora_prefix_text_encoder: str = "lora_te", |
|
lora_unet_layers=[ |
|
"lora_unet_down_blocks_0_attentions_0", |
|
"lora_unet_down_blocks_0_attentions_1", |
|
"lora_unet_down_blocks_1_attentions_0", |
|
"lora_unet_down_blocks_1_attentions_1", |
|
"lora_unet_down_blocks_2_attentions_0", |
|
"lora_unet_down_blocks_2_attentions_1", |
|
"lora_unet_mid_block_attentions_0", |
|
"lora_unet_up_blocks_1_attentions_0", |
|
"lora_unet_up_blocks_1_attentions_1", |
|
"lora_unet_up_blocks_1_attentions_2", |
|
"lora_unet_up_blocks_2_attentions_0", |
|
"lora_unet_up_blocks_2_attentions_1", |
|
"lora_unet_up_blocks_2_attentions_2", |
|
"lora_unet_up_blocks_3_attentions_0", |
|
"lora_unet_up_blocks_3_attentions_1", |
|
"lora_unet_up_blocks_3_attentions_2", |
|
], |
|
lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", |
|
need_unload: bool = False, |
|
): |
|
"""使用 lora 更新pipeline中的unet相关参数 |
|
|
|
Args: |
|
pipeline (DiffusionPipeline): _description_ |
|
lora (Union[str, Dict]): _description_ |
|
alpha (float, optional): _description_. Defaults to 0.75. |
|
device (str, optional): _description_. Defaults to "cuda". |
|
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". |
|
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". |
|
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. |
|
lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". |
|
need_unload (bool, optional): _description_. Defaults to False. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
if lora_block_weight_str is not None: |
|
lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] |
|
if lora_block_weight: |
|
assert len(lora_block_weight) == 17 |
|
|
|
if isinstance(lora, str): |
|
state_dict = load_file(lora, device=device) |
|
else: |
|
for k in lora: |
|
lora[k] = lora[k].to(device) |
|
state_dict = lora |
|
|
|
visited = set() |
|
unload_dict = [] |
|
|
|
for key in state_dict: |
|
|
|
|
|
|
|
|
|
if ".alpha" in key or key in visited: |
|
continue |
|
|
|
if "text" in key: |
|
layer_infos = ( |
|
key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") |
|
) |
|
curr_layer = pipeline.text_encoder |
|
else: |
|
layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") |
|
curr_layer = pipeline.unet |
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += "_" + layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
|
|
pair_keys = [] |
|
if "lora_down" in key: |
|
pair_keys.append(key.replace("lora_down", "lora_up")) |
|
pair_keys.append(key) |
|
alpha_key = key.replace("lora_down.weight", "alpha") |
|
else: |
|
pair_keys.append(key) |
|
pair_keys.append(key.replace("lora_up", "lora_down")) |
|
alpha_key = key.replace("lora_up.weight", "alpha") |
|
|
|
|
|
if len(state_dict[pair_keys[0]].shape) == 4: |
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
|
weight_down = ( |
|
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
|
) |
|
if alpha_key in state_dict: |
|
weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] |
|
else: |
|
weight_scale = 1.0 |
|
|
|
if len(weight_up.shape) == len(weight_down.shape): |
|
adding_weight = ( |
|
alpha |
|
* weight_scale |
|
* torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) |
|
) |
|
else: |
|
adding_weight = ( |
|
alpha |
|
* weight_scale |
|
* torch.einsum("a b, b c h w -> a c h w", weight_up, weight_down) |
|
) |
|
else: |
|
weight_up = state_dict[pair_keys[0]].to(torch.float32) |
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
if alpha_key in state_dict: |
|
weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] |
|
else: |
|
weight_scale = 1.0 |
|
adding_weight = alpha * weight_scale * torch.mm(weight_up, weight_down) |
|
adding_weight = adding_weight.to(torch.float16) |
|
if lora_block_weight: |
|
if "text" in key: |
|
adding_weight *= lora_block_weight[0] |
|
else: |
|
for idx, layer in enumerate(lora_unet_layers): |
|
if layer in key: |
|
adding_weight *= lora_block_weight[idx + 1] |
|
break |
|
|
|
curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} |
|
curr_layer.weight.data += adding_weight |
|
|
|
unload_dict.append(curr_layer_unload_data) |
|
|
|
for item in pair_keys: |
|
visited.add(item) |
|
if need_unload: |
|
return pipeline, unload_dict |
|
else: |
|
return pipeline |
|
|
|
|
|
|
|
def update_pipeline_lora_model_old( |
|
pipeline: DiffusionPipeline, |
|
lora: Union[str, Dict], |
|
alpha: float = 0.75, |
|
device: str = "cuda", |
|
lora_prefix_unet: str = "lora_unet", |
|
lora_prefix_text_encoder: str = "lora_te", |
|
lora_unet_layers=[ |
|
"lora_unet_down_blocks_0_attentions_0", |
|
"lora_unet_down_blocks_0_attentions_1", |
|
"lora_unet_down_blocks_1_attentions_0", |
|
"lora_unet_down_blocks_1_attentions_1", |
|
"lora_unet_down_blocks_2_attentions_0", |
|
"lora_unet_down_blocks_2_attentions_1", |
|
"lora_unet_mid_block_attentions_0", |
|
"lora_unet_up_blocks_1_attentions_0", |
|
"lora_unet_up_blocks_1_attentions_1", |
|
"lora_unet_up_blocks_1_attentions_2", |
|
"lora_unet_up_blocks_2_attentions_0", |
|
"lora_unet_up_blocks_2_attentions_1", |
|
"lora_unet_up_blocks_2_attentions_2", |
|
"lora_unet_up_blocks_3_attentions_0", |
|
"lora_unet_up_blocks_3_attentions_1", |
|
"lora_unet_up_blocks_3_attentions_2", |
|
], |
|
lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", |
|
need_unload: bool = False, |
|
): |
|
"""使用 lora 更新pipeline中的unet相关参数 |
|
|
|
Args: |
|
pipeline (DiffusionPipeline): _description_ |
|
lora (Union[str, Dict]): _description_ |
|
alpha (float, optional): _description_. Defaults to 0.75. |
|
device (str, optional): _description_. Defaults to "cuda". |
|
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". |
|
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". |
|
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. |
|
lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". |
|
need_unload (bool, optional): _description_. Defaults to False. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
if lora_block_weight_str is not None: |
|
lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] |
|
if lora_block_weight: |
|
assert len(lora_block_weight) == 17 |
|
|
|
if isinstance(lora, str): |
|
state_dict = load_file(lora, device=device) |
|
else: |
|
for k in lora: |
|
lora[k] = lora[k].to(device) |
|
state_dict = lora |
|
|
|
visited = set() |
|
unload_dict = [] |
|
|
|
for key in state_dict: |
|
|
|
|
|
|
|
|
|
if ".alpha" in key or key in visited: |
|
continue |
|
|
|
if "text" in key: |
|
layer_infos = ( |
|
key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") |
|
) |
|
curr_layer = pipeline.text_encoder |
|
else: |
|
layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") |
|
curr_layer = pipeline.unet |
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += "_" + layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
|
|
pair_keys = [] |
|
if "lora_down" in key: |
|
pair_keys.append(key.replace("lora_down", "lora_up")) |
|
pair_keys.append(key) |
|
else: |
|
pair_keys.append(key) |
|
pair_keys.append(key.replace("lora_up", "lora_down")) |
|
|
|
|
|
if len(state_dict[pair_keys[0]].shape) == 4: |
|
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
|
weight_down = ( |
|
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
|
) |
|
adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze( |
|
2 |
|
).unsqueeze(3) |
|
else: |
|
weight_up = state_dict[pair_keys[0]].to(torch.float32) |
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
adding_weight = alpha * torch.mm(weight_up, weight_down) |
|
|
|
if lora_block_weight: |
|
if "text" in key: |
|
adding_weight *= lora_block_weight[0] |
|
else: |
|
for idx, layer in enumerate(lora_unet_layers): |
|
if layer in key: |
|
adding_weight *= lora_block_weight[idx + 1] |
|
break |
|
|
|
curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} |
|
curr_layer.weight.data += adding_weight |
|
|
|
unload_dict.append(curr_layer_unload_data) |
|
|
|
for item in pair_keys: |
|
visited.add(item) |
|
if need_unload: |
|
return pipeline, unload_dict |
|
else: |
|
return pipeline |
|
|
|
|
|
def update_pipeline_lora_models( |
|
pipeline: DiffusionPipeline, |
|
lora_dict: Dict[str, Dict], |
|
device: str = "cuda", |
|
need_unload: bool = True, |
|
lora_prefix_unet: str = "lora_unet", |
|
lora_prefix_text_encoder: str = "lora_te", |
|
lora_unet_layers=[ |
|
"lora_unet_down_blocks_0_attentions_0", |
|
"lora_unet_down_blocks_0_attentions_1", |
|
"lora_unet_down_blocks_1_attentions_0", |
|
"lora_unet_down_blocks_1_attentions_1", |
|
"lora_unet_down_blocks_2_attentions_0", |
|
"lora_unet_down_blocks_2_attentions_1", |
|
"lora_unet_mid_block_attentions_0", |
|
"lora_unet_up_blocks_1_attentions_0", |
|
"lora_unet_up_blocks_1_attentions_1", |
|
"lora_unet_up_blocks_1_attentions_2", |
|
"lora_unet_up_blocks_2_attentions_0", |
|
"lora_unet_up_blocks_2_attentions_1", |
|
"lora_unet_up_blocks_2_attentions_2", |
|
"lora_unet_up_blocks_3_attentions_0", |
|
"lora_unet_up_blocks_3_attentions_1", |
|
"lora_unet_up_blocks_3_attentions_2", |
|
], |
|
): |
|
"""使用 lora 更新pipeline中的unet相关参数 |
|
|
|
Args: |
|
pipeline (DiffusionPipeline): _description_ |
|
lora_dict (Dict[str, Dict]): _description_ |
|
device (str, optional): _description_. Defaults to "cuda". |
|
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". |
|
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". |
|
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
unload_dicts = [] |
|
for lora, value in lora_dict.items(): |
|
lora_name = os.path.basename(lora).replace(".safetensors", "") |
|
strength_offset = value.get("strength_offset", 0.0) |
|
alpha = value.get("strength", 1.0) |
|
alpha += strength_offset |
|
lora_weight_str = value.get("lora_block_weight", "ALL") |
|
lora = load_file(lora) |
|
pipeline, unload_dict = update_pipeline_lora_model( |
|
pipeline, |
|
lora=lora, |
|
device=device, |
|
alpha=alpha, |
|
lora_prefix_unet=lora_prefix_unet, |
|
lora_prefix_text_encoder=lora_prefix_text_encoder, |
|
lora_unet_layers=lora_unet_layers, |
|
lora_block_weight_str=lora_weight_str, |
|
need_unload=True, |
|
) |
|
print( |
|
"Update LoRA {} with alpha {} and weight {}".format( |
|
lora_name, alpha, lora_weight_str |
|
) |
|
) |
|
unload_dicts += unload_dict |
|
return pipeline, unload_dicts |
|
|
|
|
|
def unload_lora(unload_dict: List[Dict[str, nn.Module]]): |
|
for layer_data in unload_dict: |
|
layer = layer_data["layer"] |
|
added_weight = layer_data["added_weight"] |
|
layer.weight.data -= added_weight |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def load_motion_lora_weights( |
|
animation_pipeline, |
|
motion_module_lora_configs=[], |
|
): |
|
for motion_module_lora_config in motion_module_lora_configs: |
|
path, alpha = ( |
|
motion_module_lora_config["path"], |
|
motion_module_lora_config["alpha"], |
|
) |
|
print(f"load motion LoRA from {path}") |
|
|
|
motion_lora_state_dict = torch.load(path, map_location="cpu") |
|
motion_lora_state_dict = ( |
|
motion_lora_state_dict["state_dict"] |
|
if "state_dict" in motion_lora_state_dict |
|
else motion_lora_state_dict |
|
) |
|
|
|
animation_pipeline = convert_motion_lora_ckpt_to_diffusers( |
|
animation_pipeline, motion_lora_state_dict, alpha |
|
) |
|
|
|
return animation_pipeline |
|
|