Spaces:
Sleeping
Sleeping
# Adapted from https://github.com/AppleHolic/source_separation/tree/master/source_separation | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.nn.init import calculate_gain | |
from typing import Tuple | |
from scipy.signal import get_window | |
from librosa.util import pad_center | |
from remfx.utils import single, concat_complex | |
class ComplexConvBlock(nn.Module): | |
""" | |
Convolution block | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
padding: int = 0, | |
layers: int = 4, | |
bn_func=nn.BatchNorm1d, | |
act_func=nn.LeakyReLU, | |
skip_res: bool = False, | |
): | |
super().__init__() | |
# modules | |
self.blocks = nn.ModuleList() | |
self.skip_res = skip_res | |
for idx in range(layers): | |
in_ = in_channels if idx == 0 else out_channels | |
self.blocks.append( | |
nn.Sequential( | |
*[ | |
bn_func(in_), | |
act_func(), | |
ComplexConv1d(in_, out_channels, kernel_size, padding=padding), | |
] | |
) | |
) | |
def forward(self, x: torch.tensor) -> torch.tensor: | |
temp = x | |
for idx, block in enumerate(self.blocks): | |
x = block(x) | |
if temp.size() != x.size() or self.skip_res: | |
return x | |
else: | |
return x + temp | |
class SpectrogramUnet(nn.Module): | |
def __init__( | |
self, | |
spec_dim: int, | |
hidden_dim: int, | |
filter_len: int, | |
hop_len: int, | |
layers: int = 3, | |
block_layers: int = 3, | |
kernel_size: int = 5, | |
is_mask: bool = False, | |
norm: str = "bn", | |
act: str = "tanh", | |
): | |
super().__init__() | |
self.layers = layers | |
self.is_mask = is_mask | |
# stft modules | |
self.stft = STFT(filter_len, hop_len) | |
if norm == "bn": | |
self.bn_func = nn.BatchNorm1d | |
elif norm == "ins": | |
self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True) | |
else: | |
raise NotImplementedError("{} is not implemented !".format(norm)) | |
if act == "tanh": | |
self.act_func = nn.Tanh | |
self.act_out = nn.Tanh | |
elif act == "comp": | |
self.act_func = ComplexActLayer | |
self.act_out = lambda: ComplexActLayer(is_out=True) | |
else: | |
raise NotImplementedError("{} is not implemented !".format(act)) | |
# prev conv | |
self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1) | |
# down | |
self.down = nn.ModuleList() | |
self.down_pool = nn.MaxPool1d(3, stride=2, padding=1) | |
for idx in range(self.layers): | |
block = ComplexConvBlock( | |
hidden_dim, | |
hidden_dim, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
bn_func=self.bn_func, | |
act_func=self.act_func, | |
layers=block_layers, | |
) | |
self.down.append(block) | |
# up | |
self.up = nn.ModuleList() | |
for idx in range(self.layers): | |
in_c = hidden_dim if idx == 0 else hidden_dim * 2 | |
self.up.append( | |
nn.Sequential( | |
ComplexConvBlock( | |
in_c, | |
hidden_dim, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
bn_func=self.bn_func, | |
act_func=self.act_func, | |
layers=block_layers, | |
), | |
self.bn_func(hidden_dim), | |
self.act_func(), | |
ComplexTransposedConv1d( | |
hidden_dim, hidden_dim, kernel_size=2, stride=2 | |
), | |
) | |
) | |
# out_conv | |
self.out_conv = nn.Sequential( | |
ComplexConvBlock( | |
hidden_dim * 2, | |
spec_dim * 2, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
bn_func=self.bn_func, | |
act_func=self.act_func, | |
), | |
self.bn_func(spec_dim * 2), | |
self.act_func(), | |
) | |
# refine conv | |
self.refine_conv = nn.Sequential( | |
ComplexConvBlock( | |
spec_dim * 4, | |
spec_dim * 2, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
bn_func=self.bn_func, | |
act_func=self.act_func, | |
), | |
self.bn_func(spec_dim * 2), | |
self.act_func(), | |
) | |
def log_stft(self, wav): | |
# stft | |
mag, phase = self.stft.transform(wav) | |
return torch.log(mag + 1), phase | |
def exp_istft(self, log_mag, phase): | |
# exp | |
mag = np.e**log_mag - 1 | |
# istft | |
wav = self.stft.inverse(mag, phase) | |
return wav | |
def adjust_diff(self, x, target): | |
size_diff = target.size()[-1] - x.size()[-1] | |
assert size_diff >= 0 | |
if size_diff > 0: | |
x = F.pad( | |
x.unsqueeze(1), (size_diff // 2, size_diff // 2), "reflect" | |
).squeeze(1) | |
return x | |
def masking(self, mag, phase, origin_mag, origin_phase): | |
abs_mag = torch.abs(mag) | |
mag_mask = torch.tanh(abs_mag) | |
phase_mask = mag / abs_mag | |
# masking | |
mag = mag_mask * origin_mag | |
phase = phase_mask * (origin_phase + phase) | |
return mag, phase | |
def forward(self, wav): | |
# stft | |
origin_mag, origin_phase = self.log_stft(wav) | |
origin_x = torch.cat([origin_mag, origin_phase], dim=1) | |
# prev | |
x = self.prev_conv(origin_x) | |
# body | |
# down | |
down_cache = [] | |
for idx, block in enumerate(self.down): | |
x = block(x) | |
down_cache.append(x) | |
x = self.down_pool(x) | |
# up | |
for idx, block in enumerate(self.up): | |
x = block(x) | |
res = F.interpolate( | |
down_cache[self.layers - (idx + 1)], | |
size=[x.size()[2]], | |
mode="linear", | |
align_corners=False, | |
) | |
x = concat_complex(x, res, dim=1) | |
# match spec dimension | |
x = self.out_conv(x) | |
if origin_mag.size(2) != x.size(2): | |
x = F.interpolate( | |
x, size=[origin_mag.size(2)], mode="linear", align_corners=False | |
) | |
# refine | |
x = self.refine_conv(concat_complex(x, origin_x)) | |
def to_wav(stft): | |
mag, phase = stft.chunk(2, 1) | |
if self.is_mask: | |
mag, phase = self.masking(mag, phase, origin_mag, origin_phase) | |
out = self.exp_istft(mag, phase) | |
out = self.adjust_diff(out, wav) | |
return out | |
refine_wav = to_wav(x) | |
return refine_wav | |
class RefineSpectrogramUnet(SpectrogramUnet): | |
def __init__( | |
self, | |
spec_dim: int, | |
hidden_dim: int, | |
filter_len: int, | |
hop_len: int, | |
layers: int = 4, | |
block_layers: int = 4, | |
kernel_size: int = 3, | |
is_mask: bool = True, | |
norm: str = "ins", | |
act: str = "comp", | |
refine_layers: int = 1, | |
add_spec_results: bool = False, | |
): | |
super().__init__( | |
spec_dim, | |
hidden_dim, | |
filter_len, | |
hop_len, | |
layers, | |
block_layers, | |
kernel_size, | |
is_mask, | |
norm, | |
act, | |
) | |
self.add_spec_results = add_spec_results | |
# refine conv | |
self.refine_conv = nn.ModuleList( | |
[ | |
nn.Sequential( | |
ComplexConvBlock( | |
spec_dim * 2, | |
spec_dim * 2, | |
kernel_size=kernel_size, | |
padding=kernel_size // 2, | |
bn_func=self.bn_func, | |
act_func=self.act_func, | |
), | |
self.bn_func(spec_dim * 2), | |
self.act_func(), | |
) | |
] | |
* refine_layers | |
) | |
def forward(self, wav): | |
# stft | |
origin_mag, origin_phase = self.log_stft(wav) | |
origin_x = torch.cat([origin_mag, origin_phase], dim=1) | |
# prev | |
x = self.prev_conv(origin_x) | |
# body | |
# down | |
down_cache = [] | |
for idx, block in enumerate(self.down): | |
x = block(x) | |
down_cache.append(x) | |
x = self.down_pool(x) | |
# up | |
for idx, block in enumerate(self.up): | |
x = block(x) | |
res = F.interpolate( | |
down_cache[self.layers - (idx + 1)], | |
size=[x.size()[2]], | |
mode="linear", | |
align_corners=False, | |
) | |
x = concat_complex(x, res, dim=1) | |
# match spec dimension | |
x = self.out_conv(x) | |
if origin_mag.size(2) != x.size(2): | |
x = F.interpolate( | |
x, size=[origin_mag.size(2)], mode="linear", align_corners=False | |
) | |
# refine | |
for idx, refine_module in enumerate(self.refine_conv): | |
x = refine_module(x) | |
mag, phase = x.chunk(2, 1) | |
mag, phase = self.masking(mag, phase, origin_mag, origin_phase) | |
if idx < len(self.refine_conv) - 1: | |
x = torch.cat([mag, phase], dim=1) | |
# clamp phase | |
phase = phase.clamp(-np.pi, np.pi) | |
out = self.exp_istft(mag, phase) | |
out = self.adjust_diff(out, wav) | |
if self.add_spec_results: | |
out = (out, mag, phase) | |
return out | |
class _ComplexConvNd(nn.Module): | |
""" | |
Implement Complex Convolution | |
A: real weight | |
B: img weight | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
transposed, | |
output_padding, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.output_padding = output_padding | |
self.transposed = transposed | |
self.A = self.make_weight(in_channels, out_channels, kernel_size) | |
self.B = self.make_weight(in_channels, out_channels, kernel_size) | |
self.reset_parameters() | |
def make_weight(self, in_ch, out_ch, kernel_size): | |
if self.transposed: | |
tensor = nn.Parameter(torch.Tensor(in_ch, out_ch // 2, *kernel_size)) | |
else: | |
tensor = nn.Parameter(torch.Tensor(out_ch, in_ch // 2, *kernel_size)) | |
return tensor | |
def reset_parameters(self): | |
# init real weight | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.A) | |
# init A | |
gain = calculate_gain("leaky_relu", 0) | |
std = gain / np.sqrt(fan_in) | |
bound = np.sqrt(3.0) * std | |
with torch.no_grad(): | |
# TODO: find more stable initial values | |
self.A.uniform_(-bound * (1 / (np.pi**2)), bound * (1 / (np.pi**2))) | |
# | |
# B is initialized by pi | |
# -pi and pi is too big, so it is powed by -1 | |
self.B.uniform_(-1 / np.pi, 1 / np.pi) | |
class ComplexConv1d(_ComplexConvNd): | |
""" | |
Complex Convolution 1d | |
""" | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1 | |
): | |
kernel_size = single(kernel_size) | |
stride = single(stride) | |
# edit padding | |
padding = padding | |
dilation = single(dilation) | |
super(ComplexConv1d, self).__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
False, | |
single(0), | |
) | |
def forward(self, x): | |
""" | |
Implemented complex convolution using combining 'grouped convolution' and | |
'real / img weight' | |
:param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels | |
:return: complex conved result | |
""" | |
# adopt reflect padding | |
if self.padding: | |
x = F.pad(x, (self.padding, self.padding), "reflect") | |
# forward real | |
real_part = F.conv1d( | |
x, | |
self.A, | |
None, | |
stride=self.stride, | |
padding=0, | |
dilation=self.dilation, | |
groups=2, | |
) | |
# forward idea | |
spl = self.in_channels // 2 | |
weight_B = torch.cat([self.B[:spl].data * (-1), self.B[spl:].data]) | |
idea_part = F.conv1d( | |
x, | |
weight_B, | |
None, | |
stride=self.stride, | |
padding=0, | |
dilation=self.dilation, | |
groups=2, | |
) | |
return real_part + idea_part | |
class ComplexTransposedConv1d(_ComplexConvNd): | |
""" | |
Complex Transposed Convolution 1d | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=0, | |
output_padding=0, | |
dilation=1, | |
): | |
kernel_size = single(kernel_size) | |
stride = single(stride) | |
padding = padding | |
dilation = single(dilation) | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
dilation, | |
True, | |
output_padding, | |
) | |
def forward(self, x, output_size=None): | |
""" | |
Implemented complex transposed convolution using combining 'grouped convolution' | |
and 'real / img weight' | |
:param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels | |
:return: complex transposed convolution result | |
""" | |
# forward real | |
if self.padding: | |
x = F.pad(x, (self.padding, self.padding), "reflect") | |
real_part = F.conv_transpose1d( | |
x, | |
self.A, | |
None, | |
stride=self.stride, | |
padding=0, | |
dilation=self.dilation, | |
groups=2, | |
) | |
# forward idea | |
spl = self.out_channels // 2 | |
weight_B = torch.cat([self.B[:spl] * (-1), self.B[spl:]]) | |
idea_part = F.conv_transpose1d( | |
x, | |
weight_B, | |
None, | |
stride=self.stride, | |
padding=0, | |
dilation=self.dilation, | |
groups=2, | |
) | |
if self.output_padding: | |
real_part = F.pad( | |
real_part, (self.output_padding, self.output_padding), "reflect" | |
) | |
idea_part = F.pad( | |
idea_part, (self.output_padding, self.output_padding), "reflect" | |
) | |
return real_part + idea_part | |
class ComplexActLayer(nn.Module): | |
""" | |
Activation differently 'real' part and 'img' part | |
In implemented DCUnet on this repository, Real part is activated to log space. | |
And Phase(img) part, it is distributed in [-pi, pi]... | |
""" | |
def forward(self, x): | |
real, img = x.chunk(2, 1) | |
return torch.cat([F.leaky_relu(real), torch.tanh(img) * np.pi], dim=1) | |
class STFT(nn.Module): | |
""" | |
Re-construct stft for calculating backward operation | |
refer on : https://github.com/pseeth/torch-stft/blob/master/torch_stft/stft.py | |
""" | |
def __init__( | |
self, | |
filter_length: int = 1024, | |
hop_length: int = 512, | |
win_length: int = None, | |
window: str = "hann", | |
): | |
super().__init__() | |
self.filter_length = filter_length | |
self.hop_length = hop_length | |
self.win_length = win_length if win_length else filter_length | |
self.window = window | |
self.pad_amount = self.filter_length // 2 | |
# make fft window | |
assert filter_length >= self.win_length | |
# get window and zero center pad it to filter_length | |
fft_window = get_window(window, self.win_length, fftbins=True) | |
fft_window = pad_center(fft_window, filter_length) | |
fft_window = torch.from_numpy(fft_window).float() | |
# calculate fourer_basis | |
cut_off = int((self.filter_length / 2 + 1)) | |
fourier_basis = np.fft.fft(np.eye(self.filter_length)) | |
fourier_basis = np.vstack( | |
[np.real(fourier_basis[:cut_off, :]), np.imag(fourier_basis[:cut_off, :])] | |
) | |
# make forward & inverse basis | |
self.register_buffer("square_window", fft_window**2) | |
forward_basis = torch.FloatTensor(fourier_basis[:, np.newaxis, :]) * fft_window | |
inverse_basis = ( | |
torch.FloatTensor( | |
np.linalg.pinv(self.filter_length / self.hop_length * fourier_basis).T[ | |
:, np.newaxis, : | |
] | |
) | |
* fft_window | |
) | |
# torch.pinverse has a bug, so at this time, it is separated into two parts.. | |
self.register_buffer("forward_basis", forward_basis) | |
self.register_buffer("inverse_basis", inverse_basis) | |
def transform(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
# reflect padding | |
wav = wav.unsqueeze(1).unsqueeze(1) | |
wav = F.pad( | |
wav, (self.pad_amount, self.pad_amount, 0, 0), mode="reflect" | |
).squeeze(1) | |
# conv | |
forward_trans = F.conv1d( | |
wav, self.forward_basis, stride=self.hop_length, padding=0 | |
) | |
real_part, imag_part = forward_trans.chunk(2, 1) | |
return torch.sqrt(real_part**2 + imag_part**2), torch.atan2( | |
imag_part.data, real_part.data | |
) | |
def inverse( | |
self, magnitude: torch.Tensor, phase: torch.Tensor, eps: float = 1e-9 | |
) -> torch.Tensor: | |
comp = torch.cat( | |
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 | |
) | |
inverse_transform = F.conv_transpose1d( | |
comp, self.inverse_basis, stride=self.hop_length, padding=0 | |
) | |
# remove window effect | |
n_frames = comp.size(-1) | |
inverse_size = inverse_transform.size(-1) | |
window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform) | |
weight = self.square_window[: self.filter_length].unsqueeze(0).unsqueeze(0) | |
window_filter = F.conv_transpose1d( | |
window_filter, weight, stride=self.hop_length, padding=0 | |
) | |
window_filter = window_filter.squeeze()[:inverse_size] + eps | |
inverse_transform /= window_filter | |
# scale by hop ratio | |
inverse_transform *= self.filter_length / self.hop_length | |
return inverse_transform[..., self.pad_amount : -self.pad_amount].squeeze(1) | |