donghoney0416 commited on
Commit
795921a
·
verified ·
1 Parent(s): 57ac7f5

Upload DeFTAN2.py

Browse files
Files changed (1) hide show
  1. DeFTAN2.py +329 -0
DeFTAN2.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import Variable
9
+ from packaging.version import parse as V
10
+ from torch.nn import init
11
+ from torch.nn.parameter import Parameter
12
+
13
+ from einops import rearrange, repeat
14
+ from einops.layers.torch import Rearrange
15
+
16
+
17
+ class DeFTAN2(nn.Module):
18
+ def __init__(self, n_srcs=1, win=512, n_mics=4, n_layers=6, att_dim=64, hidden_dim=256, n_head=4, emb_dim=64, emb_ks=4, emb_hs=1, dropout=0.1, eps=1.0e-5):
19
+ super().__init__()
20
+ self.n_srcs = n_srcs
21
+ self.win = win
22
+ self.hop = win // 2
23
+ self.n_layers = n_layers
24
+ self.n_mics = n_mics
25
+ self.emb_dim = emb_dim
26
+ assert win % 2 == 0
27
+
28
+ t_ksize = 3
29
+ ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
30
+ self.up_conv = nn.Sequential(
31
+ nn.Conv2d(2 * n_mics, emb_dim * n_head, ks, padding=padding),
32
+ nn.GroupNorm(1, emb_dim * n_head, eps=eps),
33
+ SDB2d(emb_dim * n_head, emb_dim, n_head)
34
+ )
35
+ self.blocks = nn.ModuleList([])
36
+ for idx in range(n_layers):
37
+ self.blocks.append(DeFTAN2block(idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps))
38
+ self.down_conv = nn.Sequential(
39
+ nn.Conv2d(emb_dim, 2 * n_srcs * n_head, ks, padding=padding),
40
+ SDB2d(2 * n_srcs * n_head, 2 * n_srcs, n_head))
41
+
42
+ def pad_signal(self, input):
43
+ # input is the waveforms: (B, T) or (B, 1, T)
44
+ # reshape and padding
45
+ if input.dim() not in [2, 3]:
46
+ raise RuntimeError("Input can only be 2 or 3 dimensional.")
47
+
48
+ if input.dim() == 2:
49
+ input = input.unsqueeze(1)
50
+ batch_size = input.size(0)
51
+ nchannel = input.size(1)
52
+ nsample = input.size(2)
53
+ rest = self.win - (self.hop + nsample % self.win) % self.win
54
+ if rest > 0:
55
+ pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type())
56
+ input = torch.cat([input, pad], 2)
57
+
58
+ pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type())
59
+ input = torch.cat([pad_aux, input, pad_aux], 2)
60
+
61
+ return input, rest
62
+
63
+ def forward(self, input: Union[torch.Tensor]) -> Tuple[List[Union[torch.Tensor]], torch.Tensor, OrderedDict]:
64
+ input, rest = self.pad_signal(input)
65
+ B, M, N = input.size() # batch B, mic M, time samples N
66
+ mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1]
67
+ input = input / mix_std_ # RMS normalization
68
+ # Encoding
69
+ stft_input = torch.stft(input.view([-1, N]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
70
+ _, F, T, _ = stft_input.size() # B*M , F= num freqs, T= num frame, 2= real imag
71
+ xi = stft_input.view([B, M, F, T, 2]) # B*M, F, T, 2 -> B, M, F, T, 2
72
+ xi = xi.permute(0, 1, 4, 3, 2).contiguous() # [B, M, 2, T, F]
73
+ xi = xi.view([B, M * 2, T, F]) # [B, 2*M, T, F]
74
+ # Separation
75
+ feature = self.up_conv(xi) # [B, C, T, F]
76
+ for ii in range(self.n_layers):
77
+ feature = self.blocks[ii](feature) # [B, C, T, F]
78
+ xo = self.down_conv(feature).view([B, self.n_srcs, 2, T, F]).view([B * self.n_srcs, 2, T, F])
79
+ # Decoding
80
+ xo = xo.permute(0, 3, 2, 1).type(input.type()) # [B*n_srcs, 2, T, F] -> [B*n_srcs, F, T, 2]
81
+ istft_input = torch.complex(xo[:, :, :, 0], xo[:, :, :, 1])
82
+ istft_output = torch.istft(istft_input, n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
83
+
84
+ output = istft_output[:, self.hop:-(rest + self.hop)].unsqueeze(1) # [B*n_srcs, 1, N]
85
+ output = output.view([B, self.n_srcs, -1]) # [B, n_srcs, N]
86
+ output = output * mix_std_ # reverse the RMS normalization
87
+
88
+ return output
89
+
90
+
91
+ class SDB1d(nn.Module):
92
+ def __init__(self, in_channels, out_channels, groups):
93
+ super().__init__()
94
+ assert in_channels // out_channels == groups
95
+ self.in_channels = in_channels
96
+ self.out_channels = out_channels
97
+ self.groups = groups
98
+ self.blocks = nn.ModuleList([])
99
+ for idx in range(groups):
100
+ self.blocks.append(nn.Sequential(
101
+ nn.Conv1d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=3, padding=1),
102
+ nn.GroupNorm(1, out_channels, 1e-5),
103
+ nn.PReLU(out_channels)
104
+ ))
105
+
106
+ def forward(self, x):
107
+ B, C, L = x.size()
108
+ g = self.groups
109
+ # x = x.view(B, g, C//g, L).transpose(1, 2).reshape(B, C, L)
110
+ skip = x[:, ::g, :]
111
+ for idx in range(g):
112
+ output = self.blocks[idx](skip)
113
+ skip = torch.cat([output, x[:, idx+1::g, :]], dim=1)
114
+ return output
115
+
116
+
117
+ class SDB2d(nn.Module):
118
+ def __init__(self, in_channels, out_channels, groups):
119
+ super().__init__()
120
+ assert in_channels // out_channels == groups
121
+ self.in_channels = in_channels
122
+ self.out_channels = out_channels
123
+ self.groups = groups
124
+ self.blocks = nn.ModuleList([])
125
+ for idx in range(groups):
126
+ self.blocks.append(nn.Sequential(
127
+ nn.Conv2d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=(3, 3), padding=(1, 1)),
128
+ nn.GroupNorm(1, out_channels, 1e-5),
129
+ nn.PReLU(out_channels)
130
+ ))
131
+
132
+ def forward(self, x):
133
+ B, C, T, Q = x.size()
134
+ g = self.groups
135
+ # x = x.view(B, g, C//g, T, Q).transpose(1, 2).reshape(B, C, T, Q)
136
+ skip = x[:, ::g, :, :]
137
+ for idx in range(g):
138
+ output = self.blocks[idx](skip)
139
+ skip = torch.cat([output, x[:, idx+1::g, :, :]], dim=1)
140
+ return output
141
+
142
+
143
+ class PreNorm(nn.Module):
144
+ def __init__(self, dim, fn):
145
+ super().__init__()
146
+ self.norm = nn.LayerNorm(dim)
147
+ self.fn = fn
148
+ def forward(self, x, **kwargs):
149
+ return self.fn(self.norm(x), **kwargs)
150
+
151
+
152
+ class CEA(nn.Module):
153
+ def __init__(self, dim, heads, dim_head, dropout):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ project_out = not (heads == 1 and dim_head == dim)
157
+
158
+ self.heads = heads
159
+ self.scale = dim_head ** -0.5
160
+
161
+ self.cv_qk = nn.Sequential(
162
+ nn.Conv1d(dim, dim * 2, kernel_size=3, padding=1, bias = False),
163
+ nn.GLU(dim=1))
164
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
165
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
166
+ self.to_v = nn.Linear(dim, inner_dim, bias = False)
167
+
168
+ self.att_drop = nn.Dropout(dropout)
169
+
170
+ self.to_out = nn.Sequential(
171
+ nn.Linear(inner_dim, dim),
172
+ nn.Dropout(dropout)
173
+ ) if project_out else nn.Identity()
174
+
175
+ def forward(self, x):
176
+ qk = self.cv_qk(x.transpose(1, 2)).transpose(1, 2)
177
+ q = rearrange(self.to_q(qk), 'b n (h d) -> b h n d', h = self.heads)
178
+ k = rearrange(self.to_k(qk), 'b n (h d) -> b h n d', h=self.heads)
179
+ v = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h = self.heads)
180
+
181
+ weight = torch.matmul(F.softmax(k, dim=2).transpose(-1, -2), v) * self.scale
182
+ out = torch.matmul(F.softmax(q, dim=3), self.att_drop(weight))
183
+ out = rearrange(out, 'b h n d -> b n (h d)')
184
+ return self.to_out(out)
185
+
186
+
187
+ class DPFN(nn.Module):
188
+ def __init__(self, dim, hidden_dim, idx, dropout):
189
+ super().__init__()
190
+ self.PW1 = nn.Sequential(
191
+ nn.Linear(dim, hidden_dim//2),
192
+ nn.GELU(),
193
+ nn.Dropout(dropout)
194
+ )
195
+ self.PW2 = nn.Sequential(
196
+ nn.Linear(dim, hidden_dim//2),
197
+ nn.GELU(),
198
+ nn.Dropout(dropout)
199
+ )
200
+ self.DW_Conv = nn.Sequential(
201
+ nn.Conv1d(hidden_dim//2, hidden_dim//2, kernel_size=5, dilation=2**idx, padding='same'),
202
+ nn.GroupNorm(1, hidden_dim//2, 1e-5),
203
+ nn.PReLU(hidden_dim//2)
204
+ )
205
+ self.PW3 = nn.Sequential(
206
+ nn.Linear(hidden_dim, dim),
207
+ nn.Dropout(dropout)
208
+ )
209
+
210
+ def forward(self, x):
211
+ ffw_out = self.PW1(x)
212
+ dw_out = self.DW_Conv(self.PW2(x).transpose(1, 2)).transpose(1, 2)
213
+ out = self.PW3(torch.cat((ffw_out, dw_out), dim=2))
214
+ return out
215
+
216
+
217
+ class DeFTAN2block(nn.Module):
218
+ def __getitem__(self, key):
219
+ return getattr(self, key)
220
+
221
+ def __init__(self, idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps):
222
+ super().__init__()
223
+ in_channels = emb_dim * emb_ks
224
+ self.F_norm = LayerNormalization4D(emb_dim, eps)
225
+ self.F_inv = SDB1d(in_channels, emb_dim, emb_ks)
226
+ self.F_att = PreNorm(emb_dim, CEA(emb_dim, n_head, att_dim, dropout))
227
+ self.F_ffw = PreNorm(emb_dim, DPFN(emb_dim, hidden_dim, idx, dropout))
228
+ self.F_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
229
+
230
+ self.T_norm = LayerNormalization4D(emb_dim, eps)
231
+ self.T_inv = SDB1d(in_channels, emb_dim, emb_ks)
232
+ self.T_att = PreNorm(emb_dim, CEA(emb_dim, n_head, att_dim, dropout))
233
+ self.T_ffw = PreNorm(emb_dim, DPFN(emb_dim, hidden_dim, idx, dropout))
234
+ self.T_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
235
+
236
+ self.emb_dim = emb_dim
237
+ self.emb_ks = emb_ks
238
+ self.emb_hs = emb_hs
239
+ self.n_head = n_head
240
+
241
+ def forward(self, x):
242
+ B, C, old_T, old_Q = x.shape
243
+ T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
244
+ Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
245
+ x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
246
+
247
+ # F-transformer
248
+ input_ = x
249
+ F_feat = self.F_norm(input_) # [B, C, T, Q]
250
+ F_feat = F_feat.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
251
+ F_feat = F.unfold(F_feat[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*emb_ks, -1]
252
+ F_feat = self.F_inv(F_feat) # [BT, C, -1]
253
+
254
+ F_feat = F_feat.transpose(1, 2) # [BT, -1, C]
255
+ F_feat = self.F_att(F_feat) + F_feat
256
+ F_feat = self.F_ffw(F_feat) + F_feat
257
+ F_feat = F_feat.transpose(1, 2) # [BT, H, -1]
258
+
259
+ F_feat = self.F_linear(F_feat) # [BT, C, Q]
260
+ F_feat = F_feat.view([B, T, C, Q])
261
+ F_feat = F_feat.transpose(1, 2).contiguous() # [B, C, T, Q]
262
+ F_feat = F_feat + input_ # [B, C, T, Q]
263
+
264
+ # T-transformer
265
+ input_ = F_feat
266
+ T_feat = self.T_norm(input_) # [B, C, T, F]
267
+ T_feat = T_feat.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) # [BF, C, T]
268
+ T_feat = F.unfold(T_feat[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BF, C*emb_ks, -1]
269
+ T_feat = self.T_inv(T_feat) # [BF, C, -1]
270
+
271
+ T_feat = T_feat.transpose(1, 2) # [BF, -1, C]
272
+ T_feat = self.T_att(T_feat) + T_feat
273
+ T_feat = self.T_ffw(T_feat) + T_feat
274
+ T_feat = T_feat.transpose(1, 2) # [BF, H, -1]
275
+
276
+ T_feat = self.T_linear(T_feat) # [BF, C, T]
277
+ T_feat = T_feat.view([B, Q, C, T])
278
+ T_feat = T_feat.permute(0, 2, 3, 1).contiguous() # [B, C, T, Q]
279
+ T_feat = T_feat + input_ # [B, C, T, Q]
280
+
281
+ return T_feat
282
+
283
+
284
+ class LayerNormalization4D(nn.Module):
285
+ def __init__(self, input_dimension, eps=1e-5):
286
+ super().__init__()
287
+ param_size = [1, input_dimension, 1, 1]
288
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
289
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
290
+ init.ones_(self.gamma)
291
+ init.zeros_(self.beta)
292
+ self.eps = eps
293
+
294
+ def forward(self, x):
295
+ if x.ndim == 4:
296
+ _, C, _, _ = x.shape
297
+ stat_dim = (1,)
298
+ else:
299
+ raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
300
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
301
+ std_ = torch.sqrt(
302
+ x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
303
+ ) # [B,1,T,F]
304
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
305
+ return x_hat
306
+
307
+
308
+ class LayerNormalization4DCF(nn.Module):
309
+ def __init__(self, input_dimension, eps=1e-5):
310
+ super().__init__()
311
+ assert len(input_dimension) == 2
312
+ param_size = [1, input_dimension[0], 1, input_dimension[1]]
313
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
314
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
315
+ init.ones_(self.gamma)
316
+ init.zeros_(self.beta)
317
+ self.eps = eps
318
+
319
+ def forward(self, x):
320
+ if x.ndim == 4:
321
+ stat_dim = (1, 3)
322
+ else:
323
+ raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
324
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
325
+ std_ = torch.sqrt(
326
+ x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
327
+ ) # [B,1,T,F]
328
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
329
+ return x_hat