seamless-crossattention / modeling_seamless_crossattention.py
giuseppe-tanzi's picture
Upload folder using huggingface_hub
988bd5d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import SeamlessM4TModel
import logging
logger = logging.getLogger(__name__)
class SeamlessCrossAttentionConfig(PretrainedConfig):
"""Configuration class for SeamlessCrossAttention model."""
model_type = "seamless_crossattention"
def __init__(
self,
seamless_model_name="facebook/hf-seamless-m4t-medium",
hidden_size=1024,
dropout_prob=0.1,
num_attention_heads=8,
embedding_regularization=0.0,
**kwargs
):
super().__init__(**kwargs)
self.seamless_model_name = seamless_model_name
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.num_attention_heads = num_attention_heads
self.embedding_regularization = embedding_regularization
class ScalarMix(nn.Module):
"""Scalar mixing layer for combining multiple embeddings."""
def __init__(self, num_inputs=4):
super().__init__()
self.weights = nn.Parameter(torch.ones(num_inputs))
self.gamma = nn.Parameter(torch.tensor(1.0))
def forward(self, *tensors):
# Normalize weights with softmax
weights = F.softmax(self.weights, dim=0)
# Weighted sum
weighted_sum = sum(w * t for w, t in zip(weights, tensors))
# Scale by gamma
return self.gamma * weighted_sum
class HFSeamlessCrossAttention(PreTrainedModel):
"""SeamlessM4T model with cross attention for HuggingFace Hub."""
config_class = SeamlessCrossAttentionConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
# Load the underlying SeamlessM4T model
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
self.seamless_model_text_encoder = self.seamless_model.text_encoder
# Freeze pre-trained models
for param in self.seamless_model_speech_encoder.parameters():
param.requires_grad = False
for param in self.seamless_model_text_encoder.parameters():
param.requires_grad = False
# Projection layers
self.audio_proj = nn.Linear(
self.seamless_model_speech_encoder.config.hidden_size,
config.hidden_size
)
self.text_proj = nn.Linear(
self.seamless_model_text_encoder.config.hidden_size,
config.hidden_size
)
# Layer norms
self.audio_norm = nn.LayerNorm(config.hidden_size)
self.text_norm = nn.LayerNorm(config.hidden_size)
# Cross-attention layers
self.audio_to_text_attention = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.dropout_prob,
batch_first=True
)
self.text_to_audio_attention = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.dropout_prob,
batch_first=True
)
# Scalar mix for combining embeddings
self.scalar_mix = ScalarMix(num_inputs=4)
# Enhanced classifier with residual connections
self.fc = nn.Sequential(
nn.Linear(config.hidden_size, 512),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(256, 1)
)
# Initialize new layers
self._initialize_new_layers()
def _initialize_new_layers(self):
"""Initialize new layers with proper weights."""
for module in [self.audio_proj, self.text_proj, self.fc]:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Sequential):
for layer in module:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
def forward(
self,
input_features,
input_ids,
text_attention_mask,
audio_attention_mask=None,
labels=None,
**kwargs # Accept additional features but ignore them
):
# Create default audio attention mask if not provided
if audio_attention_mask is None:
audio_attention_mask = torch.ones(
input_features.size(0), input_features.size(1),
device=input_features.device
)
# 1. Encode audio
audio_output = self.seamless_model_speech_encoder(
input_features=input_features,
attention_mask=audio_attention_mask
)
audio_hidden_states = audio_output.last_hidden_state # [batch_size, audio_seq_len, hidden_size]
# 2. Encode text
text_output = self.seamless_model_text_encoder(
input_ids=input_ids,
attention_mask=text_attention_mask
)
text_hidden_states = text_output.last_hidden_state # [batch_size, text_seq_len, hidden_size]
# 3. Project to common dimension
audio_projected = self.audio_proj(audio_hidden_states) # [batch_size, audio_seq_len, hidden_size]
text_projected = self.text_proj(text_hidden_states) # [batch_size, text_seq_len, hidden_size]
audio_projected = self.audio_norm(audio_projected)
text_projected = self.text_norm(text_projected)
# 4. Global pooling (mean) of original embeddings
audio_global = audio_projected.mean(dim=1) # [batch_size, hidden_size]
text_global = text_projected.mean(dim=1) # [batch_size, hidden_size]
# 5. Cross-attention with masks
# Audio attends to text - each audio token attends to all text tokens
audio_attended_to_text, _ = self.audio_to_text_attention(
query=audio_projected, # [batch_size, audio_seq_len, hidden_size]
key=text_projected, # [batch_size, text_seq_len, hidden_size]
value=text_projected, # [batch_size, text_seq_len, hidden_size]
)
# Text attends to audio - each text token attends to all audio tokens
text_attended_to_audio, _ = self.text_to_audio_attention(
query=text_projected, # [batch_size, text_seq_len, hidden_size]
key=audio_projected, # [batch_size, audio_seq_len, hidden_size]
value=audio_projected, # [batch_size, audio_seq_len, hidden_size]
)
# 6. Global pooling (mean) of attended embeddings
audio_attended_emb = audio_attended_to_text.mean(dim=1) # [batch_size, hidden_size]
text_attended_emb = text_attended_to_audio.mean(dim=1) # [batch_size, hidden_size]
# 7. Combine with scalar mix
final_embedding = self.scalar_mix(
audio_global,
text_global,
audio_attended_emb,
text_attended_emb
)
# 8. Classification
logits = self.fc(final_embedding).squeeze(-1)
# Compute loss if labels are provided
loss = None
if labels is not None:
mse_loss = F.mse_loss(logits, labels)
# Add embedding regularization if specified
reg_loss = 0.0
if self.config.embedding_regularization > 0:
reg_loss = (
torch.norm(audio_global, p=2, dim=1).mean() +
torch.norm(text_global, p=2, dim=1).mean() +
torch.norm(audio_attended_emb, p=2, dim=1).mean() +
torch.norm(text_attended_emb, p=2, dim=1).mean()
) / 4.0
loss = mse_loss + self.config.embedding_regularization * reg_loss
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None
)