donghoney0416 commited on
Commit
6648ece
·
verified ·
1 Parent(s): 263799a

Upload DeFTAN-II.py

Browse files
Files changed (1) hide show
  1. DeFTAN-II.py +329 -0
DeFTAN-II.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import Variable
9
+ from packaging.version import parse as V
10
+ from torch.nn import init
11
+ from torch.nn.parameter import Parameter
12
+
13
+ from einops import rearrange, repeat
14
+ from einops.layers.torch import Rearrange
15
+
16
+
17
+ class Network(nn.Module):
18
+ def __init__(self, n_srcs=1, win=512, n_mics=4, n_layers=12, att_dim=64, hidden_dim=256, n_head=4, emb_dim=64, emb_ks=4, emb_hs=1, dropout=0.1, eps=1.0e-5):
19
+ super().__init__()
20
+ self.n_srcs = n_srcs
21
+ self.win = win
22
+ self.hop = win // 2
23
+ self.n_layers = n_layers
24
+ self.n_mics = n_mics
25
+ self.emb_dim = emb_dim
26
+ assert win % 2 == 0
27
+
28
+ t_ksize = 3
29
+ ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
30
+ self.conv = nn.Sequential(
31
+ nn.Conv2d(2 * n_mics, emb_dim * n_head, ks, padding=padding),
32
+ nn.GroupNorm(1, emb_dim * n_head, eps=eps),
33
+ InverseDenseBlock2d(emb_dim * n_head, emb_dim, n_head)
34
+ )
35
+ self.blocks = nn.ModuleList([])
36
+ for idx in range(n_layers):
37
+ self.blocks.append(DeFTANblock(idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps))
38
+ self.deconv = nn.Sequential(
39
+ nn.Conv2d(emb_dim, 2 * n_srcs * n_head, ks, padding=padding),
40
+ InverseDenseBlock2d(2 * n_srcs * n_head, 2 * n_srcs, n_head))
41
+
42
+ def pad_signal(self, input):
43
+ # input is the waveforms: (B, T) or (B, 1, T)
44
+ # reshape and padding
45
+ if input.dim() not in [2, 3]:
46
+ raise RuntimeError("Input can only be 2 or 3 dimensional.")
47
+
48
+ if input.dim() == 2:
49
+ input = input.unsqueeze(1)
50
+ batch_size = input.size(0)
51
+ nchannel = input.size(1)
52
+ nsample = input.size(2)
53
+ rest = self.win - (self.hop + nsample % self.win) % self.win
54
+ if rest > 0:
55
+ pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type())
56
+ input = torch.cat([input, pad], 2)
57
+
58
+ pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type())
59
+ input = torch.cat([pad_aux, input, pad_aux], 2)
60
+
61
+ return input, rest
62
+
63
+ def forward(self, input: Union[torch.Tensor]) -> Tuple[List[Union[torch.Tensor]], torch.Tensor, OrderedDict]:
64
+ input, rest = self.pad_signal(input)
65
+ B, M, N = input.size() # batch B, mic M, time samples N
66
+ mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1]
67
+ input = input / mix_std_ # RMS normalization
68
+
69
+ stft_input = torch.stft(input.view([-1, N]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
70
+ _, F, T, _ = stft_input.size() # B*M , F= num freqs, T= num frame, 2= real imag
71
+ xi = stft_input.view([B, M, F, T, 2]) # B*M, F, T, 2 -> B, M, F, T, 2
72
+ xi = xi.permute(0, 1, 4, 3, 2).contiguous() # [B, M, 2, T, F]
73
+ batch = xi.view([B, M * 2, T, F]) # [B, 2*M, T, F]
74
+
75
+ batch = self.conv(batch) # [B, C, T, F]
76
+ for ii in range(self.n_layers):
77
+ batch = self.blocks[ii](batch) # [B, C, T, F]
78
+ batch = self.deconv(batch).view([B, self.n_srcs, 2, T, F]).view([B * self.n_srcs, 2, T, F])
79
+
80
+ batch = batch.permute(0, 3, 2, 1).type(input.type()) # [B*n_srcs, 2, T, F] -> [B*n_srcs, F, T, 2]
81
+ istft_input = torch.complex(batch[:, :, :, 0], batch[:, :, :, 1])
82
+ istft_output = torch.istft(istft_input, n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
83
+
84
+ output = istft_output[:, self.hop:-(rest + self.hop)].unsqueeze(1) # [B*n_srcs, 1, N]
85
+ output = output.view([B, self.n_srcs, -1]) # [B, n_srcs, N]
86
+ output = output * mix_std_ # reverse the RMS normalization
87
+
88
+ return output
89
+
90
+
91
+ class InverseDenseBlock1d(nn.Module):
92
+ def __init__(self, in_channels, out_channels, groups):
93
+ super().__init__()
94
+ assert in_channels // out_channels == groups
95
+ self.in_channels = in_channels
96
+ self.out_channels = out_channels
97
+ self.groups = groups
98
+ self.blocks = nn.ModuleList([])
99
+ for idx in range(groups):
100
+ self.blocks.append(nn.Sequential(
101
+ nn.Conv1d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=3, padding=1),
102
+ nn.GroupNorm(1, out_channels, 1e-5),
103
+ nn.PReLU(out_channels)
104
+ ))
105
+
106
+ def forward(self, x):
107
+ B, C, L = x.size()
108
+ g = self.groups
109
+ x = x.view(B, g, C//g, L).transpose(1, 2).reshape(B, C, L)
110
+ skip = x[:, ::g, :]
111
+ for idx in range(g):
112
+ output = self.blocks[idx](skip)
113
+ skip = torch.cat([output, x[:, idx+1::g, :]], dim=1)
114
+ return output
115
+
116
+
117
+ class InverseDenseBlock2d(nn.Module):
118
+ def __init__(self, in_channels, out_channels, groups):
119
+ super().__init__()
120
+ assert in_channels // out_channels == groups
121
+ self.in_channels = in_channels
122
+ self.out_channels = out_channels
123
+ self.groups = groups
124
+ self.blocks = nn.ModuleList([])
125
+ for idx in range(groups):
126
+ self.blocks.append(nn.Sequential(
127
+ nn.Conv2d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=(3, 3), padding=(1, 1)),
128
+ nn.GroupNorm(1, out_channels, 1e-5),
129
+ nn.PReLU(out_channels)
130
+ ))
131
+
132
+ def forward(self, x):
133
+ B, C, T, Q = x.size()
134
+ g = self.groups
135
+ x = x.view(B, g, C//g, T, Q).transpose(1, 2).reshape(B, C, T, Q)
136
+ skip = x[:, ::g, :, :]
137
+ for idx in range(g):
138
+ output = self.blocks[idx](skip)
139
+ skip = torch.cat([output, x[:, idx+1::g, :, :]], dim=1)
140
+ return output
141
+
142
+
143
+ class PreNorm(nn.Module):
144
+ def __init__(self, dim, fn):
145
+ super().__init__()
146
+ self.norm = nn.LayerNorm(dim)
147
+ self.fn = fn
148
+ def forward(self, x, **kwargs):
149
+ return self.fn(self.norm(x), **kwargs)
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, dim, heads, dim_head, dropout):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ project_out = not (heads == 1 and dim_head == dim)
157
+
158
+ self.heads = heads
159
+ self.scale = dim_head ** -0.5
160
+
161
+ self.cv_qk = nn.Sequential(
162
+ nn.Conv1d(dim, dim * 2, kernel_size=3, padding=1, bias = False),
163
+ nn.GLU(dim=1))
164
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
165
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
166
+ self.to_v = nn.Linear(dim, inner_dim, bias = False)
167
+
168
+ self.att_drop = nn.Dropout(dropout)
169
+
170
+ self.to_out = nn.Sequential(
171
+ nn.Linear(inner_dim, dim),
172
+ nn.Dropout(dropout)
173
+ ) if project_out else nn.Identity()
174
+
175
+ def forward(self, x):
176
+ qk = self.cv_qk(x.transpose(1, 2)).transpose(1, 2)
177
+ q = rearrange(self.to_q(qk), 'b n (h d) -> b h n d', h = self.heads)
178
+ k = rearrange(self.to_k(qk), 'b n (h d) -> b h n d', h=self.heads)
179
+ v = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h = self.heads)
180
+
181
+ weight = torch.matmul(F.softmax(k, dim=2).transpose(-1, -2), v) * self.scale
182
+ out = torch.matmul(F.softmax(q, dim=3), self.att_drop(weight))
183
+ out = rearrange(out, 'b h n d -> b n (h d)')
184
+ return self.to_out(out)
185
+
186
+
187
+ class FeedForward(nn.Module):
188
+ def __init__(self, dim, hidden_dim, idx, dropout):
189
+ super().__init__()
190
+ self.PW1 = nn.Sequential(
191
+ nn.Linear(dim, hidden_dim//2),
192
+ nn.GELU(),
193
+ nn.Dropout(dropout)
194
+ )
195
+ self.PW2 = nn.Sequential(
196
+ nn.Linear(dim, hidden_dim//2),
197
+ nn.GELU(),
198
+ nn.Dropout(dropout)
199
+ )
200
+ self.DW_Conv = nn.Sequential(
201
+ nn.Conv1d(hidden_dim//2, hidden_dim//2, kernel_size=5, dilation=2**idx, padding='same'),
202
+ nn.GroupNorm(1, hidden_dim//2, 1e-5),
203
+ nn.PReLU(hidden_dim//2)
204
+ )
205
+ self.PW3 = nn.Sequential(
206
+ nn.Linear(hidden_dim, dim),
207
+ nn.Dropout(dropout)
208
+ )
209
+
210
+ def forward(self, x):
211
+ ffw_out = self.PW1(x)
212
+ dw_out = self.DW_Conv(self.PW2(x).transpose(1, 2)).transpose(1, 2)
213
+ out = self.PW3(torch.cat((ffw_out, dw_out), dim=2))
214
+ return out
215
+
216
+
217
+ class DeFTANblock(nn.Module):
218
+ def __getitem__(self, key):
219
+ return getattr(self, key)
220
+
221
+ def __init__(self, idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps):
222
+ super().__init__()
223
+ in_channels = emb_dim * emb_ks
224
+ self.intra_norm = LayerNormalization4D(emb_dim, eps)
225
+ self.intra_inv = InverseDenseBlock1d(in_channels, emb_dim, emb_ks)
226
+ self.intra_att = PreNorm(emb_dim, Attention(emb_dim, n_head, att_dim, dropout))
227
+ self.intra_ffw = PreNorm(emb_dim, FeedForward(emb_dim, hidden_dim, idx, dropout))
228
+ self.intra_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
229
+
230
+ self.inter_norm = LayerNormalization4D(emb_dim, eps)
231
+ self.inter_inv = InverseDenseBlock1d(in_channels, emb_dim, emb_ks)
232
+ self.inter_att = PreNorm(emb_dim, Attention(emb_dim, n_head, att_dim, dropout))
233
+ self.inter_ffw = PreNorm(emb_dim, FeedForward(emb_dim, hidden_dim, idx, dropout))
234
+ self.inter_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
235
+
236
+ self.emb_dim = emb_dim
237
+ self.emb_ks = emb_ks
238
+ self.emb_hs = emb_hs
239
+ self.n_head = n_head
240
+
241
+ def forward(self, x):
242
+ B, C, old_T, old_Q = x.shape
243
+ T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
244
+ Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
245
+ x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
246
+
247
+ # F-transformer
248
+ input_ = x
249
+ intra_rnn = self.intra_norm(input_) # [B, C, T, Q]
250
+ intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
251
+ intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*emb_ks, -1]
252
+ intra_rnn = self.intra_inv(intra_rnn) # [BT, C, -1]
253
+
254
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C]
255
+ intra_rnn = self.intra_att(intra_rnn) + intra_rnn
256
+ intra_rnn = self.intra_ffw(intra_rnn) + intra_rnn
257
+ intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
258
+
259
+ intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
260
+ intra_rnn = intra_rnn.view([B, T, C, Q])
261
+ intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q]
262
+ intra_rnn = intra_rnn + input_ # [B, C, T, Q]
263
+
264
+ # T-transformer
265
+ input_ = intra_rnn
266
+ inter_rnn = self.inter_norm(input_) # [B, C, T, F]
267
+ inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) # [BF, C, T]
268
+ inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BF, C*emb_ks, -1]
269
+ inter_rnn = self.inter_inv(inter_rnn) # [BF, C, -1]
270
+
271
+ inter_rnn = inter_rnn.transpose(1, 2) # [BF, -1, C]
272
+ inter_rnn = self.inter_att(inter_rnn) + inter_rnn
273
+ inter_rnn = self.inter_ffw(inter_rnn) + inter_rnn
274
+ inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1]
275
+
276
+ inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T]
277
+ inter_rnn = inter_rnn.view([B, Q, C, T])
278
+ inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous() # [B, C, T, Q]
279
+ inter_rnn = inter_rnn + input_ # [B, C, T, Q]
280
+
281
+ return inter_rnn
282
+
283
+
284
+ class LayerNormalization4D(nn.Module):
285
+ def __init__(self, input_dimension, eps=1e-5):
286
+ super().__init__()
287
+ param_size = [1, input_dimension, 1, 1]
288
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
289
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
290
+ init.ones_(self.gamma)
291
+ init.zeros_(self.beta)
292
+ self.eps = eps
293
+
294
+ def forward(self, x):
295
+ if x.ndim == 4:
296
+ _, C, _, _ = x.shape
297
+ stat_dim = (1,)
298
+ else:
299
+ raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
300
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
301
+ std_ = torch.sqrt(
302
+ x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
303
+ ) # [B,1,T,F]
304
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
305
+ return x_hat
306
+
307
+
308
+ class LayerNormalization4DCF(nn.Module):
309
+ def __init__(self, input_dimension, eps=1e-5):
310
+ super().__init__()
311
+ assert len(input_dimension) == 2
312
+ param_size = [1, input_dimension[0], 1, input_dimension[1]]
313
+ self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
314
+ self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
315
+ init.ones_(self.gamma)
316
+ init.zeros_(self.beta)
317
+ self.eps = eps
318
+
319
+ def forward(self, x):
320
+ if x.ndim == 4:
321
+ stat_dim = (1, 3)
322
+ else:
323
+ raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
324
+ mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
325
+ std_ = torch.sqrt(
326
+ x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
327
+ ) # [B,1,T,F]
328
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
329
+ return x_hat