Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| import torch.nn as nn | |
| import os | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.utils import is_torch_version | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| from tqdm import tqdm | |
| from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings | |
| from .modeling_normalization import AdaLayerNormContinuous | |
| from .modeling_mmdit_block import JointTransformerBlock | |
| from trainer_misc import ( | |
| is_sequence_parallel_initialized, | |
| get_sequence_parallel_group, | |
| get_sequence_parallel_world_size, | |
| get_sequence_parallel_rank, | |
| all_to_all, | |
| ) | |
| from IPython import embed | |
| def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: | |
| assert dim % 2 == 0, "The dimension must be even." | |
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | |
| omega = 1.0 / (theta**scale) | |
| batch_size, seq_length = pos.shape | |
| out = torch.einsum("...n,d->...nd", pos, omega) | |
| cos_out = torch.cos(out) | |
| sin_out = torch.sin(out) | |
| stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) | |
| out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) | |
| return out.float() | |
| class EmbedNDRoPE(nn.Module): | |
| def __init__(self, dim: int, theta: int, axes_dim: List[int]): | |
| super().__init__() | |
| self.dim = dim | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | |
| n_axes = ids.shape[-1] | |
| emb = torch.cat( | |
| [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], | |
| dim=-3, | |
| ) | |
| return emb.unsqueeze(2) | |
| class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin): | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| sample_size: int = 128, | |
| patch_size: int = 2, | |
| in_channels: int = 16, | |
| num_layers: int = 24, | |
| attention_head_dim: int = 64, | |
| num_attention_heads: int = 24, | |
| caption_projection_dim: int = 1152, | |
| pooled_projection_dim: int = 2048, | |
| pos_embed_max_size: int = 192, | |
| max_num_frames: int = 200, | |
| qk_norm: str = 'rms_norm', | |
| pos_embed_type: str = 'rope', | |
| temp_pos_embed_type: str = 'sincos', | |
| joint_attention_dim: int = 4096, | |
| use_gradient_checkpointing: bool = False, | |
| use_flash_attn: bool = True, | |
| use_temporal_causal: bool = False, | |
| use_t5_mask: bool = False, | |
| add_temp_pos_embed: bool = False, | |
| interp_condition_pos: bool = False, | |
| ): | |
| super().__init__() | |
| self.out_channels = in_channels | |
| self.inner_dim = num_attention_heads * attention_head_dim | |
| assert temp_pos_embed_type in ['rope', 'sincos'] | |
| # The input latent embeder, using the name pos_embed to remain the same with SD# | |
| self.pos_embed = PatchEmbed3D( | |
| height=sample_size, | |
| width=sample_size, | |
| patch_size=patch_size, | |
| in_channels=in_channels, | |
| embed_dim=self.inner_dim, | |
| pos_embed_max_size=pos_embed_max_size, # hard-code for now. | |
| max_num_frames=max_num_frames, | |
| pos_embed_type=pos_embed_type, | |
| temp_pos_embed_type=temp_pos_embed_type, | |
| add_temp_pos_embed=add_temp_pos_embed, | |
| interp_condition_pos=interp_condition_pos, | |
| ) | |
| # The RoPE EMbedding | |
| if pos_embed_type == 'rope': | |
| self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24]) | |
| else: | |
| self.rope_embed = None | |
| if temp_pos_embed_type == 'rope': | |
| self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim]) | |
| else: | |
| self.temp_rope_embed = None | |
| self.time_text_embed = CombinedTimestepConditionEmbeddings( | |
| embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim, | |
| ) | |
| self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| JointTransformerBlock( | |
| dim=self.inner_dim, | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=self.inner_dim, | |
| qk_norm=qk_norm, | |
| context_pre_only=i == num_layers - 1, | |
| use_flash_attn=use_flash_attn, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
| self.gradient_checkpointing = use_gradient_checkpointing | |
| self.patch_size = patch_size | |
| self.use_flash_attn = use_flash_attn | |
| self.use_temporal_causal = use_temporal_causal | |
| self.pos_embed_type = pos_embed_type | |
| self.temp_pos_embed_type = temp_pos_embed_type | |
| self.add_temp_pos_embed = add_temp_pos_embed | |
| if self.use_temporal_causal: | |
| print("Using temporal causal attention") | |
| assert self.use_flash_attn is False, "The flash attention does not support temporal causal" | |
| if interp_condition_pos: | |
| print("We interp the position embedding of condition latents") | |
| # init weights | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): | |
| w = self.pos_embed.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(self.pos_embed.proj.bias, 0) | |
| # Initialize all the conditioning to normal init | |
| nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02) | |
| nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02) | |
| nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02) | |
| nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02) | |
| nn.init.normal_(self.context_embedder.weight, std=0.02) | |
| # Zero-out adaLN modulation layers in DiT blocks: | |
| for block in self.transformer_blocks: | |
| nn.init.constant_(block.norm1.linear.weight, 0) | |
| nn.init.constant_(block.norm1.linear.bias, 0) | |
| nn.init.constant_(block.norm1_context.linear.weight, 0) | |
| nn.init.constant_(block.norm1_context.linear.bias, 0) | |
| # Zero-out output layers: | |
| nn.init.constant_(self.norm_out.linear.weight, 0) | |
| nn.init.constant_(self.norm_out.linear.bias, 0) | |
| nn.init.constant_(self.proj_out.weight, 0) | |
| nn.init.constant_(self.proj_out.bias, 0) | |
| def _prepare_latent_image_ids(self, batch_size, temp, height, width, device): | |
| latent_image_ids = torch.zeros(temp, height, width, 3) | |
| latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None] | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :] | |
| latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) | |
| latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c') | |
| return latent_image_ids.to(device=device) | |
| def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device): | |
| base_width = width_list[-1]; base_height = height_list[-1] | |
| assert base_width == max(width_list) | |
| assert base_height == max(height_list) | |
| image_ids_list = [] | |
| for temp, height, width in zip(temp_list, height_list, width_list): | |
| latent_image_ids = torch.zeros(temp, height, width, 3) | |
| if height != base_height: | |
| height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1) | |
| else: | |
| height_pos = torch.arange(base_height).float() | |
| if width != base_width: | |
| width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1) | |
| else: | |
| width_pos = torch.arange(base_width).float() | |
| latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None] | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :] | |
| latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) | |
| latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device) | |
| image_ids_list.append(latent_image_ids) | |
| return image_ids_list | |
| def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0): | |
| latent_image_ids = torch.zeros(temp, height, width, 1) | |
| latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None] | |
| latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1) | |
| latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c') | |
| return latent_image_ids.to(device=device) | |
| def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device): | |
| image_ids_list = [] | |
| for i_b, sample_ in enumerate(sample): | |
| if not isinstance(sample_, list): | |
| sample_ = [sample_] | |
| cur_image_ids = [] | |
| start_time_stamp = 0 | |
| for clip_ in sample_: | |
| _, _, temp, height, width = clip_.shape | |
| height = height // self.patch_size | |
| width = width // self.patch_size | |
| cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp)) | |
| start_time_stamp += temp | |
| cur_image_ids = torch.cat(cur_image_ids, dim=1) | |
| image_ids_list.append(cur_image_ids) | |
| return image_ids_list | |
| def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask): | |
| """ | |
| Merge the input video with different resolutions into one sequence | |
| Sample: From low resolution to high resolution | |
| """ | |
| if isinstance(sample[0], list): | |
| device = sample[0][-1].device | |
| pad_batch_size = sample[0][-1].shape[0] | |
| else: | |
| device = sample[0].device | |
| pad_batch_size = sample[0].shape[0] | |
| num_stages = len(sample) | |
| height_list = [];width_list = [];temp_list = [] | |
| trainable_token_list = [] | |
| for i_b, sample_ in enumerate(sample): | |
| if isinstance(sample_, list): | |
| sample_ = sample_[-1] | |
| _, _, temp, height, width = sample_.shape | |
| height = height // self.patch_size | |
| width = width // self.patch_size | |
| temp_list.append(temp) | |
| height_list.append(height) | |
| width_list.append(width) | |
| trainable_token_list.append(height * width * temp) | |
| # prepare the RoPE embedding if needed | |
| if self.pos_embed_type == 'rope': | |
| # TODO: support the 3D Rope for video | |
| raise NotImplementedError("Not compatible with video generation now") | |
| text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device) | |
| image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device) | |
| input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list] | |
| image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2] | |
| else: | |
| if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed: | |
| image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device) | |
| text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device) | |
| input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list] | |
| image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2] | |
| if is_sequence_parallel_initialized(): | |
| sp_group = get_sequence_parallel_group() | |
| sp_group_size = get_sequence_parallel_world_size() | |
| image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for x_ in image_rotary_emb] | |
| input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) for input_ids in input_ids_list] | |
| else: | |
| image_rotary_emb = None | |
| hidden_states = self.pos_embed(sample) # hidden states is a list of [b c t h w] b = real_b // num_stages | |
| hidden_length = [] | |
| for i_b in range(num_stages): | |
| hidden_length.append(hidden_states[i_b].shape[1]) | |
| # prepare the attention mask | |
| if self.use_flash_attn: | |
| attention_mask = None | |
| indices_list = [] | |
| for i_p, length in enumerate(hidden_length): | |
| pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device) | |
| pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1) | |
| if is_sequence_parallel_initialized(): | |
| sp_group = get_sequence_parallel_group() | |
| sp_group_size = get_sequence_parallel_world_size() | |
| pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0) | |
| pad_attention_mask = pad_attention_mask.squeeze(2) | |
| seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten() | |
| indices_list.append( | |
| { | |
| 'indices': indices, | |
| 'seqlens_in_batch': seqlens_in_batch, | |
| } | |
| ) | |
| encoder_attention_mask = indices_list | |
| else: | |
| assert encoder_attention_mask.shape[1] == encoder_hidden_length | |
| real_batch_size = encoder_attention_mask.shape[0] | |
| # prepare text ids | |
| text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length) | |
| text_ids = text_ids.to(device) | |
| text_ids[encoder_attention_mask == 0] = 0 | |
| # prepare image ids | |
| image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length)) | |
| image_ids = image_ids.to(device) | |
| image_ids_list = [] | |
| for i_p, length in enumerate(hidden_length): | |
| image_ids_list.append(image_ids[i_p::num_stages][:, :length]) | |
| if is_sequence_parallel_initialized(): | |
| sp_group = get_sequence_parallel_group() | |
| sp_group_size = get_sequence_parallel_world_size() | |
| text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2) | |
| image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0).squeeze(2) for image_ids_ in image_ids_list] | |
| attention_mask = [] | |
| for i_p in range(len(hidden_length)): | |
| image_ids = image_ids_list[i_p] | |
| token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1) | |
| stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len] | |
| if self.use_temporal_causal: | |
| input_order_ids = input_ids_list[i_p].squeeze(2) | |
| temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j') | |
| stage_attention_mask = stage_attention_mask & temporal_causal_mask | |
| attention_mask.append(stage_attention_mask) | |
| return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb | |
| def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list): | |
| # To split the hidden states | |
| batch_size = batch_hidden_states.shape[0] | |
| output_hidden_list = [] | |
| batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1) | |
| if is_sequence_parallel_initialized(): | |
| sp_group_size = get_sequence_parallel_world_size() | |
| batch_size = batch_size // sp_group_size | |
| for i_p, length in enumerate(hidden_length): | |
| width, height, temp = widths[i_p], heights[i_p], temps[i_p] | |
| trainable_token_num = trainable_token_list[i_p] | |
| hidden_states = batch_hidden_states[i_p] | |
| if is_sequence_parallel_initialized(): | |
| sp_group = get_sequence_parallel_group() | |
| sp_group_size = get_sequence_parallel_world_size() | |
| hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1) | |
| # only the trainable token are taking part in loss computation | |
| hidden_states = hidden_states[:, -trainable_token_num:] | |
| # unpatchify | |
| hidden_states = hidden_states.reshape( | |
| shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels) | |
| ) | |
| hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c") | |
| hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w") | |
| output_hidden_list.append(hidden_states) | |
| return output_hidden_list | |
| def forward( | |
| self, | |
| sample: torch.FloatTensor, # [num_stages] | |
| encoder_hidden_states: torch.FloatTensor = None, | |
| encoder_attention_mask: torch.FloatTensor = None, | |
| pooled_projections: torch.FloatTensor = None, | |
| timestep_ratio: torch.FloatTensor = None, | |
| ): | |
| # Get the timestep embedding | |
| temb = self.time_text_embed(timestep_ratio, pooled_projections) | |
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | |
| encoder_hidden_length = encoder_hidden_states.shape[1] | |
| # Get the input sequence | |
| hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \ | |
| attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask) | |
| # split the long latents if necessary | |
| if is_sequence_parallel_initialized(): | |
| sp_group = get_sequence_parallel_group() | |
| sp_group_size = get_sequence_parallel_world_size() | |
| # sync the input hidden states | |
| batch_hidden_states = [] | |
| for i_p, hidden_states_ in enumerate(hidden_states): | |
| assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size" | |
| hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0) | |
| hidden_length[i_p] = hidden_length[i_p] // sp_group_size | |
| batch_hidden_states.append(hidden_states_) | |
| # sync the encoder hidden states | |
| hidden_states = torch.cat(batch_hidden_states, dim=1) | |
| encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0) | |
| temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0) | |
| temb = temb.squeeze(1) | |
| else: | |
| hidden_states = torch.cat(hidden_states, dim=1) | |
| # print(hidden_length) | |
| for i_b, block in enumerate(self.transformer_blocks): | |
| if self.training and self.gradient_checkpointing and (i_b >= 2): | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| temb, | |
| attention_mask, | |
| hidden_length, | |
| image_rotary_emb, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| encoder_hidden_states, hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| temb=temb, | |
| attention_mask=attention_mask, | |
| hidden_length=hidden_length, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length) | |
| hidden_states = self.proj_out(hidden_states) | |
| output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list) | |
| return output | |
 
			
