# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS """ MIT License Copyright (c) 2023 Shivam Mehta Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from typing import Any, Dict, Optional import torch import torch.nn as nn from diffusers.models.attention import ( GEGLU, GELU, AdaLayerNorm, AdaLayerNormZero, ApproximateGELU, ) from diffusers.models.attention_processor import Attention from diffusers.models.lora import LoRACompatibleLinear from diffusers.utils.torch_utils import maybe_allow_in_graph import torch.nn.functional as F from flash_attn import flash_attn_varlen_func def get_sequence_mask(inputs, inputs_length): if inputs.dim() == 3: bsz, tgt_len, _ = inputs.size() else: bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length) sequence_mask = torch.arange(0, tgt_len).to(inputs.device) sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view( bsz, tgt_len, 1 ) unpacking_index = ( torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 ) # 转成下标 return sequence_mask, unpacking_index class OmniWhisperAttention(nn.Module): def __init__(self, embed_dim, num_heads, causal=False): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.causal = causal def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor): bsz, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view( bsz, self.num_heads, self.head_dim ) key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim) value_states = self.v_proj(hidden_states).view( bsz, self.num_heads, self.head_dim ) cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to( torch.int32 ) max_seqlen = torch.max(seq_len).to(torch.int32).detach() attn_output = flash_attn_varlen_func( query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=self.causal, ) # (bsz * qlen, nheads, headdim) attn_output = attn_output.reshape(bsz, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output class SnakeBeta(nn.Module): """ A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snakebeta(256) >>> x = torch.randn(256) >>> x = a1(x) """ def __init__( self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True, ): """ Initialization. INPUT: - in_features: shape of the input - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. """ super().__init__() self.in_features = ( out_features if isinstance(out_features, list) else [out_features] ) self.proj = LoRACompatibleLinear(in_features, out_features) # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): """ Forward pass of the function. Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) """ x = self.proj(x) if self.alpha_logscale: alpha = torch.exp(self.alpha) beta = torch.exp(self.beta) else: alpha = self.alpha beta = self.beta x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( torch.sin(x * alpha), 2 ) return x class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) elif activation_fn == "snakebeta": act_fn = SnakeBeta(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, use_omni_attn: bool = False, ): super().__init__() self.use_omni_attn = use_omni_attn self.dim = dim self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = ( num_embeds_ada_norm is not None ) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = ( num_embeds_ada_norm is not None ) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if self.use_omni_attn: if only_cross_attention: raise NotImplementedError print( "Use OmniWhisperAttention with flash attention. Dropout is ignored." ) self.attn1 = OmniWhisperAttention( embed_dim=dim, num_heads=num_attention_heads, causal=False ) else: self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=( cross_attention_dim if only_cross_attention else None ), upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) self.attn2 = Attention( query_dim=dim, cross_attention_dim=( cross_attention_dim if not double_self_attention else None ), heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, # scale_qk=False, # uncomment this to not to use flash attention ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, ) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): bsz, tgt_len, d_model = hidden_states.shape # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = ( cross_attention_kwargs if cross_attention_kwargs is not None else {} ) if self.use_omni_attn: seq_len = attention_mask[:, 0, :].float().long().sum(dim=1) var_len_attention_mask, unpacking_index = get_sequence_mask( norm_hidden_states, seq_len ) norm_hidden_states = torch.masked_select( norm_hidden_states, var_len_attention_mask ) norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim) attn_output = self.attn1(norm_hidden_states, seq_len) # unpacking attn_output = torch.index_select(attn_output, 0, unpacking_index).view( bsz, tgt_len, d_model ) attn_output = torch.where(var_len_attention_mask, attn_output, 0) else: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=( encoder_hidden_states if self.only_cross_attention else None ), attention_mask=( encoder_attention_mask if self.only_cross_attention else attention_mask ), **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = ( norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ) if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( [ self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk( num_chunks, dim=self._chunk_dim ) ], dim=self._chunk_dim, ) else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states