Upload folder using huggingface_hub
Browse files- README.md +230 -0
- __init__.py +6 -0
- config.json +13 -0
- data_collator.py +88 -0
- example_usage.py +54 -0
- modeling_seamless_crossattention.py +225 -0
- pytorch_model.bin +3 -0
- requirements.txt +8 -0
README.md
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- multilingual
|
| 4 |
+
tags:
|
| 5 |
+
- audio
|
| 6 |
+
- text
|
| 7 |
+
- multimodal
|
| 8 |
+
- seamless
|
| 9 |
+
- subtitle-editing-time-prediction
|
| 10 |
+
- cross-attention
|
| 11 |
+
- attention-mechanism
|
| 12 |
+
library_name: transformers
|
| 13 |
+
base_model: facebook/hf-seamless-m4t-medium
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# videoloc/seamless-crossattention
|
| 17 |
+
|
| 18 |
+
## Model Description
|
| 19 |
+
|
| 20 |
+
This is a **SeamlessCrossAttention** model that processes audio and text inputs with advanced cross-modal attention mechanisms to predict **Time To Edit (TTE)** for subtitle segments. Given an audio segment and its corresponding subtitle text, the model predicts how much time (in seconds) would be required to edit/refine that subtitle segment, leveraging sophisticated cross-attention patterns between audio and text modalities.
|
| 21 |
+
|
| 22 |
+
The model extends the SeamlessM4T architecture with bidirectional cross-attention layers that allow audio and text representations to attend to each other, creating rich cross-modal embeddings that capture temporal and semantic relationships across 5 languages: **English, French, Spanish, Italian, and German**.
|
| 23 |
+
|
| 24 |
+
### Key Features
|
| 25 |
+
|
| 26 |
+
- **Cross-Modal Attention**: Bidirectional attention between audio and text representations
|
| 27 |
+
- **Advanced Architecture**: Audio-to-text and text-to-audio attention mechanisms
|
| 28 |
+
- **Scalar Mixing**: Learnable combination of global and attended embeddings
|
| 29 |
+
- **Embedding Regularization**: Optional L2 regularization for embedding stability
|
| 30 |
+
- **Multimodal Processing**: Simultaneously processes audio (16kHz) and text inputs
|
| 31 |
+
- **Frozen Encoders**: Uses pre-trained SeamlessM4T encoders (frozen for stability)
|
| 32 |
+
- **TTE Prediction**: Predicts editing time required for subtitle segments
|
| 33 |
+
- **Direct Output**: Raw time values in seconds for immediate use
|
| 34 |
+
|
| 35 |
+
## Model Architecture
|
| 36 |
+
|
| 37 |
+
The model implements sophisticated cross-modal attention mechanisms:
|
| 38 |
+
|
| 39 |
+
1. **Audio Processing**:
|
| 40 |
+
- SeamlessM4T speech encoder (frozen) processes raw audio input
|
| 41 |
+
- Audio projection layer maps speech encoder output to 1024 dimensions
|
| 42 |
+
- Layer normalization for stability
|
| 43 |
+
|
| 44 |
+
2. **Text Processing**:
|
| 45 |
+
- SeamlessM4T text encoder (frozen) processes tokenized text input
|
| 46 |
+
- Text projection layer maps text encoder output to 1024 dimensions
|
| 47 |
+
- Layer normalization for stability
|
| 48 |
+
|
| 49 |
+
3. **Cross-Modal Attention**:
|
| 50 |
+
- **Audio-to-Text Attention**: Each audio token attends to all text tokens
|
| 51 |
+
- **Text-to-Audio Attention**: Each text token attends to all audio tokens
|
| 52 |
+
- Multi-head attention (8 heads) with dropout for regularization
|
| 53 |
+
- Bidirectional information flow between modalities
|
| 54 |
+
|
| 55 |
+
4. **Feature Fusion**:
|
| 56 |
+
- Global pooling of original audio and text embeddings
|
| 57 |
+
- Global pooling of cross-attended embeddings
|
| 58 |
+
- Scalar mixing layer combines all four embeddings with learnable weights
|
| 59 |
+
- Final embedding captures both global and cross-modal patterns
|
| 60 |
+
|
| 61 |
+
5. **Regression Head**:
|
| 62 |
+
- Multi-layer perceptron: 1024 → 512 → 256 → 1
|
| 63 |
+
- ReLU activations and dropout for regularization
|
| 64 |
+
- Single output for TTE prediction (regression, in seconds)
|
| 65 |
+
|
| 66 |
+
6. **Optional Regularization**:
|
| 67 |
+
- L2 regularization on embedding norms for training stability
|
| 68 |
+
- Configurable regularization strength
|
| 69 |
+
|
| 70 |
+
## Quick Start
|
| 71 |
+
|
| 72 |
+
### Installation
|
| 73 |
+
```bash
|
| 74 |
+
pip install transformers torch torchaudio huggingface_hub
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Basic Usage
|
| 78 |
+
```python
|
| 79 |
+
from transformers import AutoModel, AutoConfig
|
| 80 |
+
from huggingface_hub import hf_hub_download
|
| 81 |
+
import torch
|
| 82 |
+
import numpy as np
|
| 83 |
+
import importlib.util
|
| 84 |
+
|
| 85 |
+
# Load model - custom architecture requires importing the model class
|
| 86 |
+
model_files = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="modeling_seamless_crossattention.py")
|
| 87 |
+
spec = importlib.util.spec_from_file_location("modeling_seamless_crossattention", model_files)
|
| 88 |
+
modeling_module = importlib.util.module_from_spec(spec)
|
| 89 |
+
spec.loader.exec_module(modeling_module)
|
| 90 |
+
|
| 91 |
+
# Now load the model using the custom class
|
| 92 |
+
config = modeling_module.SeamlessCrossAttentionConfig.from_pretrained("videoloc/seamless-crossattention")
|
| 93 |
+
model = modeling_module.HFSeamlessCrossAttention.from_pretrained("videoloc/seamless-crossattention")
|
| 94 |
+
|
| 95 |
+
# Load the data collator (included in this repo)
|
| 96 |
+
collator_file = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="data_collator.py")
|
| 97 |
+
spec = importlib.util.spec_from_file_location("data_collator", collator_file)
|
| 98 |
+
collator_module = importlib.util.module_from_spec(spec)
|
| 99 |
+
spec.loader.exec_module(collator_module)
|
| 100 |
+
|
| 101 |
+
# Initialize data collator
|
| 102 |
+
data_collator = collator_module.DataCollatorSimpleSeamless(
|
| 103 |
+
processor="facebook/hf-seamless-m4t-medium",
|
| 104 |
+
max_audio_length_sec=8.0,
|
| 105 |
+
max_text_length=256
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Prepare your data
|
| 109 |
+
your_data = [
|
| 110 |
+
{
|
| 111 |
+
'raw_audio': np.random.randn(16000 * 5), # 5 seconds at 16kHz
|
| 112 |
+
'raw_text': "Your subtitle text here",
|
| 113 |
+
# Note: Cross-attention model doesn't require translation features
|
| 114 |
+
}
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
# Process and run inference
|
| 118 |
+
batch = data_collator(your_data)
|
| 119 |
+
model.eval()
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
outputs = model(**batch)
|
| 122 |
+
tte_prediction = outputs.logits.item()
|
| 123 |
+
|
| 124 |
+
print(f"Predicted Time To Edit (TTE): {tte_prediction:.2f} seconds")
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## Model Details
|
| 128 |
+
|
| 129 |
+
- **Base Model**: SeamlessM4T (facebook/hf-seamless-m4t-medium)
|
| 130 |
+
- **Audio Encoder**: Frozen SeamlessM4T speech encoder
|
| 131 |
+
- **Text Encoder**: Frozen SeamlessM4T text encoder
|
| 132 |
+
- **Hidden Size**: 1024
|
| 133 |
+
- **Attention Heads**: 8 (configurable)
|
| 134 |
+
- **Cross-Attention**: Bidirectional (audio↔text)
|
| 135 |
+
- **Scalar Mix**: 4 embeddings (audio global, text global, audio→text, text→audio)
|
| 136 |
+
- **Audio Input**: 16kHz
|
| 137 |
+
- **Output**: Single regression value (TTE in seconds)
|
| 138 |
+
- **Task**: Subtitle editing time prediction
|
| 139 |
+
|
| 140 |
+
## Data Format
|
| 141 |
+
|
| 142 |
+
Your input data should be a list of dictionaries with:
|
| 143 |
+
- `raw_audio`: NumPy array of audio samples (16kHz sampling rate)
|
| 144 |
+
- `raw_text`: String of subtitle text
|
| 145 |
+
- `labels`: Target TTE values in seconds (optional, for training)
|
| 146 |
+
|
| 147 |
+
Example:
|
| 148 |
+
```python
|
| 149 |
+
data = [
|
| 150 |
+
{
|
| 151 |
+
'raw_audio': audio_samples, # shape: (num_samples,) at 16kHz
|
| 152 |
+
'raw_text': "Subtitle text content",
|
| 153 |
+
'labels': 2.5 # optional TTE target value in seconds
|
| 154 |
+
}
|
| 155 |
+
]
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Performance Metrics
|
| 159 |
+
|
| 160 |
+
- **Best Eval RMSE**: 33.34
|
| 161 |
+
|
| 162 |
+
## Training Details
|
| 163 |
+
|
| 164 |
+
- **Base Model**: facebook/hf-seamless-m4t-medium
|
| 165 |
+
- **Model Type**: seamless_cross_attention
|
| 166 |
+
- **Epochs**: 10
|
| 167 |
+
- **Batch Size (Train)**: 32
|
| 168 |
+
- **Batch Size (Eval)**: 64
|
| 169 |
+
- **Learning Rate**: 1.2e-4
|
| 170 |
+
- **LR Scheduler**: cosine_with_restarts
|
| 171 |
+
- **Warmup Ratio**: 0.05
|
| 172 |
+
- **Weight Decay**: 0.001
|
| 173 |
+
- **Optimizer**: AdamW (torch)
|
| 174 |
+
- **Max Grad Norm**: 1.0
|
| 175 |
+
- **FP16**: True
|
| 176 |
+
- **Early Stopping Patience**: 5
|
| 177 |
+
- **Audio Max Length**: 8.0 seconds
|
| 178 |
+
- **Text Max Length**: 256 tokens
|
| 179 |
+
- **Sample Rate**: 16kHz
|
| 180 |
+
- **Cross-Attention**: 8-head multi-head attention
|
| 181 |
+
- **Scalar Mixing**: 4 embedding types
|
| 182 |
+
- **Embedding Regularization**: Optional L2
|
| 183 |
+
- **Normalization**: None (raw values)
|
| 184 |
+
- **Dataset Split**: 80/20 train/test
|
| 185 |
+
- **Random Seed**: 42
|
| 186 |
+
- **Metric**: RMSE (lower is better)
|
| 187 |
+
|
| 188 |
+
## Training Configuration
|
| 189 |
+
|
| 190 |
+
The model was trained with the following specifications:
|
| 191 |
+
|
| 192 |
+
- **Dataset**: Multimodal audio-subtitle pairs with TTE annotations (5 languages: EN, FR, ES, IT, DE)
|
| 193 |
+
- **Train/Test Split**: 80/20 with random seed 42
|
| 194 |
+
- **Audio Processing**: 16kHz sampling, max 8.0 seconds, no offset
|
| 195 |
+
- **Text Processing**: Max 256 tokens
|
| 196 |
+
- **Cross-Attention**: 8-head multi-head attention with dropout
|
| 197 |
+
- **Scalar Mixing**: Learnable combination of 4 embedding types
|
| 198 |
+
- **Normalization**: None (raw TTE values in seconds)
|
| 199 |
+
- **Caching**: Audio segments cached and compressed for efficiency
|
| 200 |
+
|
| 201 |
+
## Usage Notes
|
| 202 |
+
|
| 203 |
+
- This is the **advanced cross-attention** variant with sophisticated attention mechanisms
|
| 204 |
+
- For simpler models, see `seamless-basic`, `seamless-translation`, or `seamless-langpairs`
|
| 205 |
+
- Model expects 16kHz audio input (automatically resampled by data collator)
|
| 206 |
+
- Cross-attention captures complex temporal and semantic relationships
|
| 207 |
+
- No feature normalization applied - outputs raw TTE predictions in seconds
|
| 208 |
+
- Optimized for detailed subtitle editing time estimation tasks
|
| 209 |
+
|
| 210 |
+
## Architecture Advantages
|
| 211 |
+
|
| 212 |
+
- **Rich Cross-Modal Interactions**: Audio and text modalities directly attend to each other
|
| 213 |
+
- **Temporal Alignment**: Cross-attention naturally captures temporal relationships
|
| 214 |
+
- **Semantic Understanding**: Text-to-audio attention helps model understand content meaning
|
| 215 |
+
- **Flexible Combination**: Scalar mixing allows model to weight different embedding types
|
| 216 |
+
- **Regularization Options**: Optional embedding regularization for training stability
|
| 217 |
+
|
| 218 |
+
## Limitations
|
| 219 |
+
|
| 220 |
+
- Higher computational complexity than basic models due to attention mechanisms
|
| 221 |
+
- Requires more training data to fully leverage cross-attention capabilities
|
| 222 |
+
- Designed for TTE prediction, not general audio-text matching
|
| 223 |
+
- Performance may vary on out-of-domain content or different editing workflows
|
| 224 |
+
- Requires specific data preprocessing (use included data collator)
|
| 225 |
+
|
| 226 |
+
## Related Models
|
| 227 |
+
|
| 228 |
+
- **seamless-basic**: Basic audio+text model without attention mechanisms
|
| 229 |
+
- **seamless-translation**: Includes translation awareness but no cross-attention
|
| 230 |
+
- **seamless-langpairs**: Includes language pair embeddings but no cross-attention
|
__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SeamlessCrossAttention model for HuggingFace Transformers
|
| 3 |
+
"""
|
| 4 |
+
from .modeling_seamless_crossattention import HFSeamlessCrossAttention, SeamlessCrossAttentionConfig
|
| 5 |
+
|
| 6 |
+
__all__ = ["HFSeamlessCrossAttention", "SeamlessCrossAttentionConfig"]
|
config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"HFSeamlessCrossAttention"
|
| 4 |
+
],
|
| 5 |
+
"dropout_prob": 0.1,
|
| 6 |
+
"embedding_regularization": 0.0,
|
| 7 |
+
"hidden_size": 1024,
|
| 8 |
+
"model_type": "seamless_crossattention",
|
| 9 |
+
"num_attention_heads": 8,
|
| 10 |
+
"seamless_model_name": "facebook/hf-seamless-m4t-medium",
|
| 11 |
+
"torch_dtype": "float32",
|
| 12 |
+
"transformers_version": "4.50.2"
|
| 13 |
+
}
|
data_collator.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from transformers import AutoProcessor
|
| 4 |
+
from typing import Dict, List, Union
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class DataCollatorSimpleSeamless:
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
processor: str,
|
| 13 |
+
sample_rate: int = 16000,
|
| 14 |
+
max_audio_length_sec: float = 8.0,
|
| 15 |
+
max_text_length: int = 256,
|
| 16 |
+
normalization_type: str = "none"
|
| 17 |
+
):
|
| 18 |
+
"""Initialize the data collator.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
processor: The processor to use.
|
| 22 |
+
sample_rate: Audio sample rate.
|
| 23 |
+
max_audio_length_sec: Maximum audio length in seconds.
|
| 24 |
+
max_text_length: Maximum text length.
|
| 25 |
+
normalization_type: Type of normalization to apply to labels. Options: "log1p", "none"
|
| 26 |
+
"""
|
| 27 |
+
logger.info(f"Loading processor: {processor}")
|
| 28 |
+
self.processor = AutoProcessor.from_pretrained(processor)
|
| 29 |
+
|
| 30 |
+
self.sample_rate = sample_rate
|
| 31 |
+
self.max_audio_sample_length = int(max_audio_length_sec * sample_rate)
|
| 32 |
+
self.max_text_length = max_text_length
|
| 33 |
+
self.normalization_type = normalization_type
|
| 34 |
+
|
| 35 |
+
def __call__(self, batch: List[Dict[str, Union[np.ndarray, str, float]]]) -> Dict[str, torch.Tensor]:
|
| 36 |
+
"""Process a batch of raw features into model inputs."""
|
| 37 |
+
# Extract raw data
|
| 38 |
+
raw_audios = [item['raw_audio'] for item in batch]
|
| 39 |
+
raw_texts = [item['raw_text'] for item in batch]
|
| 40 |
+
|
| 41 |
+
raw_audios = [torch.tensor(audio) for audio in raw_audios]
|
| 42 |
+
|
| 43 |
+
audio_inputs = self.processor(
|
| 44 |
+
audios=raw_audios,
|
| 45 |
+
sampling_rate=self.sample_rate,
|
| 46 |
+
return_tensors="pt",
|
| 47 |
+
padding="longest",
|
| 48 |
+
truncation=True,
|
| 49 |
+
max_length=self.max_audio_sample_length,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
text_inputs = self.processor(
|
| 53 |
+
text=raw_texts,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
padding="longest",
|
| 56 |
+
truncation=True,
|
| 57 |
+
max_length=self.max_text_length,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Extract translation features
|
| 61 |
+
is_translation = torch.tensor([item.get('is_translation', 0) for item in batch], dtype=torch.float32)
|
| 62 |
+
|
| 63 |
+
# Extract language pair features
|
| 64 |
+
language_pair_id = torch.tensor([item.get('language_pair_id', 0) for item in batch], dtype=torch.long)
|
| 65 |
+
|
| 66 |
+
if 'labels' in batch[0]:
|
| 67 |
+
labels = [item['labels'] for item in batch]
|
| 68 |
+
labels = torch.tensor(labels, dtype=torch.float32)
|
| 69 |
+
|
| 70 |
+
# Apply normalization based on type
|
| 71 |
+
if self.normalization_type == "log1p":
|
| 72 |
+
labels = torch.log1p(labels)
|
| 73 |
+
elif self.normalization_type == "none":
|
| 74 |
+
pass
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
|
| 77 |
+
else:
|
| 78 |
+
labels = None
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
'input_features': audio_inputs['input_features'],
|
| 82 |
+
'audio_attention_mask': audio_inputs.get('attention_mask', None) if audio_inputs.get('attention_mask') is not None else None,
|
| 83 |
+
'input_ids': text_inputs['input_ids'],
|
| 84 |
+
'text_attention_mask': text_inputs['attention_mask'],
|
| 85 |
+
'is_translation': is_translation,
|
| 86 |
+
'language_pair_id': language_pair_id,
|
| 87 |
+
**({'labels': labels} if labels is not None else {})
|
| 88 |
+
}
|
example_usage.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Example usage for videoloc/seamless-crossattention
|
| 3 |
+
|
| 4 |
+
from transformers import AutoModel, AutoConfig
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import importlib.util
|
| 9 |
+
|
| 10 |
+
def load_model_and_collator():
|
| 11 |
+
# Load model - custom architecture requires importing the model class
|
| 12 |
+
model_files = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="modeling_seamless_crossattention.py")
|
| 13 |
+
spec = importlib.util.spec_from_file_location("modeling_seamless_crossattention", model_files)
|
| 14 |
+
modeling_module = importlib.util.module_from_spec(spec)
|
| 15 |
+
spec.loader.exec_module(modeling_module)
|
| 16 |
+
|
| 17 |
+
# Now load the model using the custom class
|
| 18 |
+
config = modeling_module.SeamlessCrossAttentionConfig.from_pretrained("videoloc/seamless-crossattention")
|
| 19 |
+
model = modeling_module.HFSeamlessCrossAttention.from_pretrained("videoloc/seamless-crossattention")
|
| 20 |
+
|
| 21 |
+
# Load data collator
|
| 22 |
+
collator_file = hf_hub_download(repo_id="videoloc/seamless-crossattention", filename="data_collator.py")
|
| 23 |
+
spec = importlib.util.spec_from_file_location("data_collator", collator_file)
|
| 24 |
+
collator_module = importlib.util.module_from_spec(spec)
|
| 25 |
+
spec.loader.exec_module(collator_module)
|
| 26 |
+
|
| 27 |
+
data_collator = collator_module.DataCollatorSimpleSeamless(
|
| 28 |
+
processor="facebook/hf-seamless-m4t-medium",
|
| 29 |
+
max_audio_length_sec=8.0,
|
| 30 |
+
max_text_length=256
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return model, data_collator
|
| 34 |
+
|
| 35 |
+
def example_inference():
|
| 36 |
+
model, collator = load_model_and_collator()
|
| 37 |
+
|
| 38 |
+
# Example data: audio segment + subtitle text for cross-attention TTE prediction
|
| 39 |
+
data = [{
|
| 40 |
+
'raw_audio': np.random.randn(16000 * 3), # 3 seconds at 16kHz
|
| 41 |
+
'raw_text': "Example subtitle text with cross-modal attention for TTE prediction",
|
| 42 |
+
}]
|
| 43 |
+
|
| 44 |
+
batch = collator(data)
|
| 45 |
+
model.eval()
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
outputs = model(**batch)
|
| 48 |
+
tte_prediction = outputs.logits.item()
|
| 49 |
+
|
| 50 |
+
print(f"Predicted Time To Edit (TTE): {tte_prediction:.2f} seconds")
|
| 51 |
+
return tte_prediction
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
example_inference()
|
modeling_seamless_crossattention.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 6 |
+
from transformers import SeamlessM4TModel
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SeamlessCrossAttentionConfig(PretrainedConfig):
|
| 13 |
+
"""Configuration class for SeamlessCrossAttention model."""
|
| 14 |
+
|
| 15 |
+
model_type = "seamless_crossattention"
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
seamless_model_name="facebook/hf-seamless-m4t-medium",
|
| 20 |
+
hidden_size=1024,
|
| 21 |
+
dropout_prob=0.1,
|
| 22 |
+
num_attention_heads=8,
|
| 23 |
+
embedding_regularization=0.0,
|
| 24 |
+
**kwargs
|
| 25 |
+
):
|
| 26 |
+
super().__init__(**kwargs)
|
| 27 |
+
self.seamless_model_name = seamless_model_name
|
| 28 |
+
self.hidden_size = hidden_size
|
| 29 |
+
self.dropout_prob = dropout_prob
|
| 30 |
+
self.num_attention_heads = num_attention_heads
|
| 31 |
+
self.embedding_regularization = embedding_regularization
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ScalarMix(nn.Module):
|
| 35 |
+
"""Scalar mixing layer for combining multiple embeddings."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, num_inputs=4):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.weights = nn.Parameter(torch.ones(num_inputs))
|
| 40 |
+
self.gamma = nn.Parameter(torch.tensor(1.0))
|
| 41 |
+
|
| 42 |
+
def forward(self, *tensors):
|
| 43 |
+
# Normalize weights with softmax
|
| 44 |
+
weights = F.softmax(self.weights, dim=0)
|
| 45 |
+
|
| 46 |
+
# Weighted sum
|
| 47 |
+
weighted_sum = sum(w * t for w, t in zip(weights, tensors))
|
| 48 |
+
|
| 49 |
+
# Scale by gamma
|
| 50 |
+
return self.gamma * weighted_sum
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class HFSeamlessCrossAttention(PreTrainedModel):
|
| 54 |
+
"""SeamlessM4T model with cross attention for HuggingFace Hub."""
|
| 55 |
+
|
| 56 |
+
config_class = SeamlessCrossAttentionConfig
|
| 57 |
+
supports_gradient_checkpointing = True
|
| 58 |
+
|
| 59 |
+
def __init__(self, config):
|
| 60 |
+
super().__init__(config)
|
| 61 |
+
self.config = config
|
| 62 |
+
|
| 63 |
+
# Load the underlying SeamlessM4T model
|
| 64 |
+
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
|
| 65 |
+
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
|
| 66 |
+
self.seamless_model_text_encoder = self.seamless_model.text_encoder
|
| 67 |
+
|
| 68 |
+
# Freeze pre-trained models
|
| 69 |
+
for param in self.seamless_model_speech_encoder.parameters():
|
| 70 |
+
param.requires_grad = False
|
| 71 |
+
for param in self.seamless_model_text_encoder.parameters():
|
| 72 |
+
param.requires_grad = False
|
| 73 |
+
|
| 74 |
+
# Projection layers
|
| 75 |
+
self.audio_proj = nn.Linear(
|
| 76 |
+
self.seamless_model_speech_encoder.config.hidden_size,
|
| 77 |
+
config.hidden_size
|
| 78 |
+
)
|
| 79 |
+
self.text_proj = nn.Linear(
|
| 80 |
+
self.seamless_model_text_encoder.config.hidden_size,
|
| 81 |
+
config.hidden_size
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Layer norms
|
| 85 |
+
self.audio_norm = nn.LayerNorm(config.hidden_size)
|
| 86 |
+
self.text_norm = nn.LayerNorm(config.hidden_size)
|
| 87 |
+
|
| 88 |
+
# Cross-attention layers
|
| 89 |
+
self.audio_to_text_attention = nn.MultiheadAttention(
|
| 90 |
+
embed_dim=config.hidden_size,
|
| 91 |
+
num_heads=config.num_attention_heads,
|
| 92 |
+
dropout=config.dropout_prob,
|
| 93 |
+
batch_first=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.text_to_audio_attention = nn.MultiheadAttention(
|
| 97 |
+
embed_dim=config.hidden_size,
|
| 98 |
+
num_heads=config.num_attention_heads,
|
| 99 |
+
dropout=config.dropout_prob,
|
| 100 |
+
batch_first=True
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Scalar mix for combining embeddings
|
| 104 |
+
self.scalar_mix = ScalarMix(num_inputs=4)
|
| 105 |
+
|
| 106 |
+
# Enhanced classifier with residual connections
|
| 107 |
+
self.fc = nn.Sequential(
|
| 108 |
+
nn.Linear(config.hidden_size, 512),
|
| 109 |
+
nn.ReLU(),
|
| 110 |
+
nn.Dropout(config.dropout_prob),
|
| 111 |
+
nn.Linear(512, 256),
|
| 112 |
+
nn.ReLU(),
|
| 113 |
+
nn.Dropout(config.dropout_prob),
|
| 114 |
+
nn.Linear(256, 1)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Initialize new layers
|
| 118 |
+
self._initialize_new_layers()
|
| 119 |
+
|
| 120 |
+
def _initialize_new_layers(self):
|
| 121 |
+
"""Initialize new layers with proper weights."""
|
| 122 |
+
for module in [self.audio_proj, self.text_proj, self.fc]:
|
| 123 |
+
if isinstance(module, nn.Linear):
|
| 124 |
+
nn.init.xavier_uniform_(module.weight)
|
| 125 |
+
nn.init.zeros_(module.bias)
|
| 126 |
+
elif isinstance(module, nn.Sequential):
|
| 127 |
+
for layer in module:
|
| 128 |
+
if isinstance(layer, nn.Linear):
|
| 129 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 130 |
+
nn.init.zeros_(layer.bias)
|
| 131 |
+
|
| 132 |
+
def forward(
|
| 133 |
+
self,
|
| 134 |
+
input_features,
|
| 135 |
+
input_ids,
|
| 136 |
+
text_attention_mask,
|
| 137 |
+
audio_attention_mask=None,
|
| 138 |
+
labels=None,
|
| 139 |
+
**kwargs # Accept additional features but ignore them
|
| 140 |
+
):
|
| 141 |
+
# Create default audio attention mask if not provided
|
| 142 |
+
if audio_attention_mask is None:
|
| 143 |
+
audio_attention_mask = torch.ones(
|
| 144 |
+
input_features.size(0), input_features.size(1),
|
| 145 |
+
device=input_features.device
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# 1. Encode audio
|
| 149 |
+
audio_output = self.seamless_model_speech_encoder(
|
| 150 |
+
input_features=input_features,
|
| 151 |
+
attention_mask=audio_attention_mask
|
| 152 |
+
)
|
| 153 |
+
audio_hidden_states = audio_output.last_hidden_state # [batch_size, audio_seq_len, hidden_size]
|
| 154 |
+
|
| 155 |
+
# 2. Encode text
|
| 156 |
+
text_output = self.seamless_model_text_encoder(
|
| 157 |
+
input_ids=input_ids,
|
| 158 |
+
attention_mask=text_attention_mask
|
| 159 |
+
)
|
| 160 |
+
text_hidden_states = text_output.last_hidden_state # [batch_size, text_seq_len, hidden_size]
|
| 161 |
+
|
| 162 |
+
# 3. Project to common dimension
|
| 163 |
+
audio_projected = self.audio_proj(audio_hidden_states) # [batch_size, audio_seq_len, hidden_size]
|
| 164 |
+
text_projected = self.text_proj(text_hidden_states) # [batch_size, text_seq_len, hidden_size]
|
| 165 |
+
|
| 166 |
+
audio_projected = self.audio_norm(audio_projected)
|
| 167 |
+
text_projected = self.text_norm(text_projected)
|
| 168 |
+
|
| 169 |
+
# 4. Global pooling (mean) of original embeddings
|
| 170 |
+
audio_global = audio_projected.mean(dim=1) # [batch_size, hidden_size]
|
| 171 |
+
text_global = text_projected.mean(dim=1) # [batch_size, hidden_size]
|
| 172 |
+
|
| 173 |
+
# 5. Cross-attention with masks
|
| 174 |
+
# Audio attends to text - each audio token attends to all text tokens
|
| 175 |
+
audio_attended_to_text, _ = self.audio_to_text_attention(
|
| 176 |
+
query=audio_projected, # [batch_size, audio_seq_len, hidden_size]
|
| 177 |
+
key=text_projected, # [batch_size, text_seq_len, hidden_size]
|
| 178 |
+
value=text_projected, # [batch_size, text_seq_len, hidden_size]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Text attends to audio - each text token attends to all audio tokens
|
| 182 |
+
text_attended_to_audio, _ = self.text_to_audio_attention(
|
| 183 |
+
query=text_projected, # [batch_size, text_seq_len, hidden_size]
|
| 184 |
+
key=audio_projected, # [batch_size, audio_seq_len, hidden_size]
|
| 185 |
+
value=audio_projected, # [batch_size, audio_seq_len, hidden_size]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# 6. Global pooling (mean) of attended embeddings
|
| 189 |
+
audio_attended_emb = audio_attended_to_text.mean(dim=1) # [batch_size, hidden_size]
|
| 190 |
+
text_attended_emb = text_attended_to_audio.mean(dim=1) # [batch_size, hidden_size]
|
| 191 |
+
|
| 192 |
+
# 7. Combine with scalar mix
|
| 193 |
+
final_embedding = self.scalar_mix(
|
| 194 |
+
audio_global,
|
| 195 |
+
text_global,
|
| 196 |
+
audio_attended_emb,
|
| 197 |
+
text_attended_emb
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# 8. Classification
|
| 201 |
+
logits = self.fc(final_embedding).squeeze(-1)
|
| 202 |
+
|
| 203 |
+
# Compute loss if labels are provided
|
| 204 |
+
loss = None
|
| 205 |
+
if labels is not None:
|
| 206 |
+
mse_loss = F.mse_loss(logits, labels)
|
| 207 |
+
|
| 208 |
+
# Add embedding regularization if specified
|
| 209 |
+
reg_loss = 0.0
|
| 210 |
+
if self.config.embedding_regularization > 0:
|
| 211 |
+
reg_loss = (
|
| 212 |
+
torch.norm(audio_global, p=2, dim=1).mean() +
|
| 213 |
+
torch.norm(text_global, p=2, dim=1).mean() +
|
| 214 |
+
torch.norm(audio_attended_emb, p=2, dim=1).mean() +
|
| 215 |
+
torch.norm(text_attended_emb, p=2, dim=1).mean()
|
| 216 |
+
) / 4.0
|
| 217 |
+
|
| 218 |
+
loss = mse_loss + self.config.embedding_regularization * reg_loss
|
| 219 |
+
|
| 220 |
+
return SequenceClassifierOutput(
|
| 221 |
+
loss=loss,
|
| 222 |
+
logits=logits,
|
| 223 |
+
hidden_states=None,
|
| 224 |
+
attentions=None
|
| 225 |
+
)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:29764a5e44028038b2251c4bc21f8bccaefc03f06bd4e796f77683b4e7914e51
|
| 3 |
+
size 4883154633
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.50.2
|
| 2 |
+
torch>=2.6.0
|
| 3 |
+
torchaudio>=2.6.0
|
| 4 |
+
huggingface_hub>=0.33.0
|
| 5 |
+
numpy>=2.2.3
|
| 6 |
+
sentencepiece>=0.2.0
|
| 7 |
+
accelerate>=1.5.2
|
| 8 |
+
soundfile>=0.13.1
|