import copy from einops import repeat from diffusers import __version__ from diffusers.models.modeling_utils import ( _add_variant, _get_checkpoint_shard_files, _get_model_file, # diffusers.utils _determine_device_map, _fetch_index_file, # diffusers.models.model_loading_utils ) from diffusers.models.modeling_utils import * from diffusers.models.transformers.transformer_sd3 import * from extensions.diffusers_diffsplat.models.mv_attention import JointMVTransformerBlock if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True else: _LOW_CPU_MEM_USAGE_DEFAULT = False # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel # The only modifications: `JointTransformerBlock` -> `JointMVTransformerBlock` class SD3TransformerMV2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ): """ The Transformer model introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: sample_size (`int`): The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. out_channels (`int`, defaults to 16): Number of output channels. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: int = 128, patch_size: int = 2, in_channels: int = 16, num_layers: int = 18, attention_head_dim: int = 64, num_attention_heads: int = 18, joint_attention_dim: int = 4096, caption_projection_dim: int = 1152, pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, dual_attention_layers: Tuple[ int, ... ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = PatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, # hard-code for now. ) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) # `attention_head_dim` is doubled to account for the mixing. # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.ModuleList( [ JointMVTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(self.config.num_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. skip_layers (`list` of `int`, *optional*): A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, joint_attention_kwargs, **ckpt_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] temb = repeat(temb, "b d -> (b v) d", v=joint_attention_kwargs.get("num_views", 1)) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) # unpatchify patch_size = self.config.patch_size height = height // patch_size width = width // patch_size hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) ) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) # Copied from diffusers.models.modeling_utils.ModelingMixin.from_pretrained @classmethod @validate_hf_hub_args def from_pretrained_new( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], sample_size: int = 32, # `input_res` / 8 in_channels: int = 16, out_channels: int = 16, zero_init_conv_in: bool = True, view_concat_condition: bool = False, input_concat_plucker: bool = False, input_concat_binary_mask: bool = False, from_scratch: bool = False, # do not load pretrained parameters **kwargs ): cache_dir = kwargs.pop("cache_dir", None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # change device_map into a map if we passed an int, a str or a torch.device if isinstance(device_map, torch.device): device_map = {"": device_map} elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: try: device_map = {"": torch.device(device_map)} except RuntimeError: raise ValueError( "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." ) elif isinstance(device_map, int): if device_map < 0: raise ValueError( "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " ) else: device_map = {"": device_map} if device_map is not None: if low_cpu_mem_usage is None: low_cpu_mem_usage = True elif not low_cpu_mem_usage: raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") if low_cpu_mem_usage: if device_map is not None and not is_torch_version(">=", "1.10"): # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, **kwargs, ) # Modify configs for the multi-view cross-domain diffusion model config["_class_name"] = cls.__name__ config["sample_size"] = sample_size # training resolution config["in_channels"] = in_channels config["out_channels"] = out_channels config["view_concat_condition"] = view_concat_condition config["input_concat_plucker"] = input_concat_plucker config["input_concat_binary_mask"] = input_concat_binary_mask # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False index_file = None is_local = os.path.isdir(pretrained_model_name_or_path) index_file = _fetch_index_file( is_local=is_local, pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder or "", use_safetensors=use_safetensors, cache_dir=cache_dir, variant=variant, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, user_agent=user_agent, commit_hash=commit_hash, ) if index_file is not None and index_file.is_file(): is_sharded = True if is_sharded and from_flax: raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") # load model model_file = None if from_flax: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) # Convert the weights from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model if not from_scratch: model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: if is_sharded: sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, index_file, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, user_agent=user_agent, revision=revision, subfolder=subfolder or "", ) elif use_safetensors and not is_sharded: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") if not allow_pickle: raise logger.warning( "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) if model_file is None and not is_sharded: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) if not from_scratch: # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) unexpected_keys = load_model_dict_into_meta( model, state_dict, device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_name_or_path, ) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU force_hook = True device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) if device_map is None and is_sharded: # we load the parameters on the cpu device_map = {"": "cpu"} force_hook = False try: accelerate.load_checkpoint_and_dispatch( model, model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, force_hooks=force_hook, strict=True, ) except AttributeError as e: # When using accelerate loading, we do not have the ability to load the state # dict and rename the weight names manually. Additionally, accelerate skips # torch loading conventions and directly writes into `module.{_buffers, _parameters}` # (which look like they should be private variables?), so we can't use the standard hooks # to rename parameters on load. We need to mimic the original weight names so the correct # attributes are available. After we have loaded the weights, we convert the deprecated # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert # the weights so we don't have to do this again. if "'Attention' object has no attribute" in str(e): logger.warning( f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," " please also re-upload it or open a PR on the original repository." ) model._temp_convert_self_to_deprecated_attention_blocks() accelerate.load_checkpoint_and_dispatch( model, model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, force_hooks=force_hook, strict=True, ) model._undo_temp_convert_self_to_deprecated_attention_blocks() else: raise e loading_info = { "missing_keys": [], "unexpected_keys": [], "mismatched_keys": [], "error_msgs": [], } else: model = cls.from_config(config, **unused_kwargs) if not from_scratch: state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) state_dict_original = copy.deepcopy(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, ) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } else: loading_info = { "missing_keys": [], "unexpected_keys": [], "mismatched_keys": [], "error_msgs": [], } if not from_scratch: # Handle initilizations for some layers ## Patch embedding conv pos_embed_proj_weight = state_dict_original["pos_embed.proj.weight"] latent_channels = pos_embed_proj_weight.shape[1] if model.pos_embed.proj.weight.data.shape[1] != latent_channels: # Initialize from the original weights model.pos_embed.proj.weight.data[:, :latent_channels] = pos_embed_proj_weight # Whether to place all zero to new layers ? if zero_init_conv_in: model.pos_embed.proj.weight.data[:, latent_channels:] = 0 if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model