import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.auto import AutoModel from transformers.modeling_utils import PreTrainedModel # from transformers.modeling_layers import GradientCheckpointingLayer from transformers.activations import ACT2FN from transformers.utils import logging from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig logger = logging.get_logger(__name__) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): super().__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.register_parameter('weight', None) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) if self.weight is not None: output = output * self.weight return output def extra_repr(self) -> str: return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' def modulate(x, shift, scale): """Apply modulation to input tensor.""" return x * (1 + scale) + shift class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. Args: hidden_size (`int`): Size of the output embedding frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=False), # nn.SiLU(), ACT2FN['silu'], nn.Linear(hidden_size, hidden_size, bias=False), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. Args: t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element. These may be fractional. dim (`int`): The dimension of the output. max_period (`int`, optional): Controls the minimum frequency of the embeddings. Returns: `torch.Tensor`: An [N, D] Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding.to(t.dtype) def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class FeedForwardNetwork(nn.Module): """ Standard feed-forward network with SwiGLU activation. Args: embed_dim (`int`): Input dimension ffn_dim (`int`): Hidden dimension """ def __init__( self, embed_dim, ffn_dim, ): super().__init__() self.embed_dim = embed_dim self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function def forward(self, x): gate = self.gate_proj(x) up = self.up_proj(x) # SwiGLU activation # gate = F.silu(gate) gate = self.act_fn(gate) return self.down_proj(gate * up) class HeadLayer(nn.Module): """ A layer in the diffusion head. Args: embed_dim (`int`): Input dimension ffn_dim (`int`): Hidden dimension cond_dim (`int`): Condition embedding dimension norm_eps (`float`, optional): Epsilon for normalization """ def __init__( self, embed_dim, ffn_dim, cond_dim, norm_eps=1e-5, ): super().__init__() self.embed_dim = embed_dim self.cond_dim = cond_dim self.ffn_dim = ffn_dim self.ffn = FeedForwardNetwork( self.embed_dim, self.ffn_dim, ) self.norm = RMSNorm(self.embed_dim, eps=norm_eps) self.adaLN_modulation = nn.Sequential( # nn.SiLU(), ACT2FN['silu'], nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) ) def forward(self, x, c): shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) return x class FinalLayer(nn.Module): """ Final layer in the diffusion head. Args: hidden_size (`int`): Input dimension output_size (`int`): Output dimension cond_size (`int`): Condition embedding dimension norm_eps (`float`, optional): Epsilon for normalization """ def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5): super().__init__() self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) self.linear = nn.Linear(hidden_size, output_size, bias=False) self.adaLN_modulation = nn.Sequential( # nn.SiLU(), ACT2FN['silu'], nn.Linear(cond_size, 2 * hidden_size, bias=False) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class VibeVoiceDiffusionHead(PreTrainedModel): """ Diffusion head model for vibevoice. Args: config (`VibeVoiceDiffusionHeadConfig`): Model configuration latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`. """ config_class = VibeVoiceDiffusionHeadConfig supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True def __init__( self, config, ): super().__init__(config) self.config = config self.cond_dim = config.hidden_size latent_size = config.latent_size self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) self.t_embedder = TimestepEmbedder(self.cond_dim) ffn_dim = int(config.hidden_size * config.head_ffn_ratio) # Create the intermediate layers self.layers = nn.ModuleList([ HeadLayer( embed_dim=config.hidden_size, ffn_dim=ffn_dim, cond_dim=self.cond_dim, norm_eps=config.rms_norm_eps ) for _ in range(config.head_layers) ]) # Final layer for output self.final_layer = FinalLayer( hidden_size=config.hidden_size, output_size=latent_size, cond_size=self.cond_dim, norm_eps=config.rms_norm_eps ) self.initialize_weights() def initialize_weights(self): """Initialize the weights of the model.""" # Initialize timestep embedder nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers for layer in self.layers: nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) # Zero-out output layers nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.linear.weight, 0) def forward( self, noisy_images, timesteps, condition, ): """ Forward pass of the prediction head. Args: noisy_images (`torch.Tensor`): Noisy images/latents to denoise timesteps (`torch.Tensor`): Timesteps for diffusion condition (`torch.Tensor`): Conditioning information Returns: `torch.Tensor`: The predicted noise/velocity """ x = self.noisy_images_proj(noisy_images) t = self.t_embedder(timesteps) condition = self.cond_proj(condition) c = condition + t for layer in self.layers: x = layer(x, c) x = self.final_layer(x, c) return x AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead) __all__ = [ "VibeVoiceDiffusionHead", ]