ldhldh commited on
Commit
bb5a96d
·
verified ·
1 Parent(s): 274c600

Upload 11 files

Browse files
src/models/__init__.py ADDED
File without changes
src/models/assets/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd2cc75e70e36fcbdd8ffbc2499062f30094093e6bf2cbafa9859f59972b420b
3
+ size 2048
src/models/lcnn.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is modified version of LCNN baseline
3
+ from ASVSpoof2021 challenge - https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-LFCC-LCNN/project/baseline_LA/model.py
4
+ """
5
+ import sys
6
+
7
+ import torch
8
+ import torch.nn as torch_nn
9
+
10
+ from src import frontends
11
+
12
+
13
+ NUM_COEFFICIENTS = 384
14
+
15
+
16
+ # For blstm
17
+ class BLSTMLayer(torch_nn.Module):
18
+ """ Wrapper over dilated conv1D
19
+ Input tensor: (batchsize=1, length, dim_in)
20
+ Output tensor: (batchsize=1, length, dim_out)
21
+ We want to keep the length the same
22
+ """
23
+ def __init__(self, input_dim, output_dim):
24
+ super().__init__()
25
+ if output_dim % 2 != 0:
26
+ print("Output_dim of BLSTMLayer is {:d}".format(output_dim))
27
+ print("BLSTMLayer expects a layer size of even number")
28
+ sys.exit(1)
29
+ # bi-directional LSTM
30
+ self.l_blstm = torch_nn.LSTM(
31
+ input_dim,
32
+ output_dim // 2,
33
+ bidirectional=True
34
+ )
35
+ def forward(self, x):
36
+ # permute to (length, batchsize=1, dim)
37
+ blstm_data, _ = self.l_blstm(x.permute(1, 0, 2))
38
+ # permute it backt to (batchsize=1, length, dim)
39
+ return blstm_data.permute(1, 0, 2)
40
+
41
+
42
+ class MaxFeatureMap2D(torch_nn.Module):
43
+ """ Max feature map (along 2D)
44
+
45
+ MaxFeatureMap2D(max_dim=1)
46
+
47
+ l_conv2d = MaxFeatureMap2D(1)
48
+ data_in = torch.rand([1, 4, 5, 5])
49
+ data_out = l_conv2d(data_in)
50
+
51
+
52
+ Input:
53
+ ------
54
+ data_in: tensor of shape (batch, channel, ...)
55
+
56
+ Output:
57
+ -------
58
+ data_out: tensor of shape (batch, channel//2, ...)
59
+
60
+ Note
61
+ ----
62
+ By default, Max-feature-map is on channel dimension,
63
+ and maxout is used on (channel ...)
64
+ """
65
+ def __init__(self, max_dim = 1):
66
+ super().__init__()
67
+ self.max_dim = max_dim
68
+
69
+ def forward(self, inputs):
70
+ # suppose inputs (batchsize, channel, length, dim)
71
+
72
+ shape = list(inputs.size())
73
+
74
+ if self.max_dim >= len(shape):
75
+ print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
76
+ print("But input has %d dimensions" % (len(shape)))
77
+ sys.exit(1)
78
+ if shape[self.max_dim] // 2 * 2 != shape[self.max_dim]:
79
+ print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
80
+ print("But this dimension has an odd number of data")
81
+ sys.exit(1)
82
+ shape[self.max_dim] = shape[self.max_dim]//2
83
+ shape.insert(self.max_dim, 2)
84
+
85
+ # view to (batchsize, 2, channel//2, ...)
86
+ # maximize on the 2nd dim
87
+ m, i = inputs.view(*shape).max(self.max_dim)
88
+ return m
89
+
90
+
91
+ ##############
92
+ ## FOR MODEL
93
+ ##############
94
+
95
+ class LCNN(torch_nn.Module):
96
+ """ Model definition
97
+ """
98
+ def __init__(self, **kwargs):
99
+ super().__init__()
100
+ input_channels = kwargs.get("input_channels", 1)
101
+ num_coefficients = kwargs.get("num_coefficients", NUM_COEFFICIENTS)
102
+
103
+ # Working sampling rate
104
+ self.num_coefficients = num_coefficients
105
+
106
+ # dimension of embedding vectors
107
+ # here, the embedding is just the activation before sigmoid()
108
+ self.v_emd_dim = 1
109
+
110
+ # it can handle models with multiple front-end configuration
111
+ # by default, only a single front-end
112
+
113
+ self.m_transform = torch_nn.Sequential(
114
+ torch_nn.Conv2d(input_channels, 64, (5, 5), 1, padding=(2, 2)),
115
+ MaxFeatureMap2D(),
116
+ torch.nn.MaxPool2d((2, 2), (2, 2)),
117
+
118
+ torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)),
119
+ MaxFeatureMap2D(),
120
+ torch_nn.BatchNorm2d(32, affine=False),
121
+ torch_nn.Conv2d(32, 96, (3, 3), 1, padding=(1, 1)),
122
+ MaxFeatureMap2D(),
123
+
124
+ torch.nn.MaxPool2d((2, 2), (2, 2)),
125
+ torch_nn.BatchNorm2d(48, affine=False),
126
+
127
+ torch_nn.Conv2d(48, 96, (1, 1), 1, padding=(0, 0)),
128
+ MaxFeatureMap2D(),
129
+ torch_nn.BatchNorm2d(48, affine=False),
130
+ torch_nn.Conv2d(48, 128, (3, 3), 1, padding=(1, 1)),
131
+ MaxFeatureMap2D(),
132
+
133
+ torch.nn.MaxPool2d((2, 2), (2, 2)),
134
+
135
+ torch_nn.Conv2d(64, 128, (1, 1), 1, padding=(0, 0)),
136
+ MaxFeatureMap2D(),
137
+ torch_nn.BatchNorm2d(64, affine=False),
138
+ torch_nn.Conv2d(64, 64, (3, 3), 1, padding=(1, 1)),
139
+ MaxFeatureMap2D(),
140
+ torch_nn.BatchNorm2d(32, affine=False),
141
+
142
+ torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)),
143
+ MaxFeatureMap2D(),
144
+ torch_nn.BatchNorm2d(32, affine=False),
145
+ torch_nn.Conv2d(32, 64, (3, 3), 1, padding=(1, 1)),
146
+ MaxFeatureMap2D(),
147
+ torch_nn.MaxPool2d((2, 2), (2, 2)),
148
+
149
+ torch_nn.Dropout(0.7)
150
+ )
151
+
152
+ self.m_before_pooling = torch_nn.Sequential(
153
+ BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32),
154
+ BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32)
155
+ )
156
+
157
+ self.m_output_act = torch_nn.Linear((self.num_coefficients // 16) * 32, self.v_emd_dim)
158
+
159
+ def _compute_embedding(self, x):
160
+ """ definition of forward method
161
+ Assume x (batchsize, length, dim)
162
+ Output x (batchsize * number_filter, output_dim)
163
+ """
164
+ # resample if necessary
165
+ # x = self.m_resampler(x.squeeze(-1)).unsqueeze(-1)
166
+
167
+ # number of sub models
168
+ batch_size = x.shape[0]
169
+
170
+ # buffer to store output scores from sub-models
171
+ output_emb = torch.zeros(
172
+ [batch_size, self.v_emd_dim],
173
+ device=x.device,
174
+ dtype=x.dtype
175
+ )
176
+
177
+ # compute scores for each sub-models
178
+ idx = 0
179
+
180
+ # compute scores
181
+ # 1. unsqueeze to (batch, 1, frame_length, fft_bin)
182
+ # 2. compute hidden features
183
+ x = x.permute(0,1,3,2)
184
+ hidden_features = self.m_transform(x)
185
+
186
+ # 3. (batch, channel, frame//N, feat_dim//N) ->
187
+ # (batch, frame//N, channel * feat_dim//N)
188
+ # where N is caused by conv with stride
189
+ hidden_features = hidden_features.permute(0, 2, 1, 3).contiguous()
190
+ frame_num = hidden_features.shape[1]
191
+
192
+ hidden_features = hidden_features.view(batch_size, frame_num, -1)
193
+ # 4. pooling
194
+ # 4. pass through LSTM then summingc
195
+ hidden_features_lstm = self.m_before_pooling(hidden_features)
196
+
197
+ # 5. pass through the output layer
198
+ tmp_emb = self.m_output_act((hidden_features_lstm + hidden_features).mean(1))
199
+ output_emb[idx * batch_size : (idx+1) * batch_size] = tmp_emb
200
+
201
+ return output_emb
202
+
203
+ def _compute_score(self, feature_vec):
204
+ # feature_vec is [batch * submodel, 1]
205
+ return torch.sigmoid(feature_vec).squeeze(1)
206
+
207
+ def forward(self, x):
208
+ feature_vec = self._compute_embedding(x)
209
+ return feature_vec
210
+
211
+
212
+
213
+ class FrontendLCNN(LCNN):
214
+ """ Model definition
215
+ """
216
+ def __init__(self, device: str = "cuda", **kwargs):
217
+ super().__init__(**kwargs)
218
+
219
+ self.device = device
220
+
221
+ frontend_name = kwargs.get("frontend_algorithm", [])
222
+ self.frontend = frontends.get_frontend(frontend_name)
223
+ print(f"Using {frontend_name} frontend")
224
+
225
+ def _compute_frontend(self, x):
226
+ frontend = self.frontend(x)
227
+ if frontend.ndim < 4:
228
+ return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
229
+ return frontend # (bs, n, n_lfcc, frames)
230
+
231
+ def forward(self, x):
232
+ x = self._compute_frontend(x)
233
+ feature_vec = self._compute_embedding(x)
234
+
235
+ return feature_vec
236
+
237
+
238
+ if __name__ == "__main__":
239
+
240
+ device = "cuda"
241
+ print("Definition of model")
242
+ model = FrontendLCNN(input_channels=2, num_coefficients=80, device=device, frontend_algorithm=["mel_spec"])
243
+ model = model.to(device)
244
+ batch_size = 12
245
+ mock_input = torch.rand((batch_size, 64_600,), device=device)
246
+ output = model(mock_input)
247
+ print(output.shape)
src/models/meso_net.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is modified version of MesoNet DeepFake detection solution
3
+ from FakeAVCeleb repository - https://github.com/DASH-Lab/FakeAVCeleb/blob/main/models/MesoNet.py.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from src import frontends
9
+
10
+
11
+ class MesoInception4(nn.Module):
12
+ """
13
+ Pytorch Implemention of MesoInception4
14
+ Author: Honggu Liu
15
+ Date: July 7, 2019
16
+ """
17
+ def __init__(self, num_classes=1, **kwargs):
18
+ super().__init__()
19
+
20
+ self.fc1_dim = kwargs.get("fc1_dim", 1024)
21
+ input_channels = kwargs.get("input_channels", 3)
22
+ self.num_classes = num_classes
23
+
24
+ #InceptionLayer1
25
+ self.Incption1_conv1 = nn.Conv2d(input_channels, 1, 1, padding=0, bias=False)
26
+ self.Incption1_conv2_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
27
+ self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
28
+ self.Incption1_conv3_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
29
+ self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
30
+ self.Incption1_conv4_1 = nn.Conv2d(input_channels, 2, 1, padding=0, bias=False)
31
+ self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
32
+ self.Incption1_bn = nn.BatchNorm2d(11)
33
+
34
+
35
+ #InceptionLayer2
36
+ self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
37
+ self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
38
+ self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
39
+ self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
40
+ self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
41
+ self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
42
+ self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
43
+ self.Incption2_bn = nn.BatchNorm2d(12)
44
+
45
+ #Normal Layer
46
+ self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ self.leakyrelu = nn.LeakyReLU(0.1)
49
+ self.bn1 = nn.BatchNorm2d(16)
50
+ self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2))
51
+
52
+ self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False)
53
+ self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4))
54
+
55
+ self.dropout = nn.Dropout2d(0.5)
56
+ self.fc1 = nn.Linear(self.fc1_dim, 16)
57
+ self.fc2 = nn.Linear(16, num_classes)
58
+
59
+
60
+ #InceptionLayer
61
+ def InceptionLayer1(self, input):
62
+ x1 = self.Incption1_conv1(input)
63
+ x2 = self.Incption1_conv2_1(input)
64
+ x2 = self.Incption1_conv2_2(x2)
65
+ x3 = self.Incption1_conv3_1(input)
66
+ x3 = self.Incption1_conv3_2(x3)
67
+ x4 = self.Incption1_conv4_1(input)
68
+ x4 = self.Incption1_conv4_2(x4)
69
+ y = torch.cat((x1, x2, x3, x4), 1)
70
+ y = self.Incption1_bn(y)
71
+ y = self.maxpooling1(y)
72
+
73
+ return y
74
+
75
+ def InceptionLayer2(self, input):
76
+ x1 = self.Incption2_conv1(input)
77
+ x2 = self.Incption2_conv2_1(input)
78
+ x2 = self.Incption2_conv2_2(x2)
79
+ x3 = self.Incption2_conv3_1(input)
80
+ x3 = self.Incption2_conv3_2(x3)
81
+ x4 = self.Incption2_conv4_1(input)
82
+ x4 = self.Incption2_conv4_2(x4)
83
+ y = torch.cat((x1, x2, x3, x4), 1)
84
+ y = self.Incption2_bn(y)
85
+ y = self.maxpooling1(y)
86
+
87
+ return y
88
+
89
+ def forward(self, input):
90
+ x = self._compute_embedding(input)
91
+ return x
92
+
93
+ def _compute_embedding(self, input):
94
+ x = self.InceptionLayer1(input) #(Batch, 11, 128, 128)
95
+ x = self.InceptionLayer2(x) #(Batch, 12, 64, 64)
96
+
97
+ x = self.conv1(x) #(Batch, 16, 64 ,64)
98
+ x = self.relu(x)
99
+ x = self.bn1(x)
100
+ x = self.maxpooling1(x) #(Batch, 16, 32, 32)
101
+
102
+ x = self.conv2(x) #(Batch, 16, 32, 32)
103
+ x = self.relu(x)
104
+ x = self.bn1(x)
105
+ x = self.maxpooling2(x) #(Batch, 16, 8, 8)
106
+
107
+ x = x.view(x.size(0), -1) #(Batch, 16*8*8)
108
+ x = self.dropout(x)
109
+
110
+ x = nn.AdaptiveAvgPool1d(self.fc1_dim)(x)
111
+ x = self.fc1(x) #(Batch, 16) ### <-- o tu
112
+ x = self.leakyrelu(x)
113
+ x = self.dropout(x)
114
+ x = self.fc2(x)
115
+ return x
116
+
117
+
118
+ class FrontendMesoInception4(MesoInception4):
119
+
120
+ def __init__(self, **kwargs):
121
+ super().__init__(**kwargs)
122
+
123
+ self.device = kwargs['device']
124
+
125
+ frontend_name = kwargs.get("frontend_algorithm", [])
126
+ self.frontend = frontends.get_frontend(frontend_name)
127
+ print(f"Using {frontend_name} frontend")
128
+
129
+ def forward(self, x):
130
+ x = self.frontend(x)
131
+ x = self._compute_embedding(x)
132
+ return x
133
+
134
+
135
+ if __name__ == "__main__":
136
+ model = FrontendMesoInception4(
137
+ input_channels=2,
138
+ fc1_dim=1024,
139
+ device='cuda',
140
+ frontend_algorithm="lfcc"
141
+ )
142
+
143
+ def count_parameters(model) -> int:
144
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
145
+ return pytorch_total_params
146
+ print(count_parameters(model))
src/models/models.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from src.models import (
4
+ lcnn,
5
+ specrnet,
6
+ whisper_specrnet,
7
+ rawnet3,
8
+ whisper_lcnn,
9
+ meso_net,
10
+ whisper_meso_net
11
+ )
12
+
13
+
14
+ def get_model(model_name: str, config: Dict, device: str):
15
+ if model_name == "rawnet3":
16
+ return rawnet3.prepare_model()
17
+ elif model_name == "lcnn":
18
+ return lcnn.FrontendLCNN(device=device, **config)
19
+ elif model_name == "specrnet":
20
+ return specrnet.FrontendSpecRNet(
21
+ device=device,
22
+ **config,
23
+ )
24
+ elif model_name == "mesonet":
25
+ return meso_net.FrontendMesoInception4(
26
+ input_channels=config.get("input_channels", 1),
27
+ fc1_dim=config.get("fc1_dim", 1024),
28
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
29
+ device=device,
30
+ )
31
+ elif model_name == "whisper_lcnn":
32
+ return whisper_lcnn.WhisperLCNN(
33
+ input_channels=config.get("input_channels", 1),
34
+ freeze_encoder=config.get("freeze_encoder", False),
35
+ device=device,
36
+ )
37
+ elif model_name == "whisper_specrnet":
38
+ return whisper_specrnet.WhisperSpecRNet(
39
+ input_channels=config.get("input_channels", 1),
40
+ freeze_encoder=config.get("freeze_encoder", False),
41
+ device=device,
42
+ )
43
+ elif model_name == "whisper_mesonet":
44
+ return whisper_meso_net.WhisperMesoNet(
45
+ input_channels=config.get("input_channels", 1),
46
+ freeze_encoder=config.get("freeze_encoder", True),
47
+ fc1_dim=config.get("fc1_dim", 1024),
48
+ device=device,
49
+ )
50
+ elif model_name == "whisper_frontend_lcnn":
51
+ return whisper_lcnn.WhisperMultiFrontLCNN(
52
+ input_channels=config.get("input_channels", 2),
53
+ freeze_encoder=config.get("freeze_encoder", False),
54
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
55
+ device=device,
56
+ )
57
+ elif model_name == "whisper_frontend_specrnet":
58
+ return whisper_specrnet.WhisperMultiFrontSpecRNet(
59
+ input_channels=config.get("input_channels", 2),
60
+ freeze_encoder=config.get("freeze_encoder", False),
61
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
62
+ device=device,
63
+ )
64
+ elif model_name == "whisper_frontend_mesonet":
65
+ return whisper_meso_net.WhisperMultiFrontMesoNet(
66
+ input_channels=config.get("input_channels", 2),
67
+ fc1_dim=config.get("fc1_dim", 1024),
68
+ freeze_encoder=config.get("freeze_encoder", True),
69
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
70
+ device=device,
71
+ )
72
+ else:
73
+ raise ValueError(f"Model '{model_name}' not supported")
src/models/rawnet3.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains implementation of RawNet3 architecture.
3
+ The original implementation can be found here: https://github.com/Jungjee/RawNet/tree/master/python/RawNet3
4
+ """
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from asteroid_filterbanks import Encoder, ParamSincFB # pip install asteroid_filterbanks
11
+
12
+
13
+ class RawNet3(nn.Module):
14
+ def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
15
+ super().__init__()
16
+
17
+ nOut = kwargs["nOut"]
18
+
19
+ self.context = context
20
+ self.encoder_type = kwargs["encoder_type"]
21
+ self.log_sinc = kwargs["log_sinc"]
22
+ self.norm_sinc = kwargs["norm_sinc"]
23
+ self.out_bn = kwargs["out_bn"]
24
+ self.summed = summed
25
+
26
+ self.preprocess = nn.Sequential(
27
+ PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
28
+ )
29
+ self.conv1 = Encoder(
30
+ ParamSincFB(
31
+ C // 4,
32
+ 251,
33
+ stride=kwargs["sinc_stride"],
34
+ )
35
+ )
36
+ self.relu = nn.ReLU()
37
+ self.bn1 = nn.BatchNorm1d(C // 4)
38
+
39
+ self.layer1 = block(
40
+ C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
41
+ )
42
+ self.layer2 = block(
43
+ C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3
44
+ )
45
+ self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
46
+ self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
47
+
48
+ if self.context:
49
+ attn_input = 1536 * 3
50
+ else:
51
+ attn_input = 1536
52
+ print("self.encoder_type", self.encoder_type)
53
+ if self.encoder_type == "ECA":
54
+ attn_output = 1536
55
+ elif self.encoder_type == "ASP":
56
+ attn_output = 1
57
+ else:
58
+ raise ValueError("Undefined encoder")
59
+
60
+ self.attention = nn.Sequential(
61
+ nn.Conv1d(attn_input, 128, kernel_size=1),
62
+ nn.ReLU(),
63
+ nn.BatchNorm1d(128),
64
+ nn.Conv1d(128, attn_output, kernel_size=1),
65
+ nn.Softmax(dim=2),
66
+ )
67
+
68
+ self.bn5 = nn.BatchNorm1d(3072)
69
+
70
+ self.fc6 = nn.Linear(3072, nOut)
71
+ self.bn6 = nn.BatchNorm1d(nOut)
72
+
73
+ self.mp3 = nn.MaxPool1d(3)
74
+
75
+ def forward(self, x):
76
+ """
77
+ :param x: input mini-batch (bs, samp)
78
+ """
79
+
80
+ with torch.cuda.amp.autocast(enabled=False):
81
+ x = self.preprocess(x)
82
+ x = torch.abs(self.conv1(x))
83
+ if self.log_sinc:
84
+ x = torch.log(x + 1e-6)
85
+ if self.norm_sinc == "mean":
86
+ x = x - torch.mean(x, dim=-1, keepdim=True)
87
+ elif self.norm_sinc == "mean_std":
88
+ m = torch.mean(x, dim=-1, keepdim=True)
89
+ s = torch.std(x, dim=-1, keepdim=True)
90
+ s[s < 0.001] = 0.001
91
+ x = (x - m) / s
92
+
93
+ if self.summed:
94
+ x1 = self.layer1(x)
95
+ x2 = self.layer2(x1)
96
+ x3 = self.layer3(self.mp3(x1) + x2)
97
+ else:
98
+ x1 = self.layer1(x)
99
+ x2 = self.layer2(x1)
100
+ x3 = self.layer3(x2)
101
+
102
+ x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
103
+ x = self.relu(x)
104
+
105
+ t = x.size()[-1]
106
+
107
+ if self.context:
108
+ global_x = torch.cat(
109
+ (
110
+ x,
111
+ torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
112
+ torch.sqrt(
113
+ torch.var(x, dim=2, keepdim=True).clamp(
114
+ min=1e-4, max=1e4
115
+ )
116
+ ).repeat(1, 1, t),
117
+ ),
118
+ dim=1,
119
+ )
120
+ else:
121
+ global_x = x
122
+
123
+ w = self.attention(global_x)
124
+
125
+ mu = torch.sum(x * w, dim=2)
126
+ sg = torch.sqrt(
127
+ (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
128
+ )
129
+
130
+ x = torch.cat((mu, sg), 1)
131
+
132
+ x = self.bn5(x)
133
+
134
+ x = self.fc6(x)
135
+
136
+ if self.out_bn:
137
+ x = self.bn6(x)
138
+
139
+ return x
140
+
141
+
142
+ class PreEmphasis(torch.nn.Module):
143
+ def __init__(self, coef: float = 0.97) -> None:
144
+ super().__init__()
145
+ self.coef = coef
146
+ # make kernel
147
+ # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
148
+ self.register_buffer(
149
+ "flipped_filter",
150
+ torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
151
+ )
152
+
153
+ def forward(self, input: torch.tensor) -> torch.tensor:
154
+ assert (
155
+ len(input.size()) == 2
156
+ ), "The number of dimensions of input tensor must be 2!"
157
+ # reflect padding to match lengths of in/out
158
+ input = input.unsqueeze(1)
159
+ input = F.pad(input, (1, 0), "reflect")
160
+ return F.conv1d(input, self.flipped_filter)
161
+
162
+
163
+ class AFMS(nn.Module):
164
+ """
165
+ Alpha-Feature map scaling, added to the output of each residual block[1,2].
166
+
167
+ Reference:
168
+ [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
169
+ [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
170
+ """
171
+
172
+ def __init__(self, nb_dim: int) -> None:
173
+ super().__init__()
174
+ self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
175
+ self.fc = nn.Linear(nb_dim, nb_dim)
176
+ self.sig = nn.Sigmoid()
177
+
178
+ def forward(self, x):
179
+ y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
180
+ y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
181
+
182
+ x = x + self.alpha
183
+ x = x * y
184
+ return x
185
+
186
+
187
+ class Bottle2neck(nn.Module):
188
+ def __init__(
189
+ self,
190
+ inplanes,
191
+ planes,
192
+ kernel_size=None,
193
+ dilation=None,
194
+ scale=4,
195
+ pool=False,
196
+ ):
197
+
198
+ super().__init__()
199
+
200
+ width = int(math.floor(planes / scale))
201
+
202
+ self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
203
+ self.bn1 = nn.BatchNorm1d(width * scale)
204
+
205
+ self.nums = scale - 1
206
+
207
+ convs = []
208
+ bns = []
209
+
210
+ num_pad = math.floor(kernel_size / 2) * dilation
211
+
212
+ for i in range(self.nums):
213
+ convs.append(
214
+ nn.Conv1d(
215
+ width,
216
+ width,
217
+ kernel_size=kernel_size,
218
+ dilation=dilation,
219
+ padding=num_pad,
220
+ )
221
+ )
222
+ bns.append(nn.BatchNorm1d(width))
223
+
224
+ self.convs = nn.ModuleList(convs)
225
+ self.bns = nn.ModuleList(bns)
226
+
227
+ self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
228
+ self.bn3 = nn.BatchNorm1d(planes)
229
+
230
+ self.relu = nn.ReLU()
231
+
232
+ self.width = width
233
+
234
+ self.mp = nn.MaxPool1d(pool) if pool else False
235
+ self.afms = AFMS(planes)
236
+
237
+ if inplanes != planes: # if change in number of filters
238
+ self.residual = nn.Sequential(
239
+ nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
240
+ )
241
+ else:
242
+ self.residual = nn.Identity()
243
+
244
+ def forward(self, x):
245
+ residual = self.residual(x)
246
+
247
+ out = self.conv1(x)
248
+ out = self.relu(out)
249
+ out = self.bn1(out)
250
+
251
+ spx = torch.split(out, self.width, 1)
252
+ for i in range(self.nums):
253
+ if i == 0:
254
+ sp = spx[i]
255
+ else:
256
+ sp = sp + spx[i]
257
+ sp = self.convs[i](sp)
258
+ sp = self.relu(sp)
259
+ sp = self.bns[i](sp)
260
+ if i == 0:
261
+ out = sp
262
+ else:
263
+ out = torch.cat((out, sp), 1)
264
+
265
+ out = torch.cat((out, spx[self.nums]), 1)
266
+
267
+ out = self.conv3(out)
268
+ out = self.relu(out)
269
+ out = self.bn3(out)
270
+
271
+ out += residual
272
+ if self.mp:
273
+ out = self.mp(out)
274
+ out = self.afms(out)
275
+
276
+ return out
277
+
278
+
279
+ def prepare_model():
280
+ model = RawNet3(
281
+ Bottle2neck,
282
+ model_scale=8,
283
+ context=True,
284
+ summed=True,
285
+ encoder_type="ECA",
286
+ nOut=1, # number of slices
287
+ out_bn=False,
288
+ sinc_stride=10,
289
+ log_sinc=True,
290
+ norm_sinc="mean",
291
+ grad_mult=1,
292
+ )
293
+ return model
294
+
295
+
296
+ if __name__ == "__main__":
297
+ model = RawNet3(
298
+ Bottle2neck,
299
+ model_scale=8,
300
+ context=True,
301
+ summed=True,
302
+ encoder_type="ECA",
303
+ nOut=1, # number of slices
304
+ out_bn=False,
305
+ sinc_stride=10,
306
+ log_sinc=True,
307
+ norm_sinc="mean",
308
+ grad_mult=1,
309
+ )
310
+ gpu = False
311
+
312
+ model.eval()
313
+ print("RawNet3 initialised & weights loaded!")
314
+
315
+ if torch.cuda.is_available():
316
+ print("Cuda available, conducting inference on GPU")
317
+ model = model.to("cuda")
318
+ gpu = True
319
+
320
+ audios = torch.rand(32, 64_600)
321
+
322
+ out = model(audios)
323
+ print(out.shape)
src/models/specrnet.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains implementation of SpecRNet architecture.
3
+ We base our codebase on the implementation of RawNet2 by Hemlata Tak ([email protected]).
4
+ It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py
5
+ """
6
+ from typing import Dict
7
+
8
+ import torch.nn as nn
9
+
10
+ from src import frontends
11
+
12
+
13
+ def get_config(input_channels: int) -> Dict:
14
+ return {
15
+ "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
16
+ "nb_fc_node": 64,
17
+ "gru_node": 64,
18
+ "nb_gru_layer": 2,
19
+ "nb_classes": 1,
20
+ }
21
+
22
+
23
+ class Residual_block2D(nn.Module):
24
+ def __init__(self, nb_filts, first=False):
25
+ super().__init__()
26
+ self.first = first
27
+
28
+ if not self.first:
29
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
30
+
31
+ self.lrelu = nn.LeakyReLU(negative_slope=0.3)
32
+
33
+ self.conv1 = nn.Conv2d(
34
+ in_channels=nb_filts[0],
35
+ out_channels=nb_filts[1],
36
+ kernel_size=3,
37
+ padding=1,
38
+ stride=1,
39
+ )
40
+
41
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
42
+ self.conv2 = nn.Conv2d(
43
+ in_channels=nb_filts[1],
44
+ out_channels=nb_filts[1],
45
+ padding=1,
46
+ kernel_size=3,
47
+ stride=1,
48
+ )
49
+
50
+ if nb_filts[0] != nb_filts[1]:
51
+ self.downsample = True
52
+ self.conv_downsample = nn.Conv2d(
53
+ in_channels=nb_filts[0],
54
+ out_channels=nb_filts[1],
55
+ padding=0,
56
+ kernel_size=1,
57
+ stride=1,
58
+ )
59
+
60
+ else:
61
+ self.downsample = False
62
+ self.mp = nn.MaxPool2d(2)
63
+
64
+ def forward(self, x):
65
+ identity = x
66
+ if not self.first:
67
+ out = self.bn1(x)
68
+ out = self.lrelu(out)
69
+ else:
70
+ out = x
71
+
72
+ out = self.conv1(x)
73
+ out = self.bn2(out)
74
+ out = self.lrelu(out)
75
+ out = self.conv2(out)
76
+
77
+ if self.downsample:
78
+ identity = self.conv_downsample(identity)
79
+
80
+ out += identity
81
+ out = self.mp(out)
82
+ return out
83
+
84
+
85
+ class SpecRNet(nn.Module):
86
+ def __init__(self, input_channels, **kwargs):
87
+ super().__init__()
88
+ config = get_config(input_channels=input_channels)
89
+
90
+ self.device = kwargs.get("device", "cuda")
91
+
92
+ self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0])
93
+ self.selu = nn.SELU(inplace=True)
94
+ self.block0 = nn.Sequential(
95
+ Residual_block2D(nb_filts=config["filts"][1], first=True)
96
+ )
97
+ self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
98
+ config["filts"][2][0] = config["filts"][2][1]
99
+ self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
100
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
101
+
102
+ self.fc_attention0 = self._make_attention_fc(
103
+ in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1]
104
+ )
105
+ self.fc_attention2 = self._make_attention_fc(
106
+ in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
107
+ )
108
+ self.fc_attention4 = self._make_attention_fc(
109
+ in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
110
+ )
111
+
112
+ self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1])
113
+ self.gru = nn.GRU(
114
+ input_size=config["filts"][2][-1],
115
+ hidden_size=config["gru_node"],
116
+ num_layers=config["nb_gru_layer"],
117
+ batch_first=True,
118
+ bidirectional=True,
119
+ )
120
+
121
+ self.fc1_gru = nn.Linear(
122
+ in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2
123
+ )
124
+
125
+ self.fc2_gru = nn.Linear(
126
+ in_features=config["nb_fc_node"] * 2,
127
+ out_features=config["nb_classes"],
128
+ bias=True,
129
+ )
130
+
131
+ self.sig = nn.Sigmoid()
132
+
133
+ def _compute_embedding(self, x):
134
+ x = self.first_bn(x)
135
+ x = self.selu(x)
136
+
137
+ x0 = self.block0(x)
138
+ y0 = self.avgpool(x0).view(x0.size(0), -1)
139
+ y0 = self.fc_attention0(y0)
140
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
141
+ y0 = y0.unsqueeze(-1)
142
+ x = x0 * y0 + y0
143
+
144
+ x = nn.MaxPool2d(2)(x)
145
+
146
+ x2 = self.block2(x)
147
+ y2 = self.avgpool(x2).view(x2.size(0), -1)
148
+ y2 = self.fc_attention2(y2)
149
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
150
+ y2 = y2.unsqueeze(-1)
151
+ x = x2 * y2 + y2
152
+
153
+ x = nn.MaxPool2d(2)(x)
154
+
155
+ x4 = self.block4(x)
156
+ y4 = self.avgpool(x4).view(x4.size(0), -1)
157
+ y4 = self.fc_attention4(y4)
158
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
159
+ y4 = y4.unsqueeze(-1)
160
+ x = x4 * y4 + y4
161
+
162
+ x = nn.MaxPool2d(2)(x)
163
+
164
+ x = self.bn_before_gru(x)
165
+ x = self.selu(x)
166
+ x = nn.AdaptiveAvgPool2d((1, None))(x)
167
+ x = x.squeeze(-2)
168
+ x = x.permute(0, 2, 1)
169
+ self.gru.flatten_parameters()
170
+ x, _ = self.gru(x)
171
+ x = x[:, -1, :]
172
+ x = self.fc1_gru(x)
173
+ x = self.fc2_gru(x)
174
+ return x
175
+
176
+ def forward(self, x):
177
+ x = self._compute_embedding(x)
178
+ return x
179
+
180
+ def _make_attention_fc(self, in_features, l_out_features):
181
+ l_fc = []
182
+ l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features))
183
+ return nn.Sequential(*l_fc)
184
+
185
+
186
+ class FrontendSpecRNet(SpecRNet):
187
+ def __init__(self, input_channels, **kwargs):
188
+ super().__init__(input_channels, **kwargs)
189
+
190
+ self.device = kwargs['device']
191
+
192
+ frontend_name = kwargs.get("frontend_algorithm", [])
193
+ self.frontend = frontends.get_frontend(frontend_name)
194
+ print(f"Using {frontend_name} frontend")
195
+
196
+ def _compute_frontend(self, x):
197
+ frontend = self.frontend(x)
198
+ if frontend.ndim < 4:
199
+ return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
200
+ return frontend # (bs, n, n_lfcc, frames)
201
+
202
+ def forward(self, x):
203
+ x = self._compute_frontend(x)
204
+ x = self._compute_embedding(x)
205
+ return x
206
+
207
+
208
+ if __name__ == "__main__":
209
+ print("Definition of model")
210
+ device = "cuda"
211
+
212
+ input_channels = 1
213
+ config = {
214
+ "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
215
+ "nb_fc_node": 64,
216
+ "gru_node": 64,
217
+ "nb_gru_layer": 2,
218
+ "nb_classes": 1,
219
+ }
220
+
221
+ def count_parameters(model) -> int:
222
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
223
+ return pytorch_total_params
224
+ model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"])
225
+ model = model.to(device)
226
+ print(count_parameters(model))
src/models/whisper_lcnn.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
4
+ from src.models.lcnn import LCNN
5
+ from src import frontends
6
+ from src.commons import WHISPER_MODEL_WEIGHTS_PATH
7
+
8
+
9
+ class WhisperLCNN(LCNN):
10
+
11
+ def __init__(self, input_channels, freeze_encoder, **kwargs):
12
+ super().__init__(input_channels=input_channels, **kwargs)
13
+
14
+ self.device = kwargs['device']
15
+ checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH)
16
+ dims = ModelDimensions(**checkpoint["dims"].__dict__)
17
+ model = Whisper(dims)
18
+ model = model.to(self.device)
19
+ model.load_state_dict(checkpoint["model_state_dict"])
20
+ self.whisper_model = model
21
+ if freeze_encoder:
22
+ for param in self.whisper_model.parameters():
23
+ param.requires_grad = False
24
+
25
+ def compute_whisper_features(self, x):
26
+ specs = []
27
+ for sample in x:
28
+ specs.append(log_mel_spectrogram(sample))
29
+ x = torch.stack(specs)
30
+ x = self.whisper_model(x)
31
+
32
+ x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc)
33
+ x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc)
34
+ x = x.repeat(
35
+ (1, 1, 1, 2)
36
+ ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
37
+ return x
38
+
39
+ def forward(self, x):
40
+ # we assume that the data is correct (i.e. 30s)
41
+ x = self.compute_whisper_features(x)
42
+ out = self._compute_embedding(x)
43
+ return out
44
+
45
+
46
+ class WhisperMultiFrontLCNN(WhisperLCNN):
47
+
48
+ def __init__(self, input_channels, freeze_encoder, **kwargs):
49
+ super().__init__(input_channels=input_channels, freeze_encoder=freeze_encoder, **kwargs)
50
+
51
+ self.frontend = frontends.get_frontend(kwargs['frontend_algorithm'])
52
+ print(f"Using {self.frontend} frontend!")
53
+
54
+ def forward(self, x):
55
+ # Frontend computation
56
+ frontend_x = self.frontend(x)
57
+ x = self.compute_whisper_features(x)
58
+
59
+ x = torch.cat([x, frontend_x], 1)
60
+ out = self._compute_embedding(x)
61
+ return out
62
+
63
+
64
+ if __name__ == "__main__":
65
+ import numpy as np
66
+
67
+ input_channels = 1
68
+ device = "cpu"
69
+ classifier = WhisperLCNN(
70
+ input_channels=input_channels,
71
+ freeze_encoder=True,
72
+ device=device,
73
+ )
74
+
75
+ input_channels = 2
76
+ classifier_2 = WhisperMultiFrontLCNN(
77
+ input_channels=input_channels,
78
+ freeze_encoder=True,
79
+ device=device,
80
+ frontend_algorithm="lfcc"
81
+ )
82
+ x = np.random.rand(2, 30 * 16_000).astype(np.float32)
83
+ x = torch.from_numpy(x)
84
+
85
+ out = classifier(x)
86
+ print(out.shape)
87
+
88
+ out = classifier_2(x)
89
+ print(out.shape)
src/models/whisper_main.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/openai/whisper/blob/main/whisper/model.py
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ import os
5
+ from typing import Iterable, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from torch import nn
12
+
13
+
14
+ def exact_div(x, y):
15
+ assert x % y == 0
16
+ return x // y
17
+
18
+
19
+ # hard-coded audio hyperparameters
20
+ SAMPLE_RATE = 16000
21
+ N_FFT = 400
22
+ N_MELS = 80
23
+ HOP_LENGTH = 160
24
+ CHUNK_LENGTH = 30
25
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
26
+ N_FRAMES = exact_div(
27
+ N_SAMPLES, HOP_LENGTH
28
+ ) # 3000: number of frames in a mel spectrogram input
29
+
30
+
31
+ def pad_or_trim(
32
+ array: Union[torch.Tensor, np.ndarray],
33
+ length: int = N_SAMPLES,
34
+ *,
35
+ axis: int = -1,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
39
+ """
40
+ if not torch.is_tensor(array):
41
+ array = torch.from_numpy(array)
42
+
43
+ if array.shape[axis] > length:
44
+ array = array.index_select(
45
+ dim=axis, index=torch.arange(length, device=array.device)
46
+ )
47
+
48
+ if array.shape[axis] < length:
49
+ # pad multiple times
50
+ num_repeats = int(length / array.shape[axis]) + 1
51
+ array = torch.tile(array, (1, num_repeats))[:, :length]
52
+ return array
53
+
54
+
55
+ @lru_cache(maxsize=None)
56
+ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
57
+ """
58
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
59
+ Allows decoupling librosa dependency; saved using:
60
+
61
+ np.savez_compressed(
62
+ "mel_filters.npz",
63
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
64
+ )
65
+ """
66
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
67
+ with np.load(
68
+ os.path.join(os.path.dirname(__file__), "assets/mel_filters.npz")
69
+ ) as f:
70
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
71
+
72
+
73
+ def log_mel_spectrogram(audio: torch.Tensor, n_mels: int = N_MELS):
74
+ """
75
+ Compute the log-Mel spectrogram of
76
+
77
+ Parameters
78
+ ----------
79
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
80
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
81
+
82
+ n_mels: int
83
+ The number of Mel-frequency filters, only 80 is supported
84
+
85
+ Returns
86
+ -------
87
+ torch.Tensor, shape = (80, n_frames)
88
+ A Tensor that contains the Mel spectrogram
89
+ """
90
+ window = torch.hann_window(N_FFT).to(audio.device)
91
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
92
+ magnitudes = stft[:, :-1].abs() ** 2
93
+
94
+ filters = mel_filters(audio.device, n_mels)
95
+ mel_spec = filters @ magnitudes
96
+
97
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
98
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
99
+ log_spec = (log_spec + 4.0) / 4.0
100
+ return log_spec
101
+
102
+
103
+ @dataclass
104
+ class ModelDimensions:
105
+ n_mels: int
106
+ n_audio_ctx: int
107
+ n_audio_state: int
108
+ n_audio_head: int
109
+ n_audio_layer: int
110
+ n_vocab: int
111
+ n_text_ctx: int
112
+ n_text_state: int
113
+ n_text_head: int
114
+ n_text_layer: int
115
+
116
+
117
+ class LayerNorm(nn.LayerNorm):
118
+ def forward(self, x: Tensor) -> Tensor:
119
+ return super().forward(x.float()).type(x.dtype)
120
+
121
+
122
+ class Linear(nn.Linear):
123
+ def forward(self, x: Tensor) -> Tensor:
124
+ return F.linear(
125
+ x,
126
+ self.weight.to(x.dtype),
127
+ None if self.bias is None else self.bias.to(x.dtype),
128
+ )
129
+
130
+
131
+ class Conv1d(nn.Conv1d):
132
+ def _conv_forward(
133
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
134
+ ) -> Tensor:
135
+ return super()._conv_forward(
136
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
137
+ )
138
+
139
+
140
+ def sinusoids(length, channels, max_timescale=10_000):
141
+ """Returns sinusoids for positional embedding"""
142
+ assert channels % 2 == 0
143
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
144
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
145
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
146
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
147
+
148
+
149
+ class MultiHeadAttention(nn.Module):
150
+ def __init__(self, n_state: int, n_head: int):
151
+ super().__init__()
152
+ self.n_head = n_head
153
+ self.query = Linear(n_state, n_state)
154
+ self.key = Linear(n_state, n_state, bias=False)
155
+ self.value = Linear(n_state, n_state)
156
+ self.out = Linear(n_state, n_state)
157
+
158
+ def forward(
159
+ self,
160
+ x: Tensor,
161
+ xa: Optional[Tensor] = None,
162
+ mask: Optional[Tensor] = None,
163
+ kv_cache: Optional[dict] = None,
164
+ ):
165
+ q = self.query(x)
166
+
167
+ if kv_cache is None or xa is None or self.key not in kv_cache:
168
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
169
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
170
+ k = self.key(x if xa is None else xa)
171
+ v = self.value(x if xa is None else xa)
172
+ else:
173
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
174
+ k = kv_cache[self.key]
175
+ v = kv_cache[self.value]
176
+
177
+ wv = self.qkv_attention(q, k, v, mask)
178
+ return self.out(wv)
179
+
180
+ def qkv_attention(
181
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
182
+ ):
183
+ n_batch, n_ctx, n_state = q.shape
184
+ scale = (n_state // self.n_head) ** -0.25
185
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
186
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
187
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
188
+
189
+ qk = q @ k
190
+ if mask is not None:
191
+ qk = qk + mask[:n_ctx, :n_ctx]
192
+
193
+ w = F.softmax(qk.float(), dim=-1).to(q.dtype)
194
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
195
+
196
+
197
+ class ResidualAttentionBlock(nn.Module):
198
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
199
+ super().__init__()
200
+
201
+ self.attn = MultiHeadAttention(n_state, n_head)
202
+ self.attn_ln = LayerNorm(n_state)
203
+
204
+ self.cross_attn = (
205
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
206
+ )
207
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
208
+
209
+ n_mlp = n_state * 4
210
+ self.mlp = nn.Sequential(
211
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
212
+ )
213
+ self.mlp_ln = LayerNorm(n_state)
214
+
215
+ def forward(
216
+ self,
217
+ x: Tensor,
218
+ xa: Optional[Tensor] = None,
219
+ mask: Optional[Tensor] = None,
220
+ kv_cache: Optional[dict] = None,
221
+ ):
222
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
223
+ if self.cross_attn:
224
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
225
+ x = x + self.mlp(self.mlp_ln(x))
226
+ return x
227
+
228
+
229
+ class AudioEncoder(nn.Module):
230
+ def __init__(
231
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
232
+ ):
233
+ super().__init__()
234
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
235
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
236
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
237
+
238
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
239
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
240
+ )
241
+ self.ln_post = LayerNorm(n_state)
242
+
243
+ def forward(self, x: Tensor):
244
+ """
245
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
246
+ the mel spectrogram of the audio
247
+ """
248
+ x = F.gelu(self.conv1(x))
249
+ x = F.gelu(self.conv2(x))
250
+ x = x.permute(0, 2, 1)
251
+
252
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
253
+ x = (x + self.positional_embedding).to(x.dtype)
254
+ for block in self.blocks:
255
+ x = block(x)
256
+
257
+ x = self.ln_post(x)
258
+ return x
259
+
260
+
261
+ class TextDecoder(nn.Module):
262
+ def __init__(
263
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
264
+ ):
265
+ super().__init__()
266
+
267
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
268
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
269
+
270
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
271
+ [
272
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
273
+ for _ in range(n_layer)
274
+ ]
275
+ )
276
+ self.ln = LayerNorm(n_state)
277
+
278
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
279
+ self.register_buffer("mask", mask, persistent=False)
280
+
281
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
282
+ """
283
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
284
+ the text tokens
285
+ xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
286
+ the encoded audio features to be attended on
287
+ """
288
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
289
+ x = (
290
+ self.token_embedding(x)
291
+ + self.positional_embedding[offset : offset + x.shape[-1]]
292
+ )
293
+ x = x.to(xa.dtype)
294
+
295
+ for block in self.blocks:
296
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
297
+
298
+ x = self.ln(x)
299
+ logits = (
300
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
301
+ ).float()
302
+
303
+ return logits
304
+
305
+
306
+ class Whisper(nn.Module):
307
+ def __init__(self, dims: ModelDimensions):
308
+ super().__init__()
309
+ self.dims = dims
310
+ self.encoder = AudioEncoder(
311
+ self.dims.n_mels,
312
+ self.dims.n_audio_ctx,
313
+ self.dims.n_audio_state,
314
+ self.dims.n_audio_head,
315
+ self.dims.n_audio_layer,
316
+ )
317
+
318
+ def forward(self, mel: torch.Tensor):
319
+ return self.encoder(mel)
320
+
321
+ @property
322
+ def device(self):
323
+ return next(self.parameters()).device
src/models/whisper_meso_net.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src import frontends
3
+
4
+ from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
5
+ from src.models.meso_net import MesoInception4
6
+ from src.commons import WHISPER_MODEL_WEIGHTS_PATH
7
+
8
+
9
+ class WhisperMesoNet(MesoInception4):
10
+ def __init__(self, freeze_encoder, **kwargs):
11
+ super().__init__(**kwargs)
12
+
13
+ self.device = kwargs['device']
14
+ checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH)
15
+ dims = ModelDimensions(**checkpoint["dims"].__dict__)
16
+ model = Whisper(dims)
17
+ model = model.to(self.device)
18
+ model.load_state_dict(checkpoint["model_state_dict"])
19
+ self.whisper_model = model
20
+ if freeze_encoder:
21
+ for param in self.whisper_model.parameters():
22
+ param.requires_grad = False
23
+
24
+ def compute_whisper_features(self, x):
25
+ specs = []
26
+ for sample in x:
27
+ specs.append(log_mel_spectrogram(sample))
28
+ x = torch.stack(specs)
29
+ x = self.whisper_model(x)
30
+
31
+ x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc)
32
+ x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc)
33
+ x = x.repeat(
34
+ (1, 1, 1, 2)
35
+ ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
36
+ return x
37
+
38
+ def forward(self, x):
39
+ # we assume that the data is correct (i.e. 30s)
40
+ x = self.compute_whisper_features(x)
41
+ out = self._compute_embedding(x)
42
+ return out
43
+
44
+
45
+ class WhisperMultiFrontMesoNet(WhisperMesoNet):
46
+ def __init__(self, freeze_encoder, **kwargs):
47
+ super().__init__(freeze_encoder=freeze_encoder, **kwargs)
48
+ self.frontend = frontends.get_frontend(kwargs['frontend_algorithm'])
49
+ print(f"Using {self.frontend} frontend!")
50
+
51
+ def forward(self, x):
52
+ # Frontend computation
53
+ frontend_x = self.frontend(x)
54
+ x = self.compute_whisper_features(x)
55
+
56
+ x = torch.cat([x, frontend_x], 1)
57
+ out = self._compute_embedding(x)
58
+ return out
59
+
60
+
61
+ if __name__ == "__main__":
62
+ import numpy as np
63
+
64
+ input_channels = 1
65
+ device = "cpu"
66
+ classifier = WhisperMesoNet(
67
+ input_channels=input_channels,
68
+ freeze_encoder=True,
69
+ fc1_dim=1024,
70
+ device=device,
71
+ )
72
+
73
+ input_channels = 2
74
+ classifier_2 = WhisperMultiFrontMesoNet(
75
+ input_channels=input_channels,
76
+ freeze_encoder=True,
77
+ fc1_dim=1024,
78
+ device=device,
79
+ frontend_algorithm="lfcc"
80
+ )
81
+ x = np.random.rand(2, 30 * 16_000).astype(np.float32)
82
+ x = torch.from_numpy(x)
83
+
84
+ out = classifier(x)
85
+ print(out.shape)
86
+
87
+ out = classifier_2(x)
88
+ print(out.shape)
src/models/whisper_specrnet.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from src import frontends
5
+ from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
6
+ from src.models.specrnet import SpecRNet
7
+ from src.commons import WHISPER_MODEL_WEIGHTS_PATH
8
+
9
+
10
+ class WhisperSpecRNet(SpecRNet):
11
+ def __init__(self, input_channels, freeze_encoder, **kwargs):
12
+ super().__init__(input_channels=input_channels, **kwargs)
13
+
14
+ self.device = kwargs["device"]
15
+ checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH)
16
+ dims = ModelDimensions(**checkpoint["dims"].__dict__)
17
+ model = Whisper(dims)
18
+ model = model.to(self.device)
19
+ model.load_state_dict(checkpoint["model_state_dict"])
20
+ self.whisper_model = model
21
+ if freeze_encoder:
22
+ for param in self.whisper_model.parameters():
23
+ param.requires_grad = False
24
+
25
+ def compute_whisper_features(self, x):
26
+ specs = []
27
+ for sample in x:
28
+ specs.append(log_mel_spectrogram(sample))
29
+ x = torch.stack(specs)
30
+ x = self.whisper_model(x)
31
+
32
+ x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc)
33
+ x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc)
34
+ x = x.repeat(
35
+ (1, 1, 1, 2)
36
+ ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
37
+ return x
38
+
39
+ def forward(self, x):
40
+ # we assume that the data is correct (i.e. 30s)
41
+ x = self.compute_whisper_features(x)
42
+ out = self._compute_embedding(x)
43
+ return out
44
+
45
+
46
+ class WhisperMultiFrontSpecRNet(WhisperSpecRNet):
47
+ def __init__(self, input_channels, freeze_encoder, **kwargs):
48
+ super().__init__(
49
+ input_channels=input_channels,
50
+ freeze_encoder=freeze_encoder,
51
+ **kwargs,
52
+ )
53
+ self.frontend = frontends.get_frontend(kwargs["frontend_algorithm"])
54
+ print(f"Using {self.frontend} frontend!")
55
+
56
+ def forward(self, x):
57
+ # Frontend computation
58
+ frontend_x = self.frontend(x)
59
+ x = self.compute_whisper_features(x)
60
+
61
+ x = torch.cat([x, frontend_x], 1)
62
+ out = self._compute_embedding(x)
63
+ return out
64
+
65
+
66
+ if __name__ == "__main__":
67
+ import numpy as np
68
+
69
+ input_channels = 1
70
+ config = {
71
+ "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
72
+ "nb_fc_node": 64,
73
+ "gru_node": 64,
74
+ "nb_gru_layer": 2,
75
+ "nb_classes": 1,
76
+ }
77
+ device = "cpu"
78
+ classifier = WhisperSpecRNet(
79
+ input_channels,
80
+ freeze_encoder=False,
81
+ device=device,
82
+ )
83
+ input_channels = 2
84
+ classifier_2 = WhisperMultiFrontSpecRNet(
85
+ input_channels,
86
+ freeze_encoder=False,
87
+ device=device,
88
+ frontend_algorithm="lfcc"
89
+ )
90
+ x = np.random.rand(2, 30 * 16_000).astype(np.float32)
91
+ x = torch.from_numpy(x)
92
+
93
+ out = classifier(x)
94
+ print(out.shape)
95
+
96
+ out = classifier_2(x)
97
+ print(out.shape)