RemFx / remfx /dcunet.py
mattricesound's picture
Update callbacks, debug new models
9a9a2c9
raw
history blame
19.1 kB
# Adapted from https://github.com/AppleHolic/source_separation/tree/master/source_separation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.init import calculate_gain
from typing import Tuple
from scipy.signal import get_window
from librosa.util import pad_center
from remfx.utils import single, concat_complex
class ComplexConvBlock(nn.Module):
"""
Convolution block
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: int = 0,
layers: int = 4,
bn_func=nn.BatchNorm1d,
act_func=nn.LeakyReLU,
skip_res: bool = False,
):
super().__init__()
# modules
self.blocks = nn.ModuleList()
self.skip_res = skip_res
for idx in range(layers):
in_ = in_channels if idx == 0 else out_channels
self.blocks.append(
nn.Sequential(
*[
bn_func(in_),
act_func(),
ComplexConv1d(in_, out_channels, kernel_size, padding=padding),
]
)
)
def forward(self, x: torch.tensor) -> torch.tensor:
temp = x
for idx, block in enumerate(self.blocks):
x = block(x)
if temp.size() != x.size() or self.skip_res:
return x
else:
return x + temp
class SpectrogramUnet(nn.Module):
def __init__(
self,
spec_dim: int,
hidden_dim: int,
filter_len: int,
hop_len: int,
layers: int = 3,
block_layers: int = 3,
kernel_size: int = 5,
is_mask: bool = False,
norm: str = "bn",
act: str = "tanh",
):
super().__init__()
self.layers = layers
self.is_mask = is_mask
# stft modules
self.stft = STFT(filter_len, hop_len)
if norm == "bn":
self.bn_func = nn.BatchNorm1d
elif norm == "ins":
self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True)
else:
raise NotImplementedError("{} is not implemented !".format(norm))
if act == "tanh":
self.act_func = nn.Tanh
self.act_out = nn.Tanh
elif act == "comp":
self.act_func = ComplexActLayer
self.act_out = lambda: ComplexActLayer(is_out=True)
else:
raise NotImplementedError("{} is not implemented !".format(act))
# prev conv
self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1)
# down
self.down = nn.ModuleList()
self.down_pool = nn.MaxPool1d(3, stride=2, padding=1)
for idx in range(self.layers):
block = ComplexConvBlock(
hidden_dim,
hidden_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
bn_func=self.bn_func,
act_func=self.act_func,
layers=block_layers,
)
self.down.append(block)
# up
self.up = nn.ModuleList()
for idx in range(self.layers):
in_c = hidden_dim if idx == 0 else hidden_dim * 2
self.up.append(
nn.Sequential(
ComplexConvBlock(
in_c,
hidden_dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
bn_func=self.bn_func,
act_func=self.act_func,
layers=block_layers,
),
self.bn_func(hidden_dim),
self.act_func(),
ComplexTransposedConv1d(
hidden_dim, hidden_dim, kernel_size=2, stride=2
),
)
)
# out_conv
self.out_conv = nn.Sequential(
ComplexConvBlock(
hidden_dim * 2,
spec_dim * 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
bn_func=self.bn_func,
act_func=self.act_func,
),
self.bn_func(spec_dim * 2),
self.act_func(),
)
# refine conv
self.refine_conv = nn.Sequential(
ComplexConvBlock(
spec_dim * 4,
spec_dim * 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
bn_func=self.bn_func,
act_func=self.act_func,
),
self.bn_func(spec_dim * 2),
self.act_func(),
)
def log_stft(self, wav):
# stft
mag, phase = self.stft.transform(wav)
return torch.log(mag + 1), phase
def exp_istft(self, log_mag, phase):
# exp
mag = np.e**log_mag - 1
# istft
wav = self.stft.inverse(mag, phase)
return wav
def adjust_diff(self, x, target):
size_diff = target.size()[-1] - x.size()[-1]
assert size_diff >= 0
if size_diff > 0:
x = F.pad(
x.unsqueeze(1), (size_diff // 2, size_diff // 2), "reflect"
).squeeze(1)
return x
def masking(self, mag, phase, origin_mag, origin_phase):
abs_mag = torch.abs(mag)
mag_mask = torch.tanh(abs_mag)
phase_mask = mag / abs_mag
# masking
mag = mag_mask * origin_mag
phase = phase_mask * (origin_phase + phase)
return mag, phase
def forward(self, wav):
# stft
origin_mag, origin_phase = self.log_stft(wav)
origin_x = torch.cat([origin_mag, origin_phase], dim=1)
# prev
x = self.prev_conv(origin_x)
# body
# down
down_cache = []
for idx, block in enumerate(self.down):
x = block(x)
down_cache.append(x)
x = self.down_pool(x)
# up
for idx, block in enumerate(self.up):
x = block(x)
res = F.interpolate(
down_cache[self.layers - (idx + 1)],
size=[x.size()[2]],
mode="linear",
align_corners=False,
)
x = concat_complex(x, res, dim=1)
# match spec dimension
x = self.out_conv(x)
if origin_mag.size(2) != x.size(2):
x = F.interpolate(
x, size=[origin_mag.size(2)], mode="linear", align_corners=False
)
# refine
x = self.refine_conv(concat_complex(x, origin_x))
def to_wav(stft):
mag, phase = stft.chunk(2, 1)
if self.is_mask:
mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
out = self.exp_istft(mag, phase)
out = self.adjust_diff(out, wav)
return out
refine_wav = to_wav(x)
return refine_wav
class RefineSpectrogramUnet(SpectrogramUnet):
def __init__(
self,
spec_dim: int,
hidden_dim: int,
filter_len: int,
hop_len: int,
layers: int = 4,
block_layers: int = 4,
kernel_size: int = 3,
is_mask: bool = True,
norm: str = "ins",
act: str = "comp",
refine_layers: int = 1,
add_spec_results: bool = False,
):
super().__init__(
spec_dim,
hidden_dim,
filter_len,
hop_len,
layers,
block_layers,
kernel_size,
is_mask,
norm,
act,
)
self.add_spec_results = add_spec_results
# refine conv
self.refine_conv = nn.ModuleList(
[
nn.Sequential(
ComplexConvBlock(
spec_dim * 2,
spec_dim * 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
bn_func=self.bn_func,
act_func=self.act_func,
),
self.bn_func(spec_dim * 2),
self.act_func(),
)
]
* refine_layers
)
def forward(self, wav):
# stft
origin_mag, origin_phase = self.log_stft(wav)
origin_x = torch.cat([origin_mag, origin_phase], dim=1)
# prev
x = self.prev_conv(origin_x)
# body
# down
down_cache = []
for idx, block in enumerate(self.down):
x = block(x)
down_cache.append(x)
x = self.down_pool(x)
# up
for idx, block in enumerate(self.up):
x = block(x)
res = F.interpolate(
down_cache[self.layers - (idx + 1)],
size=[x.size()[2]],
mode="linear",
align_corners=False,
)
x = concat_complex(x, res, dim=1)
# match spec dimension
x = self.out_conv(x)
if origin_mag.size(2) != x.size(2):
x = F.interpolate(
x, size=[origin_mag.size(2)], mode="linear", align_corners=False
)
# refine
for idx, refine_module in enumerate(self.refine_conv):
x = refine_module(x)
mag, phase = x.chunk(2, 1)
mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
if idx < len(self.refine_conv) - 1:
x = torch.cat([mag, phase], dim=1)
# clamp phase
phase = phase.clamp(-np.pi, np.pi)
out = self.exp_istft(mag, phase)
out = self.adjust_diff(out, wav)
if self.add_spec_results:
out = (out, mag, phase)
return out
class _ComplexConvNd(nn.Module):
"""
Implement Complex Convolution
A: real weight
B: img weight
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.output_padding = output_padding
self.transposed = transposed
self.A = self.make_weight(in_channels, out_channels, kernel_size)
self.B = self.make_weight(in_channels, out_channels, kernel_size)
self.reset_parameters()
def make_weight(self, in_ch, out_ch, kernel_size):
if self.transposed:
tensor = nn.Parameter(torch.Tensor(in_ch, out_ch // 2, *kernel_size))
else:
tensor = nn.Parameter(torch.Tensor(out_ch, in_ch // 2, *kernel_size))
return tensor
def reset_parameters(self):
# init real weight
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.A)
# init A
gain = calculate_gain("leaky_relu", 0)
std = gain / np.sqrt(fan_in)
bound = np.sqrt(3.0) * std
with torch.no_grad():
# TODO: find more stable initial values
self.A.uniform_(-bound * (1 / (np.pi**2)), bound * (1 / (np.pi**2)))
#
# B is initialized by pi
# -pi and pi is too big, so it is powed by -1
self.B.uniform_(-1 / np.pi, 1 / np.pi)
class ComplexConv1d(_ComplexConvNd):
"""
Complex Convolution 1d
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
):
kernel_size = single(kernel_size)
stride = single(stride)
# edit padding
padding = padding
dilation = single(dilation)
super(ComplexConv1d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
single(0),
)
def forward(self, x):
"""
Implemented complex convolution using combining 'grouped convolution' and
'real / img weight'
:param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
:return: complex conved result
"""
# adopt reflect padding
if self.padding:
x = F.pad(x, (self.padding, self.padding), "reflect")
# forward real
real_part = F.conv1d(
x,
self.A,
None,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=2,
)
# forward idea
spl = self.in_channels // 2
weight_B = torch.cat([self.B[:spl].data * (-1), self.B[spl:].data])
idea_part = F.conv1d(
x,
weight_B,
None,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=2,
)
return real_part + idea_part
class ComplexTransposedConv1d(_ComplexConvNd):
"""
Complex Transposed Convolution 1d
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
):
kernel_size = single(kernel_size)
stride = single(stride)
padding = padding
dilation = single(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
True,
output_padding,
)
def forward(self, x, output_size=None):
"""
Implemented complex transposed convolution using combining 'grouped convolution'
and 'real / img weight'
:param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
:return: complex transposed convolution result
"""
# forward real
if self.padding:
x = F.pad(x, (self.padding, self.padding), "reflect")
real_part = F.conv_transpose1d(
x,
self.A,
None,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=2,
)
# forward idea
spl = self.out_channels // 2
weight_B = torch.cat([self.B[:spl] * (-1), self.B[spl:]])
idea_part = F.conv_transpose1d(
x,
weight_B,
None,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=2,
)
if self.output_padding:
real_part = F.pad(
real_part, (self.output_padding, self.output_padding), "reflect"
)
idea_part = F.pad(
idea_part, (self.output_padding, self.output_padding), "reflect"
)
return real_part + idea_part
class ComplexActLayer(nn.Module):
"""
Activation differently 'real' part and 'img' part
In implemented DCUnet on this repository, Real part is activated to log space.
And Phase(img) part, it is distributed in [-pi, pi]...
"""
def forward(self, x):
real, img = x.chunk(2, 1)
return torch.cat([F.leaky_relu(real), torch.tanh(img) * np.pi], dim=1)
class STFT(nn.Module):
"""
Re-construct stft for calculating backward operation
refer on : https://github.com/pseeth/torch-stft/blob/master/torch_stft/stft.py
"""
def __init__(
self,
filter_length: int = 1024,
hop_length: int = 512,
win_length: int = None,
window: str = "hann",
):
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length if win_length else filter_length
self.window = window
self.pad_amount = self.filter_length // 2
# make fft window
assert filter_length >= self.win_length
# get window and zero center pad it to filter_length
fft_window = get_window(window, self.win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# calculate fourer_basis
cut_off = int((self.filter_length / 2 + 1))
fourier_basis = np.fft.fft(np.eye(self.filter_length))
fourier_basis = np.vstack(
[np.real(fourier_basis[:cut_off, :]), np.imag(fourier_basis[:cut_off, :])]
)
# make forward & inverse basis
self.register_buffer("square_window", fft_window**2)
forward_basis = torch.FloatTensor(fourier_basis[:, np.newaxis, :]) * fft_window
inverse_basis = (
torch.FloatTensor(
np.linalg.pinv(self.filter_length / self.hop_length * fourier_basis).T[
:, np.newaxis, :
]
)
* fft_window
)
# torch.pinverse has a bug, so at this time, it is separated into two parts..
self.register_buffer("forward_basis", forward_basis)
self.register_buffer("inverse_basis", inverse_basis)
def transform(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# reflect padding
wav = wav.unsqueeze(1).unsqueeze(1)
wav = F.pad(
wav, (self.pad_amount, self.pad_amount, 0, 0), mode="reflect"
).squeeze(1)
# conv
forward_trans = F.conv1d(
wav, self.forward_basis, stride=self.hop_length, padding=0
)
real_part, imag_part = forward_trans.chunk(2, 1)
return torch.sqrt(real_part**2 + imag_part**2), torch.atan2(
imag_part.data, real_part.data
)
def inverse(
self, magnitude: torch.Tensor, phase: torch.Tensor, eps: float = 1e-9
) -> torch.Tensor:
comp = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
inverse_transform = F.conv_transpose1d(
comp, self.inverse_basis, stride=self.hop_length, padding=0
)
# remove window effect
n_frames = comp.size(-1)
inverse_size = inverse_transform.size(-1)
window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform)
weight = self.square_window[: self.filter_length].unsqueeze(0).unsqueeze(0)
window_filter = F.conv_transpose1d(
window_filter, weight, stride=self.hop_length, padding=0
)
window_filter = window_filter.squeeze()[:inverse_size] + eps
inverse_transform /= window_filter
# scale by hop ratio
inverse_transform *= self.filter_length / self.hop_length
return inverse_transform[..., self.pad_amount : -self.pad_amount].squeeze(1)