# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import torch from ..utils import is_peft_version, logging logger = logging.get_logger(__name__) def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): # 1. get all state_dict_keys all_keys = list(state_dict.keys()) sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] # 2. check if needs remapping, if not return original dict is_in_sgm_format = False for key in all_keys: if any(p in key for p in sgm_patterns): is_in_sgm_format = True break if not is_in_sgm_format: return state_dict # 3. Else remap from SGM patterns new_state_dict = {} inner_block_map = ["resnets", "attentions", "upsamplers"] # Retrieves # of down, mid and up blocks input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() for layer in all_keys: if "text" in layer: new_state_dict[layer] = state_dict.pop(layer) else: layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) if sgm_patterns[0] in layer: input_block_ids.add(layer_id) elif sgm_patterns[1] in layer: middle_block_ids.add(layer_id) elif sgm_patterns[2] in layer: output_block_ids.add(layer_id) else: raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] for layer_id in input_block_ids } middle_blocks = { layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] for layer_id in middle_block_ids } output_blocks = { layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] for layer_id in output_block_ids } # Rename keys accordingly for i in input_block_ids: block_id = (i - 1) // (unet_config.layers_per_block + 1) layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) for key in input_blocks[i]: inner_block_id = int(key.split(delimiter)[block_slice_pos]) inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" new_key = delimiter.join( key.split(delimiter)[: block_slice_pos - 1] + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) new_state_dict[new_key] = state_dict.pop(key) for i in middle_block_ids: key_part = None if i == 0: key_part = [inner_block_map[0], "0"] elif i == 1: key_part = [inner_block_map[1], "0"] elif i == 2: key_part = [inner_block_map[0], "1"] else: raise ValueError(f"Invalid middle block id {i}.") for key in middle_blocks[i]: new_key = delimiter.join( key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] ) new_state_dict[new_key] = state_dict.pop(key) for i in output_block_ids: block_id = i // (unet_config.layers_per_block + 1) layer_in_block_id = i % (unet_config.layers_per_block + 1) for key in output_blocks[i]: inner_block_id = int(key.split(delimiter)[block_slice_pos]) inner_block_key = inner_block_map[inner_block_id] inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" new_key = delimiter.join( key.split(delimiter)[: block_slice_pos - 1] + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) new_state_dict[new_key] = state_dict.pop(key) if len(state_dict) > 0: raise ValueError("At this point all state dict entries have to be converted.") return new_state_dict def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): """ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. Args: state_dict (`dict`): The state dict to convert. unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to "text_encoder". Returns: `tuple`: A tuple containing the converted state dict and a dictionary of alphas. """ unet_state_dict = {} te_state_dict = {} te2_state_dict = {} network_alphas = {} # Check for DoRA-enabled LoRAs. dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) # Iterate over all LoRA weights. all_lora_keys = list(state_dict.keys()) for key in all_lora_keys: if not key.endswith("lora_down.weight"): continue # Extract LoRA name. lora_name = key.split(".")[0] # Find corresponding up weight and alpha. lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" # Handle U-Net LoRAs. if lora_name.startswith("lora_unet_"): diffusers_name = _convert_unet_lora_key(key) # Store down and up weights. unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Store DoRA scale if present. if dora_present_in_unet: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." unet_state_dict[ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Handle text encoder LoRAs. elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): diffusers_name = _convert_text_encoder_lora_key(key, lora_name) # Store down and up weights for te or te2. if lora_name.startswith(("lora_te_", "lora_te1_")): te_state_dict[diffusers_name] = state_dict.pop(key) te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Store DoRA scale if present. if dora_present_in_te or dora_present_in_te2: dora_scale_key_to_replace_te = ( "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): te_state_dict[ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) elif lora_name.startswith("lora_te2_"): te2_state_dict[ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Store alpha if present. if lora_name_alpha in state_dict: alpha = state_dict.pop(lora_name_alpha).item() network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) # Check if any keys remain. if len(state_dict) > 0: raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") logger.info("Non-diffusers checkpoint detected.") # Construct final state dict. unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} te2_state_dict = ( {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} if len(te2_state_dict) > 0 else None ) if te2_state_dict is not None: te_state_dict.update(te2_state_dict) new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alphas def _convert_unet_lora_key(key): """ Converts a U-Net LoRA key to a Diffusers compatible key. """ diffusers_name = key.replace("lora_unet_", "").replace("_", ".") # Replace common U-Net naming patterns. diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = diffusers_name.replace("middle.block", "mid_block") diffusers_name = diffusers_name.replace("mid.block", "mid_block") diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") diffusers_name = diffusers_name.replace("proj.in", "proj_in") diffusers_name = diffusers_name.replace("proj.out", "proj_out") diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") # SDXL specific conversions. if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: pattern = r"\.\d+(?=\D*$)" diffusers_name = re.sub(pattern, "", diffusers_name, count=1) if ".in." in diffusers_name: diffusers_name = diffusers_name.replace("in.layers.2", "conv1") if ".out." in diffusers_name: diffusers_name = diffusers_name.replace("out.layers.3", "conv2") if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: diffusers_name = diffusers_name.replace("op", "conv") if "skip" in diffusers_name: diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") # LyCORIS specific conversions. if "time.emb.proj" in diffusers_name: diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") if "conv.shortcut" in diffusers_name: diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") # General conversions. if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") elif "ff" in diffusers_name: pass elif any(key in diffusers_name for key in ("proj_in", "proj_out")): pass else: pass return diffusers_name def _convert_text_encoder_lora_key(key, lora_name): """ Converts a text encoder LoRA key to a Diffusers compatible key. """ if lora_name.startswith(("lora_te_", "lora_te1_")): key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" else: key_to_replace = "lora_te2_" diffusers_name = key.replace(key_to_replace, "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("self.attn", "self_attn") diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") diffusers_name = diffusers_name.replace("text.projection", "text_projection") if "self_attn" in diffusers_name or "text_projection" in diffusers_name: pass elif "mlp" in diffusers_name: # Be aware that this is the new diffusers convention and the rest of the code might # not utilize it yet. diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") return diffusers_name def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): """ Gets the correct alpha name for the Diffusers model. """ if lora_name_alpha.startswith("lora_unet_"): prefix = "unet." elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): prefix = "text_encoder." else: prefix = "text_encoder_2." new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" return {new_name: alpha} # The utilities under `_convert_kohya_flux_lora_to_diffusers()` # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py # All credits go to `kohya-ss`. def _convert_kohya_flux_lora_to_diffusers(state_dict): def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): if sds_key + ".lora_down.weight" not in sds_sd: return down_weight = sds_sd.pop(sds_key + ".lora_down.weight") # scale weight by alpha and dim rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 scale_down = scale scale_up = 1.0 while scale_down * 2 < scale_up: scale_down *= 2 scale_up /= 2 ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if sds_key + ".lora_down.weight" not in sds_sd: return down_weight = sds_sd.pop(sds_key + ".lora_down.weight") up_weight = sds_sd.pop(sds_key + ".lora_up.weight") sd_lora_rank = down_weight.shape[0] # scale weight by alpha and dim alpha = sds_sd.pop(sds_key + ".alpha") scale = alpha / sd_lora_rank # calculate scale_down and scale_up scale_down = scale scale_up = 1.0 while scale_down * 2 < scale_up: scale_down *= 2 scale_up /= 2 down_weight = down_weight * scale_down up_weight = up_weight * scale_up # calculate dims if not provided num_splits = len(ait_keys) if dims is None: dims = [up_weight.shape[0] // num_splits] * num_splits else: assert sum(dims) == up_weight.shape[0] # check upweight is sparse or not is_sparse = False if sd_lora_rank % num_splits == 0: ait_rank = sd_lora_rank // num_splits is_sparse = True i = 0 for j in range(len(dims)): for k in range(len(dims)): if j == k: continue is_sparse = is_sparse and torch.all( up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 ) i += dims[j] if is_sparse: logger.info(f"weight is sparse: {sds_key}") # make ai-toolkit weight ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] if not is_sparse: # down_weight is copied to each split ait_sd.update({k: down_weight for k in ait_down_keys}) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 else: # down_weight is chunked to each split ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # up_weight is sparse: only non-zero values are copied to each split i = 0 for j in range(len(dims)): ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() i += dims[j] def _convert_sd_scripts_to_ai_toolkit(sds_sd): ait_sd = {} for i in range(19): _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0", ) _convert_to_ai_toolkit_cat( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_qkv", [ f"transformer.transformer_blocks.{i}.attn.to_q", f"transformer.transformer_blocks.{i}.attn.to_k", f"transformer.transformer_blocks.{i}.attn.to_v", ], ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj", ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2", ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear", ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out", ) _convert_to_ai_toolkit_cat( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_qkv", [ f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"transformer.transformer_blocks.{i}.attn.add_v_proj", ], ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2", ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear", ) for i in range(38): _convert_to_ai_toolkit_cat( sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear1", [ f"transformer.single_transformer_blocks.{i}.attn.to_q", f"transformer.single_transformer_blocks.{i}.attn.to_k", f"transformer.single_transformer_blocks.{i}.attn.to_v", f"transformer.single_transformer_blocks.{i}.proj_mlp", ], dims=[3072, 3072, 3072, 12288], ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out", ) _convert_to_ai_toolkit( sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear", ) remaining_keys = list(sds_sd.keys()) te_state_dict = {} if remaining_keys: if not all(k.startswith("lora_te1") for k in remaining_keys): raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") for key in remaining_keys: if not key.endswith("lora_down.weight"): continue lora_name = key.split(".")[0] lora_name_up = f"{lora_name}.lora_up.weight" lora_name_alpha = f"{lora_name}.alpha" diffusers_name = _convert_text_encoder_lora_key(key, lora_name) if lora_name.startswith(("lora_te_", "lora_te1_")): down_weight = sds_sd.pop(key) sd_lora_rank = down_weight.shape[0] te_state_dict[diffusers_name] = down_weight te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up) if lora_name_alpha in sds_sd: alpha = sds_sd.pop(lora_name_alpha).item() scale = alpha / sd_lora_rank scale_down = scale scale_up = 1.0 while scale_down * 2 < scale_up: scale_down *= 2 scale_up /= 2 te_state_dict[diffusers_name] *= scale_down te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up if len(sds_sd) > 0: logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}") if te_state_dict: te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**ait_sd, **te_state_dict} return new_state_dict return _convert_sd_scripts_to_ai_toolkit(state_dict) # Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 # Some utilities were reused from # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): new_state_dict = {} orig_keys = list(old_state_dict.keys()) def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): down_weight = sds_sd.pop(sds_key) up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) # calculate dims if not provided num_splits = len(ait_keys) if dims is None: dims = [up_weight.shape[0] // num_splits] * num_splits else: assert sum(dims) == up_weight.shape[0] # make ai-toolkit weight ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] # down_weight is copied to each split ait_sd.update({k: down_weight for k in ait_down_keys}) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 for old_key in orig_keys: # Handle double_blocks if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) new_key = f"transformer.transformer_blocks.{block_num}" if "processor.proj_lora1" in old_key: new_key += ".attn.to_out.0" elif "processor.proj_lora2" in old_key: new_key += ".attn.to_add_out" # Handle text latents. elif "processor.qkv_lora2" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, old_key, [ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", ], ) # continue # Handle image latents. elif "processor.qkv_lora1" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, old_key, [ f"transformer.transformer_blocks.{block_num}.attn.to_q", f"transformer.transformer_blocks.{block_num}.attn.to_k", f"transformer.transformer_blocks.{block_num}.attn.to_v", ], ) # continue if "down" in old_key: new_key += ".lora_A.weight" elif "up" in old_key: new_key += ".lora_B.weight" # Handle single_blocks elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")): block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) new_key = f"transformer.single_transformer_blocks.{block_num}" if "proj_lora" in old_key: new_key += ".proj_out" elif "qkv_lora" in old_key and "up" not in old_key: handle_qkv( old_state_dict, new_state_dict, old_key, [f"transformer.single_transformer_blocks.{block_num}.norm.linear"], ) if "down" in old_key: new_key += ".lora_A.weight" elif "up" in old_key: new_key += ".lora_B.weight" else: # Handle other potential key patterns here new_key = old_key # Since we already handle qkv above. if "qkv" not in old_key: new_state_dict[new_key] = old_state_dict.pop(old_key) if len(old_state_dict) > 0: raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): converted_state_dict = {} original_state_dict_keys = list(original_state_dict.keys()) num_layers = 19 num_single_layers = 38 inner_dim = 3072 mlp_ratio = 4.0 def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight for lora_key in ["lora_A", "lora_B"]: ## time_text_embed.timestep_embedder <- time_in converted_state_dict[ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") converted_state_dict[ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") ## time_text_embed.text_embedder <- vector_in converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( f"vector_in.in_layer.{lora_key}.weight" ) if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop( f"vector_in.in_layer.{lora_key}.bias" ) converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop( f"vector_in.out_layer.{lora_key}.weight" ) if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop( f"vector_in.out_layer.{lora_key}.bias" ) # guidance has_guidance = any("guidance" in k for k in original_state_dict) if has_guidance: converted_state_dict[ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") converted_state_dict[ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") # context_embedder converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( f"txt_in.{lora_key}.weight" ) if f"txt_in.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop( f"txt_in.{lora_key}.bias" ) # x_embedder converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight") if f"img_in.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias") # double transformer blocks for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." for lora_key in ["lora_A", "lora_B"]: # norms converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mod.lin.{lora_key}.weight" ) if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" ) converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight" ) if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" ) # Q, K, V if lora_key == "lora_A": sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight") converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight]) converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight]) context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight") converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat( [context_lora_weight] ) converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat( [context_lora_weight] ) converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat( [context_lora_weight] ) else: sample_q, sample_k, sample_v = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0 ) converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v]) context_q, context_k, context_v = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0 ) converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q]) converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k]) converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v]) if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys: sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0 ) converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias]) if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys: context_q_bias, context_k_bias, context_v_bias = torch.chunk( original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0 ) converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias]) converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias]) converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias]) # ff img_mlp converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.0.{lora_key}.weight" ) if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" ) converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.2.{lora_key}.weight" ) if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" ) converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight" ) if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" ) converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight" ) if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" ) # output projections. converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_attn.proj.{lora_key}.weight" ) if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" ) converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop( f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight" ) if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop( f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" ) # qk_norm converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_attn.norm.query_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( f"double_blocks.{i}.img_attn.norm.key_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( f"double_blocks.{i}.txt_attn.norm.query_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( f"double_blocks.{i}.txt_attn.norm.key_norm.scale" ) # single transfomer blocks for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." for lora_key in ["lora_A", "lora_B"]: # norm.linear <- single_blocks.0.modulation.lin converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop( f"single_blocks.{i}.modulation.lin.{lora_key}.weight" ) if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop( f"single_blocks.{i}.modulation.lin.{lora_key}.bias" ) # Q, K, V, mlp mlp_hidden_dim = int(inner_dim * mlp_ratio) split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) if lora_key == "lora_A": lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight") converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight]) converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight]) if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias") converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias]) converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias]) else: q, k, v, mlp = torch.split( original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0 ) converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v]) converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp]) if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys: q_bias, k_bias, v_bias, mlp_bias = torch.split( original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0 ) converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias]) converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias]) converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias]) converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias]) # output projections. converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop( f"single_blocks.{i}.linear2.{lora_key}.weight" ) if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop( f"single_blocks.{i}.linear2.{lora_key}.bias" ) # qk norm converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( f"single_blocks.{i}.norm.query_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( f"single_blocks.{i}.norm.key_norm.scale" ) for lora_key in ["lora_A", "lora_B"]: converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop( f"final_layer.linear.{lora_key}.weight" ) if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop( f"final_layer.linear.{lora_key}.bias" ) converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight") ) if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys: converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift( original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias") ) if len(original_state_dict) > 0: raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") for key in list(converted_state_dict.keys()): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict