|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers import SeamlessM4TModel |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SeamlessCrossAttentionConfig(PretrainedConfig): |
|
"""Configuration class for SeamlessCrossAttention model.""" |
|
|
|
model_type = "seamless_crossattention" |
|
|
|
def __init__( |
|
self, |
|
seamless_model_name="facebook/hf-seamless-m4t-medium", |
|
hidden_size=1024, |
|
dropout_prob=0.1, |
|
num_attention_heads=8, |
|
embedding_regularization=0.0, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.seamless_model_name = seamless_model_name |
|
self.hidden_size = hidden_size |
|
self.dropout_prob = dropout_prob |
|
self.num_attention_heads = num_attention_heads |
|
self.embedding_regularization = embedding_regularization |
|
|
|
|
|
class ScalarMix(nn.Module): |
|
"""Scalar mixing layer for combining multiple embeddings.""" |
|
|
|
def __init__(self, num_inputs=4): |
|
super().__init__() |
|
self.weights = nn.Parameter(torch.ones(num_inputs)) |
|
self.gamma = nn.Parameter(torch.tensor(1.0)) |
|
|
|
def forward(self, *tensors): |
|
|
|
weights = F.softmax(self.weights, dim=0) |
|
|
|
|
|
weighted_sum = sum(w * t for w, t in zip(weights, tensors)) |
|
|
|
|
|
return self.gamma * weighted_sum |
|
|
|
|
|
class HFSeamlessCrossAttention(PreTrainedModel): |
|
"""SeamlessM4T model with cross attention for HuggingFace Hub.""" |
|
|
|
config_class = SeamlessCrossAttentionConfig |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name) |
|
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder |
|
self.seamless_model_text_encoder = self.seamless_model.text_encoder |
|
|
|
|
|
for param in self.seamless_model_speech_encoder.parameters(): |
|
param.requires_grad = False |
|
for param in self.seamless_model_text_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.audio_proj = nn.Linear( |
|
self.seamless_model_speech_encoder.config.hidden_size, |
|
config.hidden_size |
|
) |
|
self.text_proj = nn.Linear( |
|
self.seamless_model_text_encoder.config.hidden_size, |
|
config.hidden_size |
|
) |
|
|
|
|
|
self.audio_norm = nn.LayerNorm(config.hidden_size) |
|
self.text_norm = nn.LayerNorm(config.hidden_size) |
|
|
|
|
|
self.audio_to_text_attention = nn.MultiheadAttention( |
|
embed_dim=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
dropout=config.dropout_prob, |
|
batch_first=True |
|
) |
|
|
|
self.text_to_audio_attention = nn.MultiheadAttention( |
|
embed_dim=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
dropout=config.dropout_prob, |
|
batch_first=True |
|
) |
|
|
|
|
|
self.scalar_mix = ScalarMix(num_inputs=4) |
|
|
|
|
|
self.fc = nn.Sequential( |
|
nn.Linear(config.hidden_size, 512), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(256, 1) |
|
) |
|
|
|
|
|
self._initialize_new_layers() |
|
|
|
def _initialize_new_layers(self): |
|
"""Initialize new layers with proper weights.""" |
|
for module in [self.audio_proj, self.text_proj, self.fc]: |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Sequential): |
|
for layer in module: |
|
if isinstance(layer, nn.Linear): |
|
nn.init.xavier_uniform_(layer.weight) |
|
nn.init.zeros_(layer.bias) |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
input_ids, |
|
text_attention_mask, |
|
audio_attention_mask=None, |
|
labels=None, |
|
**kwargs |
|
): |
|
|
|
if audio_attention_mask is None: |
|
audio_attention_mask = torch.ones( |
|
input_features.size(0), input_features.size(1), |
|
device=input_features.device |
|
) |
|
|
|
|
|
audio_output = self.seamless_model_speech_encoder( |
|
input_features=input_features, |
|
attention_mask=audio_attention_mask |
|
) |
|
audio_hidden_states = audio_output.last_hidden_state |
|
|
|
|
|
text_output = self.seamless_model_text_encoder( |
|
input_ids=input_ids, |
|
attention_mask=text_attention_mask |
|
) |
|
text_hidden_states = text_output.last_hidden_state |
|
|
|
|
|
audio_projected = self.audio_proj(audio_hidden_states) |
|
text_projected = self.text_proj(text_hidden_states) |
|
|
|
audio_projected = self.audio_norm(audio_projected) |
|
text_projected = self.text_norm(text_projected) |
|
|
|
|
|
audio_global = audio_projected.mean(dim=1) |
|
text_global = text_projected.mean(dim=1) |
|
|
|
|
|
|
|
audio_attended_to_text, _ = self.audio_to_text_attention( |
|
query=audio_projected, |
|
key=text_projected, |
|
value=text_projected, |
|
) |
|
|
|
|
|
text_attended_to_audio, _ = self.text_to_audio_attention( |
|
query=text_projected, |
|
key=audio_projected, |
|
value=audio_projected, |
|
) |
|
|
|
|
|
audio_attended_emb = audio_attended_to_text.mean(dim=1) |
|
text_attended_emb = text_attended_to_audio.mean(dim=1) |
|
|
|
|
|
final_embedding = self.scalar_mix( |
|
audio_global, |
|
text_global, |
|
audio_attended_emb, |
|
text_attended_emb |
|
) |
|
|
|
|
|
logits = self.fc(final_embedding).squeeze(-1) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
mse_loss = F.mse_loss(logits, labels) |
|
|
|
|
|
reg_loss = 0.0 |
|
if self.config.embedding_regularization > 0: |
|
reg_loss = ( |
|
torch.norm(audio_global, p=2, dim=1).mean() + |
|
torch.norm(text_global, p=2, dim=1).mean() + |
|
torch.norm(audio_attended_emb, p=2, dim=1).mean() + |
|
torch.norm(text_attended_emb, p=2, dim=1).mean() |
|
) / 4.0 |
|
|
|
loss = mse_loss + self.config.embedding_regularization * reg_loss |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None |
|
) |