Commit
·
e64d09b
1
Parent(s):
79fe833
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
A simple use case:
|
3 |
+
|
4 |
+
```shell
|
5 |
+
from transformers import Wav2Vec2Processor, AutoModel
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from datasets import load_dataset
|
9 |
+
|
10 |
+
# load demo audio and set processor
|
11 |
+
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
12 |
+
dataset = dataset.sort("id")
|
13 |
+
sampling_rate = dataset.features["audio"].sampling_rate
|
14 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
|
15 |
+
|
16 |
+
# loading our model weights
|
17 |
+
model = AutoModel.from_pretrained("m-a-p/MERT-v0")
|
18 |
+
|
19 |
+
# audio file is decoded on the fly
|
20 |
+
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
21 |
+
with torch.no_grad():
|
22 |
+
outputs = model(**inputs, output_hidden_states=True)
|
23 |
+
|
24 |
+
# take a look at the output shape, there are 13 layers of representation
|
25 |
+
# each layer performs differently in different downstream tasks, you should choose empirically
|
26 |
+
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
|
27 |
+
print(all_layer_hidden_states.shape) # [13 layer, 292 timestep, 768 feature_dim]
|
28 |
+
|
29 |
+
# for utterance level classification tasks, you can simply reduce the representation in time
|
30 |
+
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
|
31 |
+
print(time_reduced_hidden_states.shape) # [13, 768]
|
32 |
+
|
33 |
+
# you can even use a learnable weighted average representation
|
34 |
+
aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
|
35 |
+
weighted_avg_hidden_states = aggregator(time_reduced_hidden_states).squeeze()
|
36 |
+
print(weighted_avg_hidden_states.shape) # [768]
|
37 |
+
```
|