giuseppe-tanzi commited on
Commit
988bd5d
·
verified ·
1 Parent(s): 12d3a8b

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - multilingual
4
+ tags:
5
+ - audio
6
+ - text
7
+ - multimodal
8
+ - seamless
9
+ - subtitle-editing-time-prediction
10
+ - cross-attention
11
+ - attention-mechanism
12
+ library_name: transformers
13
+ base_model: facebook/hf-seamless-m4t-medium
14
+ ---
15
+
16
+ # videoloc/seamless-crossattention
17
+
18
+ ## Model Description
19
+
20
+ This is a **SeamlessCrossAttention** model that processes audio and text inputs with advanced cross-modal attention mechanisms to predict **Time To Edit (TTE)** for subtitle segments. Given an audio segment and its corresponding subtitle text, the model predicts how much time (in seconds) would be required to edit/refine that subtitle segment, leveraging sophisticated cross-attention patterns between audio and text modalities.
21
+
22
+ The model extends the SeamlessM4T architecture with bidirectional cross-attention layers that allow audio and text representations to attend to each other, creating rich cross-modal embeddings that capture temporal and semantic relationships across 5 languages: **English, French, Spanish, Italian, and German**.
23
+
24
+ ### Key Features
25
+
26
+ - **Cross-Modal Attention**: Bidirectional attention between audio and text representations
27
+ - **Advanced Architecture**: Audio-to-text and text-to-audio attention mechanisms
28
+ - **Scalar Mixing**: Learnable combination of global and attended embeddings
29
+ - **Embedding Regularization**: Optional L2 regularization for embedding stability
30
+ - **Multimodal Processing**: Simultaneously processes audio (16kHz) and text inputs
31
+ - **Frozen Encoders**: Uses pre-trained SeamlessM4T encoders (frozen for stability)
32
+ - **TTE Prediction**: Predicts editing time required for subtitle segments
33
+ - **Direct Output**: Raw time values in seconds for immediate use
34
+
35
+ ## Model Architecture
36
+
37
+ The model implements sophisticated cross-modal attention mechanisms:
38
+
39
+ 1. **Audio Processing**:
40
+ - SeamlessM4T speech encoder (frozen) processes raw audio input
41
+ - Audio projection layer maps speech encoder output to 1024 dimensions
42
+ - Layer normalization for stability
43
+
44
+ 2. **Text Processing**:
45
+ - SeamlessM4T text encoder (frozen) processes tokenized text input
46
+ - Text projection layer maps text encoder output to 1024 dimensions
47
+ - Layer normalization for stability
48
+
49
+ 3. **Cross-Modal Attention**:
50
+ - **Audio-to-Text Attention**: Each audio token attends to all text tokens
51
+ - **Text-to-Audio Attention**: Each text token attends to all audio tokens
52
+ - Multi-head attention (8 heads) with dropout for regularization
53
+ - Bidirectional information flow between modalities
54
+
55
+ 4. **Feature Fusion**:
56
+ - Global pooling of original audio and text embeddings
57
+ - Global pooling of cross-attended embeddings
58
+ - Scalar mixing layer combines all four embeddings with learnable weights
59
+ - Final embedding captures both global and cross-modal patterns
60
+
61
+ 5. **Regression Head**:
62
+ - Multi-layer perceptron: 1024 → 512 → 256 → 1
63
+ - ReLU activations and dropout for regularization
64
+ - Single output for TTE prediction (regression, in seconds)
65
+
66
+ 6. **Optional Regularization**:
67
+ - L2 regularization on embedding norms for training stability
68
+ - Configurable regularization strength
69
+
70
+ ## Quick Start
71
+
72
+ ### Installation
73
+ ```bash
74
+ pip install transformers torch torchaudio huggingface_hub
75
+ ```
76
+
77
+ ### Basic Usage
78
+ ```python
79
+ from transformers import AutoModel, AutoConfig
80
+ from huggingface_hub import hf_hub_download
81
+ import torch
82
+ import numpy as np
83
+ import importlib.util
84
+
85
+ # Load model - custom architecture requires importing the model class
86
+ model_files = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="modeling_seamless_crossattention.py")
87
+ spec = importlib.util.spec_from_file_location("modeling_seamless_crossattention", model_files)
88
+ modeling_module = importlib.util.module_from_spec(spec)
89
+ spec.loader.exec_module(modeling_module)
90
+
91
+ # Now load the model using the custom class
92
+ config = modeling_module.SeamlessCrossAttentionConfig.from_pretrained("videoloc/seamless-crossattention")
93
+ model = modeling_module.HFSeamlessCrossAttention.from_pretrained("videoloc/seamless-crossattention")
94
+
95
+ # Load the data collator (included in this repo)
96
+ collator_file = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="data_collator.py")
97
+ spec = importlib.util.spec_from_file_location("data_collator", collator_file)
98
+ collator_module = importlib.util.module_from_spec(spec)
99
+ spec.loader.exec_module(collator_module)
100
+
101
+ # Initialize data collator
102
+ data_collator = collator_module.DataCollatorSimpleSeamless(
103
+ processor="facebook/hf-seamless-m4t-medium",
104
+ max_audio_length_sec=8.0,
105
+ max_text_length=256
106
+ )
107
+
108
+ # Prepare your data
109
+ your_data = [
110
+ {
111
+ 'raw_audio': np.random.randn(16000 * 5), # 5 seconds at 16kHz
112
+ 'raw_text': "Your subtitle text here",
113
+ # Note: Cross-attention model doesn't require translation features
114
+ }
115
+ ]
116
+
117
+ # Process and run inference
118
+ batch = data_collator(your_data)
119
+ model.eval()
120
+ with torch.no_grad():
121
+ outputs = model(**batch)
122
+ tte_prediction = outputs.logits.item()
123
+
124
+ print(f"Predicted Time To Edit (TTE): {tte_prediction:.2f} seconds")
125
+ ```
126
+
127
+ ## Model Details
128
+
129
+ - **Base Model**: SeamlessM4T (facebook/hf-seamless-m4t-medium)
130
+ - **Audio Encoder**: Frozen SeamlessM4T speech encoder
131
+ - **Text Encoder**: Frozen SeamlessM4T text encoder
132
+ - **Hidden Size**: 1024
133
+ - **Attention Heads**: 8 (configurable)
134
+ - **Cross-Attention**: Bidirectional (audio↔text)
135
+ - **Scalar Mix**: 4 embeddings (audio global, text global, audio→text, text→audio)
136
+ - **Audio Input**: 16kHz
137
+ - **Output**: Single regression value (TTE in seconds)
138
+ - **Task**: Subtitle editing time prediction
139
+
140
+ ## Data Format
141
+
142
+ Your input data should be a list of dictionaries with:
143
+ - `raw_audio`: NumPy array of audio samples (16kHz sampling rate)
144
+ - `raw_text`: String of subtitle text
145
+ - `labels`: Target TTE values in seconds (optional, for training)
146
+
147
+ Example:
148
+ ```python
149
+ data = [
150
+ {
151
+ 'raw_audio': audio_samples, # shape: (num_samples,) at 16kHz
152
+ 'raw_text': "Subtitle text content",
153
+ 'labels': 2.5 # optional TTE target value in seconds
154
+ }
155
+ ]
156
+ ```
157
+
158
+ ## Performance Metrics
159
+
160
+ - **Best Eval RMSE**: 33.34
161
+
162
+ ## Training Details
163
+
164
+ - **Base Model**: facebook/hf-seamless-m4t-medium
165
+ - **Model Type**: seamless_cross_attention
166
+ - **Epochs**: 10
167
+ - **Batch Size (Train)**: 32
168
+ - **Batch Size (Eval)**: 64
169
+ - **Learning Rate**: 1.2e-4
170
+ - **LR Scheduler**: cosine_with_restarts
171
+ - **Warmup Ratio**: 0.05
172
+ - **Weight Decay**: 0.001
173
+ - **Optimizer**: AdamW (torch)
174
+ - **Max Grad Norm**: 1.0
175
+ - **FP16**: True
176
+ - **Early Stopping Patience**: 5
177
+ - **Audio Max Length**: 8.0 seconds
178
+ - **Text Max Length**: 256 tokens
179
+ - **Sample Rate**: 16kHz
180
+ - **Cross-Attention**: 8-head multi-head attention
181
+ - **Scalar Mixing**: 4 embedding types
182
+ - **Embedding Regularization**: Optional L2
183
+ - **Normalization**: None (raw values)
184
+ - **Dataset Split**: 80/20 train/test
185
+ - **Random Seed**: 42
186
+ - **Metric**: RMSE (lower is better)
187
+
188
+ ## Training Configuration
189
+
190
+ The model was trained with the following specifications:
191
+
192
+ - **Dataset**: Multimodal audio-subtitle pairs with TTE annotations (5 languages: EN, FR, ES, IT, DE)
193
+ - **Train/Test Split**: 80/20 with random seed 42
194
+ - **Audio Processing**: 16kHz sampling, max 8.0 seconds, no offset
195
+ - **Text Processing**: Max 256 tokens
196
+ - **Cross-Attention**: 8-head multi-head attention with dropout
197
+ - **Scalar Mixing**: Learnable combination of 4 embedding types
198
+ - **Normalization**: None (raw TTE values in seconds)
199
+ - **Caching**: Audio segments cached and compressed for efficiency
200
+
201
+ ## Usage Notes
202
+
203
+ - This is the **advanced cross-attention** variant with sophisticated attention mechanisms
204
+ - For simpler models, see `seamless-basic`, `seamless-translation`, or `seamless-langpairs`
205
+ - Model expects 16kHz audio input (automatically resampled by data collator)
206
+ - Cross-attention captures complex temporal and semantic relationships
207
+ - No feature normalization applied - outputs raw TTE predictions in seconds
208
+ - Optimized for detailed subtitle editing time estimation tasks
209
+
210
+ ## Architecture Advantages
211
+
212
+ - **Rich Cross-Modal Interactions**: Audio and text modalities directly attend to each other
213
+ - **Temporal Alignment**: Cross-attention naturally captures temporal relationships
214
+ - **Semantic Understanding**: Text-to-audio attention helps model understand content meaning
215
+ - **Flexible Combination**: Scalar mixing allows model to weight different embedding types
216
+ - **Regularization Options**: Optional embedding regularization for training stability
217
+
218
+ ## Limitations
219
+
220
+ - Higher computational complexity than basic models due to attention mechanisms
221
+ - Requires more training data to fully leverage cross-attention capabilities
222
+ - Designed for TTE prediction, not general audio-text matching
223
+ - Performance may vary on out-of-domain content or different editing workflows
224
+ - Requires specific data preprocessing (use included data collator)
225
+
226
+ ## Related Models
227
+
228
+ - **seamless-basic**: Basic audio+text model without attention mechanisms
229
+ - **seamless-translation**: Includes translation awareness but no cross-attention
230
+ - **seamless-langpairs**: Includes language pair embeddings but no cross-attention
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ SeamlessCrossAttention model for HuggingFace Transformers
3
+ """
4
+ from .modeling_seamless_crossattention import HFSeamlessCrossAttention, SeamlessCrossAttentionConfig
5
+
6
+ __all__ = ["HFSeamlessCrossAttention", "SeamlessCrossAttentionConfig"]
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HFSeamlessCrossAttention"
4
+ ],
5
+ "dropout_prob": 0.1,
6
+ "embedding_regularization": 0.0,
7
+ "hidden_size": 1024,
8
+ "model_type": "seamless_crossattention",
9
+ "num_attention_heads": 8,
10
+ "seamless_model_name": "facebook/hf-seamless-m4t-medium",
11
+ "torch_dtype": "float32",
12
+ "transformers_version": "4.50.2"
13
+ }
data_collator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import AutoProcessor
4
+ from typing import Dict, List, Union
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class DataCollatorSimpleSeamless:
10
+ def __init__(
11
+ self,
12
+ processor: str,
13
+ sample_rate: int = 16000,
14
+ max_audio_length_sec: float = 8.0,
15
+ max_text_length: int = 256,
16
+ normalization_type: str = "none"
17
+ ):
18
+ """Initialize the data collator.
19
+
20
+ Args:
21
+ processor: The processor to use.
22
+ sample_rate: Audio sample rate.
23
+ max_audio_length_sec: Maximum audio length in seconds.
24
+ max_text_length: Maximum text length.
25
+ normalization_type: Type of normalization to apply to labels. Options: "log1p", "none"
26
+ """
27
+ logger.info(f"Loading processor: {processor}")
28
+ self.processor = AutoProcessor.from_pretrained(processor)
29
+
30
+ self.sample_rate = sample_rate
31
+ self.max_audio_sample_length = int(max_audio_length_sec * sample_rate)
32
+ self.max_text_length = max_text_length
33
+ self.normalization_type = normalization_type
34
+
35
+ def __call__(self, batch: List[Dict[str, Union[np.ndarray, str, float]]]) -> Dict[str, torch.Tensor]:
36
+ """Process a batch of raw features into model inputs."""
37
+ # Extract raw data
38
+ raw_audios = [item['raw_audio'] for item in batch]
39
+ raw_texts = [item['raw_text'] for item in batch]
40
+
41
+ raw_audios = [torch.tensor(audio) for audio in raw_audios]
42
+
43
+ audio_inputs = self.processor(
44
+ audios=raw_audios,
45
+ sampling_rate=self.sample_rate,
46
+ return_tensors="pt",
47
+ padding="longest",
48
+ truncation=True,
49
+ max_length=self.max_audio_sample_length,
50
+ )
51
+
52
+ text_inputs = self.processor(
53
+ text=raw_texts,
54
+ return_tensors="pt",
55
+ padding="longest",
56
+ truncation=True,
57
+ max_length=self.max_text_length,
58
+ )
59
+
60
+ # Extract translation features
61
+ is_translation = torch.tensor([item.get('is_translation', 0) for item in batch], dtype=torch.float32)
62
+
63
+ # Extract language pair features
64
+ language_pair_id = torch.tensor([item.get('language_pair_id', 0) for item in batch], dtype=torch.long)
65
+
66
+ if 'labels' in batch[0]:
67
+ labels = [item['labels'] for item in batch]
68
+ labels = torch.tensor(labels, dtype=torch.float32)
69
+
70
+ # Apply normalization based on type
71
+ if self.normalization_type == "log1p":
72
+ labels = torch.log1p(labels)
73
+ elif self.normalization_type == "none":
74
+ pass
75
+ else:
76
+ raise ValueError(f"Unknown normalization type: {self.normalization_type}")
77
+ else:
78
+ labels = None
79
+
80
+ return {
81
+ 'input_features': audio_inputs['input_features'],
82
+ 'audio_attention_mask': audio_inputs.get('attention_mask', None) if audio_inputs.get('attention_mask') is not None else None,
83
+ 'input_ids': text_inputs['input_ids'],
84
+ 'text_attention_mask': text_inputs['attention_mask'],
85
+ 'is_translation': is_translation,
86
+ 'language_pair_id': language_pair_id,
87
+ **({'labels': labels} if labels is not None else {})
88
+ }
example_usage.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Example usage for videoloc/seamless-crossattention
3
+
4
+ from transformers import AutoModel, AutoConfig
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+ import numpy as np
8
+ import importlib.util
9
+
10
+ def load_model_and_collator():
11
+ # Load model - custom architecture requires importing the model class
12
+ model_files = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="modeling_seamless_crossattention.py")
13
+ spec = importlib.util.spec_from_file_location("modeling_seamless_crossattention", model_files)
14
+ modeling_module = importlib.util.module_from_spec(spec)
15
+ spec.loader.exec_module(modeling_module)
16
+
17
+ # Now load the model using the custom class
18
+ config = modeling_module.SeamlessCrossAttentionConfig.from_pretrained("videoloc/seamless-crossattention")
19
+ model = modeling_module.HFSeamlessCrossAttention.from_pretrained("videoloc/seamless-crossattention")
20
+
21
+ # Load data collator
22
+ collator_file = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="data_collator.py")
23
+ spec = importlib.util.spec_from_file_location("data_collator", collator_file)
24
+ collator_module = importlib.util.module_from_spec(spec)
25
+ spec.loader.exec_module(collator_module)
26
+
27
+ data_collator = collator_module.DataCollatorSimpleSeamless(
28
+ processor="facebook/hf-seamless-m4t-medium",
29
+ max_audio_length_sec=8.0,
30
+ max_text_length=256
31
+ )
32
+
33
+ return model, data_collator
34
+
35
+ def example_inference():
36
+ model, collator = load_model_and_collator()
37
+
38
+ # Example data: audio segment + subtitle text for cross-attention TTE prediction
39
+ data = [{
40
+ 'raw_audio': np.random.randn(16000 * 3), # 3 seconds at 16kHz
41
+ 'raw_text': "Example subtitle text with cross-modal attention for TTE prediction",
42
+ }]
43
+
44
+ batch = collator(data)
45
+ model.eval()
46
+ with torch.no_grad():
47
+ outputs = model(**batch)
48
+ tte_prediction = outputs.logits.item()
49
+
50
+ print(f"Predicted Time To Edit (TTE): {tte_prediction:.2f} seconds")
51
+ return tte_prediction
52
+
53
+ if __name__ == "__main__":
54
+ example_inference()
modeling_seamless_crossattention.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from transformers import SeamlessM4TModel
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class SeamlessCrossAttentionConfig(PretrainedConfig):
13
+ """Configuration class for SeamlessCrossAttention model."""
14
+
15
+ model_type = "seamless_crossattention"
16
+
17
+ def __init__(
18
+ self,
19
+ seamless_model_name="facebook/hf-seamless-m4t-medium",
20
+ hidden_size=1024,
21
+ dropout_prob=0.1,
22
+ num_attention_heads=8,
23
+ embedding_regularization=0.0,
24
+ **kwargs
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.seamless_model_name = seamless_model_name
28
+ self.hidden_size = hidden_size
29
+ self.dropout_prob = dropout_prob
30
+ self.num_attention_heads = num_attention_heads
31
+ self.embedding_regularization = embedding_regularization
32
+
33
+
34
+ class ScalarMix(nn.Module):
35
+ """Scalar mixing layer for combining multiple embeddings."""
36
+
37
+ def __init__(self, num_inputs=4):
38
+ super().__init__()
39
+ self.weights = nn.Parameter(torch.ones(num_inputs))
40
+ self.gamma = nn.Parameter(torch.tensor(1.0))
41
+
42
+ def forward(self, *tensors):
43
+ # Normalize weights with softmax
44
+ weights = F.softmax(self.weights, dim=0)
45
+
46
+ # Weighted sum
47
+ weighted_sum = sum(w * t for w, t in zip(weights, tensors))
48
+
49
+ # Scale by gamma
50
+ return self.gamma * weighted_sum
51
+
52
+
53
+ class HFSeamlessCrossAttention(PreTrainedModel):
54
+ """SeamlessM4T model with cross attention for HuggingFace Hub."""
55
+
56
+ config_class = SeamlessCrossAttentionConfig
57
+ supports_gradient_checkpointing = True
58
+
59
+ def __init__(self, config):
60
+ super().__init__(config)
61
+ self.config = config
62
+
63
+ # Load the underlying SeamlessM4T model
64
+ self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
65
+ self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
66
+ self.seamless_model_text_encoder = self.seamless_model.text_encoder
67
+
68
+ # Freeze pre-trained models
69
+ for param in self.seamless_model_speech_encoder.parameters():
70
+ param.requires_grad = False
71
+ for param in self.seamless_model_text_encoder.parameters():
72
+ param.requires_grad = False
73
+
74
+ # Projection layers
75
+ self.audio_proj = nn.Linear(
76
+ self.seamless_model_speech_encoder.config.hidden_size,
77
+ config.hidden_size
78
+ )
79
+ self.text_proj = nn.Linear(
80
+ self.seamless_model_text_encoder.config.hidden_size,
81
+ config.hidden_size
82
+ )
83
+
84
+ # Layer norms
85
+ self.audio_norm = nn.LayerNorm(config.hidden_size)
86
+ self.text_norm = nn.LayerNorm(config.hidden_size)
87
+
88
+ # Cross-attention layers
89
+ self.audio_to_text_attention = nn.MultiheadAttention(
90
+ embed_dim=config.hidden_size,
91
+ num_heads=config.num_attention_heads,
92
+ dropout=config.dropout_prob,
93
+ batch_first=True
94
+ )
95
+
96
+ self.text_to_audio_attention = nn.MultiheadAttention(
97
+ embed_dim=config.hidden_size,
98
+ num_heads=config.num_attention_heads,
99
+ dropout=config.dropout_prob,
100
+ batch_first=True
101
+ )
102
+
103
+ # Scalar mix for combining embeddings
104
+ self.scalar_mix = ScalarMix(num_inputs=4)
105
+
106
+ # Enhanced classifier with residual connections
107
+ self.fc = nn.Sequential(
108
+ nn.Linear(config.hidden_size, 512),
109
+ nn.ReLU(),
110
+ nn.Dropout(config.dropout_prob),
111
+ nn.Linear(512, 256),
112
+ nn.ReLU(),
113
+ nn.Dropout(config.dropout_prob),
114
+ nn.Linear(256, 1)
115
+ )
116
+
117
+ # Initialize new layers
118
+ self._initialize_new_layers()
119
+
120
+ def _initialize_new_layers(self):
121
+ """Initialize new layers with proper weights."""
122
+ for module in [self.audio_proj, self.text_proj, self.fc]:
123
+ if isinstance(module, nn.Linear):
124
+ nn.init.xavier_uniform_(module.weight)
125
+ nn.init.zeros_(module.bias)
126
+ elif isinstance(module, nn.Sequential):
127
+ for layer in module:
128
+ if isinstance(layer, nn.Linear):
129
+ nn.init.xavier_uniform_(layer.weight)
130
+ nn.init.zeros_(layer.bias)
131
+
132
+ def forward(
133
+ self,
134
+ input_features,
135
+ input_ids,
136
+ text_attention_mask,
137
+ audio_attention_mask=None,
138
+ labels=None,
139
+ **kwargs # Accept additional features but ignore them
140
+ ):
141
+ # Create default audio attention mask if not provided
142
+ if audio_attention_mask is None:
143
+ audio_attention_mask = torch.ones(
144
+ input_features.size(0), input_features.size(1),
145
+ device=input_features.device
146
+ )
147
+
148
+ # 1. Encode audio
149
+ audio_output = self.seamless_model_speech_encoder(
150
+ input_features=input_features,
151
+ attention_mask=audio_attention_mask
152
+ )
153
+ audio_hidden_states = audio_output.last_hidden_state # [batch_size, audio_seq_len, hidden_size]
154
+
155
+ # 2. Encode text
156
+ text_output = self.seamless_model_text_encoder(
157
+ input_ids=input_ids,
158
+ attention_mask=text_attention_mask
159
+ )
160
+ text_hidden_states = text_output.last_hidden_state # [batch_size, text_seq_len, hidden_size]
161
+
162
+ # 3. Project to common dimension
163
+ audio_projected = self.audio_proj(audio_hidden_states) # [batch_size, audio_seq_len, hidden_size]
164
+ text_projected = self.text_proj(text_hidden_states) # [batch_size, text_seq_len, hidden_size]
165
+
166
+ audio_projected = self.audio_norm(audio_projected)
167
+ text_projected = self.text_norm(text_projected)
168
+
169
+ # 4. Global pooling (mean) of original embeddings
170
+ audio_global = audio_projected.mean(dim=1) # [batch_size, hidden_size]
171
+ text_global = text_projected.mean(dim=1) # [batch_size, hidden_size]
172
+
173
+ # 5. Cross-attention with masks
174
+ # Audio attends to text - each audio token attends to all text tokens
175
+ audio_attended_to_text, _ = self.audio_to_text_attention(
176
+ query=audio_projected, # [batch_size, audio_seq_len, hidden_size]
177
+ key=text_projected, # [batch_size, text_seq_len, hidden_size]
178
+ value=text_projected, # [batch_size, text_seq_len, hidden_size]
179
+ )
180
+
181
+ # Text attends to audio - each text token attends to all audio tokens
182
+ text_attended_to_audio, _ = self.text_to_audio_attention(
183
+ query=text_projected, # [batch_size, text_seq_len, hidden_size]
184
+ key=audio_projected, # [batch_size, audio_seq_len, hidden_size]
185
+ value=audio_projected, # [batch_size, audio_seq_len, hidden_size]
186
+ )
187
+
188
+ # 6. Global pooling (mean) of attended embeddings
189
+ audio_attended_emb = audio_attended_to_text.mean(dim=1) # [batch_size, hidden_size]
190
+ text_attended_emb = text_attended_to_audio.mean(dim=1) # [batch_size, hidden_size]
191
+
192
+ # 7. Combine with scalar mix
193
+ final_embedding = self.scalar_mix(
194
+ audio_global,
195
+ text_global,
196
+ audio_attended_emb,
197
+ text_attended_emb
198
+ )
199
+
200
+ # 8. Classification
201
+ logits = self.fc(final_embedding).squeeze(-1)
202
+
203
+ # Compute loss if labels are provided
204
+ loss = None
205
+ if labels is not None:
206
+ mse_loss = F.mse_loss(logits, labels)
207
+
208
+ # Add embedding regularization if specified
209
+ reg_loss = 0.0
210
+ if self.config.embedding_regularization > 0:
211
+ reg_loss = (
212
+ torch.norm(audio_global, p=2, dim=1).mean() +
213
+ torch.norm(text_global, p=2, dim=1).mean() +
214
+ torch.norm(audio_attended_emb, p=2, dim=1).mean() +
215
+ torch.norm(text_attended_emb, p=2, dim=1).mean()
216
+ ) / 4.0
217
+
218
+ loss = mse_loss + self.config.embedding_regularization * reg_loss
219
+
220
+ return SequenceClassifierOutput(
221
+ loss=loss,
222
+ logits=logits,
223
+ hidden_states=None,
224
+ attentions=None
225
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29764a5e44028038b2251c4bc21f8bccaefc03f06bd4e796f77683b4e7914e51
3
+ size 4883154633
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.50.2
2
+ torch>=2.6.0
3
+ torchaudio>=2.6.0
4
+ huggingface_hub>=0.33.0
5
+ numpy>=2.2.3
6
+ sentencepiece>=0.2.0
7
+ accelerate>=1.5.2
8
+ soundfile>=0.13.1