|
import math
|
|
from collections import OrderedDict
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
from packaging.version import parse as V
|
|
from torch.nn import init
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
|
|
|
|
class Network(nn.Module):
|
|
def __init__(self, n_srcs=1, win=512, n_mics=4, n_layers=12, att_dim=64, hidden_dim=256, n_head=4, emb_dim=64, emb_ks=4, emb_hs=1, dropout=0.1, eps=1.0e-5):
|
|
super().__init__()
|
|
self.n_srcs = n_srcs
|
|
self.win = win
|
|
self.hop = win // 2
|
|
self.n_layers = n_layers
|
|
self.n_mics = n_mics
|
|
self.emb_dim = emb_dim
|
|
assert win % 2 == 0
|
|
|
|
t_ksize = 3
|
|
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(2 * n_mics, emb_dim * n_head, ks, padding=padding),
|
|
nn.GroupNorm(1, emb_dim * n_head, eps=eps),
|
|
InverseDenseBlock2d(emb_dim * n_head, emb_dim, n_head)
|
|
)
|
|
self.blocks = nn.ModuleList([])
|
|
for idx in range(n_layers):
|
|
self.blocks.append(DeFTANblock(idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps))
|
|
self.deconv = nn.Sequential(
|
|
nn.Conv2d(emb_dim, 2 * n_srcs * n_head, ks, padding=padding),
|
|
InverseDenseBlock2d(2 * n_srcs * n_head, 2 * n_srcs, n_head))
|
|
|
|
def pad_signal(self, input):
|
|
|
|
|
|
if input.dim() not in [2, 3]:
|
|
raise RuntimeError("Input can only be 2 or 3 dimensional.")
|
|
|
|
if input.dim() == 2:
|
|
input = input.unsqueeze(1)
|
|
batch_size = input.size(0)
|
|
nchannel = input.size(1)
|
|
nsample = input.size(2)
|
|
rest = self.win - (self.hop + nsample % self.win) % self.win
|
|
if rest > 0:
|
|
pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type())
|
|
input = torch.cat([input, pad], 2)
|
|
|
|
pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type())
|
|
input = torch.cat([pad_aux, input, pad_aux], 2)
|
|
|
|
return input, rest
|
|
|
|
def forward(self, input: Union[torch.Tensor]) -> Tuple[List[Union[torch.Tensor]], torch.Tensor, OrderedDict]:
|
|
input, rest = self.pad_signal(input)
|
|
B, M, N = input.size()
|
|
mix_std_ = torch.std(input, dim=(1, 2), keepdim=True)
|
|
input = input / mix_std_
|
|
|
|
stft_input = torch.stft(input.view([-1, N]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
|
|
_, F, T, _ = stft_input.size()
|
|
xi = stft_input.view([B, M, F, T, 2])
|
|
xi = xi.permute(0, 1, 4, 3, 2).contiguous()
|
|
batch = xi.view([B, M * 2, T, F])
|
|
|
|
batch = self.conv(batch)
|
|
for ii in range(self.n_layers):
|
|
batch = self.blocks[ii](batch)
|
|
batch = self.deconv(batch).view([B, self.n_srcs, 2, T, F]).view([B * self.n_srcs, 2, T, F])
|
|
|
|
batch = batch.permute(0, 3, 2, 1).type(input.type())
|
|
istft_input = torch.complex(batch[:, :, :, 0], batch[:, :, :, 1])
|
|
istft_output = torch.istft(istft_input, n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
|
|
|
|
output = istft_output[:, self.hop:-(rest + self.hop)].unsqueeze(1)
|
|
output = output.view([B, self.n_srcs, -1])
|
|
output = output * mix_std_
|
|
|
|
return output
|
|
|
|
|
|
class InverseDenseBlock1d(nn.Module):
|
|
def __init__(self, in_channels, out_channels, groups):
|
|
super().__init__()
|
|
assert in_channels // out_channels == groups
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.groups = groups
|
|
self.blocks = nn.ModuleList([])
|
|
for idx in range(groups):
|
|
self.blocks.append(nn.Sequential(
|
|
nn.Conv1d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=3, padding=1),
|
|
nn.GroupNorm(1, out_channels, 1e-5),
|
|
nn.PReLU(out_channels)
|
|
))
|
|
|
|
def forward(self, x):
|
|
B, C, L = x.size()
|
|
g = self.groups
|
|
x = x.view(B, g, C//g, L).transpose(1, 2).reshape(B, C, L)
|
|
skip = x[:, ::g, :]
|
|
for idx in range(g):
|
|
output = self.blocks[idx](skip)
|
|
skip = torch.cat([output, x[:, idx+1::g, :]], dim=1)
|
|
return output
|
|
|
|
|
|
class InverseDenseBlock2d(nn.Module):
|
|
def __init__(self, in_channels, out_channels, groups):
|
|
super().__init__()
|
|
assert in_channels // out_channels == groups
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.groups = groups
|
|
self.blocks = nn.ModuleList([])
|
|
for idx in range(groups):
|
|
self.blocks.append(nn.Sequential(
|
|
nn.Conv2d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=(3, 3), padding=(1, 1)),
|
|
nn.GroupNorm(1, out_channels, 1e-5),
|
|
nn.PReLU(out_channels)
|
|
))
|
|
|
|
def forward(self, x):
|
|
B, C, T, Q = x.size()
|
|
g = self.groups
|
|
x = x.view(B, g, C//g, T, Q).transpose(1, 2).reshape(B, C, T, Q)
|
|
skip = x[:, ::g, :, :]
|
|
for idx in range(g):
|
|
output = self.blocks[idx](skip)
|
|
skip = torch.cat([output, x[:, idx+1::g, :, :]], dim=1)
|
|
return output
|
|
|
|
|
|
class PreNorm(nn.Module):
|
|
def __init__(self, dim, fn):
|
|
super().__init__()
|
|
self.norm = nn.LayerNorm(dim)
|
|
self.fn = fn
|
|
def forward(self, x, **kwargs):
|
|
return self.fn(self.norm(x), **kwargs)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, dim, heads, dim_head, dropout):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
project_out = not (heads == 1 and dim_head == dim)
|
|
|
|
self.heads = heads
|
|
self.scale = dim_head ** -0.5
|
|
|
|
self.cv_qk = nn.Sequential(
|
|
nn.Conv1d(dim, dim * 2, kernel_size=3, padding=1, bias = False),
|
|
nn.GLU(dim=1))
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
|
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
|
|
|
self.att_drop = nn.Dropout(dropout)
|
|
|
|
self.to_out = nn.Sequential(
|
|
nn.Linear(inner_dim, dim),
|
|
nn.Dropout(dropout)
|
|
) if project_out else nn.Identity()
|
|
|
|
def forward(self, x):
|
|
qk = self.cv_qk(x.transpose(1, 2)).transpose(1, 2)
|
|
q = rearrange(self.to_q(qk), 'b n (h d) -> b h n d', h = self.heads)
|
|
k = rearrange(self.to_k(qk), 'b n (h d) -> b h n d', h=self.heads)
|
|
v = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h = self.heads)
|
|
|
|
weight = torch.matmul(F.softmax(k, dim=2).transpose(-1, -2), v) * self.scale
|
|
out = torch.matmul(F.softmax(q, dim=3), self.att_drop(weight))
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
return self.to_out(out)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, hidden_dim, idx, dropout):
|
|
super().__init__()
|
|
self.PW1 = nn.Sequential(
|
|
nn.Linear(dim, hidden_dim//2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout)
|
|
)
|
|
self.PW2 = nn.Sequential(
|
|
nn.Linear(dim, hidden_dim//2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout)
|
|
)
|
|
self.DW_Conv = nn.Sequential(
|
|
nn.Conv1d(hidden_dim//2, hidden_dim//2, kernel_size=5, dilation=2**idx, padding='same'),
|
|
nn.GroupNorm(1, hidden_dim//2, 1e-5),
|
|
nn.PReLU(hidden_dim//2)
|
|
)
|
|
self.PW3 = nn.Sequential(
|
|
nn.Linear(hidden_dim, dim),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(self, x):
|
|
ffw_out = self.PW1(x)
|
|
dw_out = self.DW_Conv(self.PW2(x).transpose(1, 2)).transpose(1, 2)
|
|
out = self.PW3(torch.cat((ffw_out, dw_out), dim=2))
|
|
return out
|
|
|
|
|
|
class DeFTANblock(nn.Module):
|
|
def __getitem__(self, key):
|
|
return getattr(self, key)
|
|
|
|
def __init__(self, idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps):
|
|
super().__init__()
|
|
in_channels = emb_dim * emb_ks
|
|
self.intra_norm = LayerNormalization4D(emb_dim, eps)
|
|
self.intra_inv = InverseDenseBlock1d(in_channels, emb_dim, emb_ks)
|
|
self.intra_att = PreNorm(emb_dim, Attention(emb_dim, n_head, att_dim, dropout))
|
|
self.intra_ffw = PreNorm(emb_dim, FeedForward(emb_dim, hidden_dim, idx, dropout))
|
|
self.intra_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
|
|
|
|
self.inter_norm = LayerNormalization4D(emb_dim, eps)
|
|
self.inter_inv = InverseDenseBlock1d(in_channels, emb_dim, emb_ks)
|
|
self.inter_att = PreNorm(emb_dim, Attention(emb_dim, n_head, att_dim, dropout))
|
|
self.inter_ffw = PreNorm(emb_dim, FeedForward(emb_dim, hidden_dim, idx, dropout))
|
|
self.inter_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
|
|
|
|
self.emb_dim = emb_dim
|
|
self.emb_ks = emb_ks
|
|
self.emb_hs = emb_hs
|
|
self.n_head = n_head
|
|
|
|
def forward(self, x):
|
|
B, C, old_T, old_Q = x.shape
|
|
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
|
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
|
|
x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
|
|
|
|
|
|
input_ = x
|
|
intra_rnn = self.intra_norm(input_)
|
|
intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q)
|
|
intra_rnn = F.unfold(intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1))
|
|
intra_rnn = self.intra_inv(intra_rnn)
|
|
|
|
intra_rnn = intra_rnn.transpose(1, 2)
|
|
intra_rnn = self.intra_att(intra_rnn) + intra_rnn
|
|
intra_rnn = self.intra_ffw(intra_rnn) + intra_rnn
|
|
intra_rnn = intra_rnn.transpose(1, 2)
|
|
|
|
intra_rnn = self.intra_linear(intra_rnn)
|
|
intra_rnn = intra_rnn.view([B, T, C, Q])
|
|
intra_rnn = intra_rnn.transpose(1, 2).contiguous()
|
|
intra_rnn = intra_rnn + input_
|
|
|
|
|
|
input_ = intra_rnn
|
|
inter_rnn = self.inter_norm(input_)
|
|
inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
|
|
inter_rnn = F.unfold(inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1))
|
|
inter_rnn = self.inter_inv(inter_rnn)
|
|
|
|
inter_rnn = inter_rnn.transpose(1, 2)
|
|
inter_rnn = self.inter_att(inter_rnn) + inter_rnn
|
|
inter_rnn = self.inter_ffw(inter_rnn) + inter_rnn
|
|
inter_rnn = inter_rnn.transpose(1, 2)
|
|
|
|
inter_rnn = self.inter_linear(inter_rnn)
|
|
inter_rnn = inter_rnn.view([B, Q, C, T])
|
|
inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous()
|
|
inter_rnn = inter_rnn + input_
|
|
|
|
return inter_rnn
|
|
|
|
|
|
class LayerNormalization4D(nn.Module):
|
|
def __init__(self, input_dimension, eps=1e-5):
|
|
super().__init__()
|
|
param_size = [1, input_dimension, 1, 1]
|
|
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
|
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
|
init.ones_(self.gamma)
|
|
init.zeros_(self.beta)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
if x.ndim == 4:
|
|
_, C, _, _ = x.shape
|
|
stat_dim = (1,)
|
|
else:
|
|
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
|
|
mu_ = x.mean(dim=stat_dim, keepdim=True)
|
|
std_ = torch.sqrt(
|
|
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
|
|
)
|
|
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
|
return x_hat
|
|
|
|
|
|
class LayerNormalization4DCF(nn.Module):
|
|
def __init__(self, input_dimension, eps=1e-5):
|
|
super().__init__()
|
|
assert len(input_dimension) == 2
|
|
param_size = [1, input_dimension[0], 1, input_dimension[1]]
|
|
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
|
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
|
|
init.ones_(self.gamma)
|
|
init.zeros_(self.beta)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
if x.ndim == 4:
|
|
stat_dim = (1, 3)
|
|
else:
|
|
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
|
|
mu_ = x.mean(dim=stat_dim, keepdim=True)
|
|
std_ = torch.sqrt(
|
|
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
|
|
)
|
|
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
|
|
return x_hat |