Upload DeFTAN-II.py
Browse files- DeFTAN-II.py +329 -0
DeFTAN-II.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.autograd import Variable
|
9 |
+
from packaging.version import parse as V
|
10 |
+
from torch.nn import init
|
11 |
+
from torch.nn.parameter import Parameter
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from einops.layers.torch import Rearrange
|
15 |
+
|
16 |
+
|
17 |
+
class Network(nn.Module):
|
18 |
+
def __init__(self, n_srcs=1, win=512, n_mics=4, n_layers=12, att_dim=64, hidden_dim=256, n_head=4, emb_dim=64, emb_ks=4, emb_hs=1, dropout=0.1, eps=1.0e-5):
|
19 |
+
super().__init__()
|
20 |
+
self.n_srcs = n_srcs
|
21 |
+
self.win = win
|
22 |
+
self.hop = win // 2
|
23 |
+
self.n_layers = n_layers
|
24 |
+
self.n_mics = n_mics
|
25 |
+
self.emb_dim = emb_dim
|
26 |
+
assert win % 2 == 0
|
27 |
+
|
28 |
+
t_ksize = 3
|
29 |
+
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
|
30 |
+
self.conv = nn.Sequential(
|
31 |
+
nn.Conv2d(2 * n_mics, emb_dim * n_head, ks, padding=padding),
|
32 |
+
nn.GroupNorm(1, emb_dim * n_head, eps=eps),
|
33 |
+
InverseDenseBlock2d(emb_dim * n_head, emb_dim, n_head)
|
34 |
+
)
|
35 |
+
self.blocks = nn.ModuleList([])
|
36 |
+
for idx in range(n_layers):
|
37 |
+
self.blocks.append(DeFTANblock(idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps))
|
38 |
+
self.deconv = nn.Sequential(
|
39 |
+
nn.Conv2d(emb_dim, 2 * n_srcs * n_head, ks, padding=padding),
|
40 |
+
InverseDenseBlock2d(2 * n_srcs * n_head, 2 * n_srcs, n_head))
|
41 |
+
|
42 |
+
def pad_signal(self, input):
|
43 |
+
# input is the waveforms: (B, T) or (B, 1, T)
|
44 |
+
# reshape and padding
|
45 |
+
if input.dim() not in [2, 3]:
|
46 |
+
raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
47 |
+
|
48 |
+
if input.dim() == 2:
|
49 |
+
input = input.unsqueeze(1)
|
50 |
+
batch_size = input.size(0)
|
51 |
+
nchannel = input.size(1)
|
52 |
+
nsample = input.size(2)
|
53 |
+
rest = self.win - (self.hop + nsample % self.win) % self.win
|
54 |
+
if rest > 0:
|
55 |
+
pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type())
|
56 |
+
input = torch.cat([input, pad], 2)
|
57 |
+
|
58 |
+
pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type())
|
59 |
+
input = torch.cat([pad_aux, input, pad_aux], 2)
|
60 |
+
|
61 |
+
return input, rest
|
62 |
+
|
63 |
+
def forward(self, input: Union[torch.Tensor]) -> Tuple[List[Union[torch.Tensor]], torch.Tensor, OrderedDict]:
|
64 |
+
input, rest = self.pad_signal(input)
|
65 |
+
B, M, N = input.size() # batch B, mic M, time samples N
|
66 |
+
mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1]
|
67 |
+
input = input / mix_std_ # RMS normalization
|
68 |
+
|
69 |
+
stft_input = torch.stft(input.view([-1, N]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
|
70 |
+
_, F, T, _ = stft_input.size() # B*M , F= num freqs, T= num frame, 2= real imag
|
71 |
+
xi = stft_input.view([B, M, F, T, 2]) # B*M, F, T, 2 -> B, M, F, T, 2
|
72 |
+
xi = xi.permute(0, 1, 4, 3, 2).contiguous() # [B, M, 2, T, F]
|
73 |
+
batch = xi.view([B, M * 2, T, F]) # [B, 2*M, T, F]
|
74 |
+
|
75 |
+
batch = self.conv(batch) # [B, C, T, F]
|
76 |
+
for ii in range(self.n_layers):
|
77 |
+
batch = self.blocks[ii](batch) # [B, C, T, F]
|
78 |
+
batch = self.deconv(batch).view([B, self.n_srcs, 2, T, F]).view([B * self.n_srcs, 2, T, F])
|
79 |
+
|
80 |
+
batch = batch.permute(0, 3, 2, 1).type(input.type()) # [B*n_srcs, 2, T, F] -> [B*n_srcs, F, T, 2]
|
81 |
+
istft_input = torch.complex(batch[:, :, :, 0], batch[:, :, :, 1])
|
82 |
+
istft_output = torch.istft(istft_input, n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
|
83 |
+
|
84 |
+
output = istft_output[:, self.hop:-(rest + self.hop)].unsqueeze(1) # [B*n_srcs, 1, N]
|
85 |
+
output = output.view([B, self.n_srcs, -1]) # [B, n_srcs, N]
|
86 |
+
output = output * mix_std_ # reverse the RMS normalization
|
87 |
+
|
88 |
+
return output
|
89 |
+
|
90 |
+
|
91 |
+
class InverseDenseBlock1d(nn.Module):
|
92 |
+
def __init__(self, in_channels, out_channels, groups):
|
93 |
+
super().__init__()
|
94 |
+
assert in_channels // out_channels == groups
|
95 |
+
self.in_channels = in_channels
|
96 |
+
self.out_channels = out_channels
|
97 |
+
self.groups = groups
|
98 |
+
self.blocks = nn.ModuleList([])
|
99 |
+
for idx in range(groups):
|
100 |
+
self.blocks.append(nn.Sequential(
|
101 |
+
nn.Conv1d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=3, padding=1),
|
102 |
+
nn.GroupNorm(1, out_channels, 1e-5),
|
103 |
+
nn.PReLU(out_channels)
|
104 |
+
))
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
B, C, L = x.size()
|
108 |
+
g = self.groups
|
109 |
+
x = x.view(B, g, C//g, L).transpose(1, 2).reshape(B, C, L)
|
110 |
+
skip = x[:, ::g, :]
|
111 |
+
for idx in range(g):
|
112 |
+
output = self.blocks[idx](skip)
|
113 |
+
skip = torch.cat([output, x[:, idx+1::g, :]], dim=1)
|
114 |
+
return output
|
115 |
+
|
116 |
+
|
117 |
+
class InverseDenseBlock2d(nn.Module):
|
118 |
+
def __init__(self, in_channels, out_channels, groups):
|
119 |
+
super().__init__()
|
120 |
+
assert in_channels // out_channels == groups
|
121 |
+
self.in_channels = in_channels
|
122 |
+
self.out_channels = out_channels
|
123 |
+
self.groups = groups
|
124 |
+
self.blocks = nn.ModuleList([])
|
125 |
+
for idx in range(groups):
|
126 |
+
self.blocks.append(nn.Sequential(
|
127 |
+
nn.Conv2d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=(3, 3), padding=(1, 1)),
|
128 |
+
nn.GroupNorm(1, out_channels, 1e-5),
|
129 |
+
nn.PReLU(out_channels)
|
130 |
+
))
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
B, C, T, Q = x.size()
|
134 |
+
g = self.groups
|
135 |
+
x = x.view(B, g, C//g, T, Q).transpose(1, 2).reshape(B, C, T, Q)
|
136 |
+
skip = x[:, ::g, :, :]
|
137 |
+
for idx in range(g):
|
138 |
+
output = self.blocks[idx](skip)
|
139 |
+
skip = torch.cat([output, x[:, idx+1::g, :, :]], dim=1)
|
140 |
+
return output
|
141 |
+
|
142 |
+
|
143 |
+
class PreNorm(nn.Module):
|
144 |
+
def __init__(self, dim, fn):
|
145 |
+
super().__init__()
|
146 |
+
self.norm = nn.LayerNorm(dim)
|
147 |
+
self.fn = fn
|
148 |
+
def forward(self, x, **kwargs):
|
149 |
+
return self.fn(self.norm(x), **kwargs)
|
150 |
+
|
151 |
+
|
152 |
+
class Attention(nn.Module):
|
153 |
+
def __init__(self, dim, heads, dim_head, dropout):
|
154 |
+
super().__init__()
|
155 |
+
inner_dim = dim_head * heads
|
156 |
+
project_out = not (heads == 1 and dim_head == dim)
|
157 |
+
|
158 |
+
self.heads = heads
|
159 |
+
self.scale = dim_head ** -0.5
|
160 |
+
|
161 |
+
self.cv_qk = nn.Sequential(
|
162 |
+
nn.Conv1d(dim, dim * 2, kernel_size=3, padding=1, bias = False),
|
163 |
+
nn.GLU(dim=1))
|
164 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
165 |
+
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
166 |
+
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
167 |
+
|
168 |
+
self.att_drop = nn.Dropout(dropout)
|
169 |
+
|
170 |
+
self.to_out = nn.Sequential(
|
171 |
+
nn.Linear(inner_dim, dim),
|
172 |
+
nn.Dropout(dropout)
|
173 |
+
) if project_out else nn.Identity()
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
qk = self.cv_qk(x.transpose(1, 2)).transpose(1, 2)
|
177 |
+
q = rearrange(self.to_q(qk), 'b n (h d) -> b h n d', h = self.heads)
|
178 |
+
k = rearrange(self.to_k(qk), 'b n (h d) -> b h n d', h=self.heads)
|
179 |
+
v = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h = self.heads)
|
180 |
+
|
181 |
+
weight = torch.matmul(F.softmax(k, dim=2).transpose(-1, -2), v) * self.scale
|
182 |
+
out = torch.matmul(F.softmax(q, dim=3), self.att_drop(weight))
|
183 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
184 |
+
return self.to_out(out)
|
185 |
+
|
186 |
+
|
187 |
+
class FeedForward(nn.Module):
|
188 |
+
def __init__(self, dim, hidden_dim, idx, dropout):
|
189 |
+
super().__init__()
|
190 |
+
self.PW1 = nn.Sequential(
|
191 |
+
nn.Linear(dim, hidden_dim//2),
|
192 |
+
nn.GELU(),
|
193 |
+
nn.Dropout(dropout)
|
194 |
+
)
|
195 |
+
self.PW2 = nn.Sequential(
|
196 |
+
nn.Linear(dim, hidden_dim//2),
|
197 |
+
nn.GELU(),
|
198 |
+
nn.Dropout(dropout)
|
199 |
+
)
|
200 |
+
self.DW_Conv = nn.Sequential(
|
201 |
+
nn.Conv1d(hidden_dim//2, hidden_dim//2, kernel_size=5, dilation=2**idx, padding='same'),
|
202 |
+
nn.GroupNorm(1, hidden_dim//2, 1e-5),
|
203 |
+
nn.PReLU(hidden_dim//2)
|
204 |
+
)
|
205 |
+
self.PW3 = nn.Sequential(
|
206 |
+
nn.Linear(hidden_dim, dim),
|
207 |
+
nn.Dropout(dropout)
|
208 |
+
)
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
ffw_out = self.PW1(x)
|
212 |
+
dw_out = self.DW_Conv(self.PW2(x).transpose(1, 2)).transpose(1, 2)
|
213 |
+
out = self.PW3(torch.cat((ffw_out, dw_out), dim=2))
|
214 |
+
return out
|
215 |
+
|
216 |
+
|
217 |
+
class DeFTANblock(nn.Module):
|
218 |
+
def __getitem__(self, key):
|
219 |
+
return getattr(self, key)
|
220 |
+
|
221 |
+
def __init__(self, idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps):
|
222 |
+
super().__init__()
|
223 |
+
in_channels = emb_dim * emb_ks
|
224 |
+
self.intra_norm = LayerNormalization4D(emb_dim, eps)
|
225 |
+
self.intra_inv = InverseDenseBlock1d(in_channels, emb_dim, emb_ks)
|
226 |
+
self.intra_att = PreNorm(emb_dim, Attention(emb_dim, n_head, att_dim, dropout))
|
227 |
+
self.intra_ffw = PreNorm(emb_dim, FeedForward(emb_dim, hidden_dim, idx, dropout))
|
228 |
+
self.intra_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
|
229 |
+
|
230 |
+
self.inter_norm = LayerNormalization4D(emb_dim, eps)
|
231 |
+
self.inter_inv = InverseDenseBlock1d(in_channels, emb_dim, emb_ks)
|
232 |
+
self.inter_att = PreNorm(emb_dim, Attention(emb_dim, n_head, att_dim, dropout))
|
233 |
+
self.inter_ffw = PreNorm(emb_dim, FeedForward(emb_dim, hidden_dim, idx, dropout))
|
234 |
+
self.inter_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
|
235 |
+
|
236 |
+
self.emb_dim = emb_dim
|
237 |
+
self.emb_ks = emb_ks
|
238 |
+
self.emb_hs = emb_hs
|
239 |
+
self.n_head = n_head
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
B, C, old_T, old_Q = x.shape
|
243 |
+
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
244 |
+
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
245 |
+
x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
|
246 |
+
|
247 |
+
# F-transformer
|
248 |
+
input_ = x
|
249 |
+
intra_rnn = self.intra_norm(input_) # [B, C, T, Q]
|
250 |
+
intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
|
251 |
+
intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*emb_ks, -1]
|
252 |
+
intra_rnn = self.intra_inv(intra_rnn) # [BT, C, -1]
|
253 |
+
|
254 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C]
|
255 |
+
intra_rnn = self.intra_att(intra_rnn) + intra_rnn
|
256 |
+
intra_rnn = self.intra_ffw(intra_rnn) + intra_rnn
|
257 |
+
intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1]
|
258 |
+
|
259 |
+
intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q]
|
260 |
+
intra_rnn = intra_rnn.view([B, T, C, Q])
|
261 |
+
intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q]
|
262 |
+
intra_rnn = intra_rnn + input_ # [B, C, T, Q]
|
263 |
+
|
264 |
+
# T-transformer
|
265 |
+
input_ = intra_rnn
|
266 |
+
inter_rnn = self.inter_norm(input_) # [B, C, T, F]
|
267 |
+
inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) # [BF, C, T]
|
268 |
+
inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BF, C*emb_ks, -1]
|
269 |
+
inter_rnn = self.inter_inv(inter_rnn) # [BF, C, -1]
|
270 |
+
|
271 |
+
inter_rnn = inter_rnn.transpose(1, 2) # [BF, -1, C]
|
272 |
+
inter_rnn = self.inter_att(inter_rnn) + inter_rnn
|
273 |
+
inter_rnn = self.inter_ffw(inter_rnn) + inter_rnn
|
274 |
+
inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1]
|
275 |
+
|
276 |
+
inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T]
|
277 |
+
inter_rnn = inter_rnn.view([B, Q, C, T])
|
278 |
+
inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous() # [B, C, T, Q]
|
279 |
+
inter_rnn = inter_rnn + input_ # [B, C, T, Q]
|
280 |
+
|
281 |
+
return inter_rnn
|
282 |
+
|
283 |
+
|
284 |
+
class LayerNormalization4D(nn.Module):
|
285 |
+
def __init__(self, input_dimension, eps=1e-5):
|
286 |
+
super().__init__()
|
287 |
+
param_size = [1, input_dimension, 1, 1]
|
288 |
+
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
289 |
+
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
290 |
+
init.ones_(self.gamma)
|
291 |
+
init.zeros_(self.beta)
|
292 |
+
self.eps = eps
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
if x.ndim == 4:
|
296 |
+
_, C, _, _ = x.shape
|
297 |
+
stat_dim = (1,)
|
298 |
+
else:
|
299 |
+
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
|
300 |
+
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
|
301 |
+
std_ = torch.sqrt(
|
302 |
+
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
|
303 |
+
) # [B,1,T,F]
|
304 |
+
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
305 |
+
return x_hat
|
306 |
+
|
307 |
+
|
308 |
+
class LayerNormalization4DCF(nn.Module):
|
309 |
+
def __init__(self, input_dimension, eps=1e-5):
|
310 |
+
super().__init__()
|
311 |
+
assert len(input_dimension) == 2
|
312 |
+
param_size = [1, input_dimension[0], 1, input_dimension[1]]
|
313 |
+
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
314 |
+
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
315 |
+
init.ones_(self.gamma)
|
316 |
+
init.zeros_(self.beta)
|
317 |
+
self.eps = eps
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
if x.ndim == 4:
|
321 |
+
stat_dim = (1, 3)
|
322 |
+
else:
|
323 |
+
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
|
324 |
+
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
|
325 |
+
std_ = torch.sqrt(
|
326 |
+
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
|
327 |
+
) # [B,1,T,F]
|
328 |
+
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
329 |
+
return x_hat
|