|
import torch |
|
import torch.nn as nn |
|
import segmentation_models_pytorch as smp |
|
from utils import prefer_target_instrument |
|
|
|
class STFT: |
|
def __init__(self, config): |
|
self.n_fft = config.n_fft |
|
self.hop_length = config.hop_length |
|
self.window = torch.hann_window(window_length=self.n_fft, periodic=True) |
|
self.dim_f = config.dim_f |
|
|
|
def __call__(self, x): |
|
window = self.window.to(x.device) |
|
batch_dims = x.shape[:-2] |
|
c, t = x.shape[-2:] |
|
x = x.reshape([-1, t]) |
|
x = torch.stft( |
|
x, |
|
n_fft=self.n_fft, |
|
hop_length=self.hop_length, |
|
window=window, |
|
center=True, |
|
return_complex=True |
|
) |
|
x = torch.view_as_real(x) |
|
x = x.permute([0, 3, 1, 2]) |
|
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) |
|
return x[..., :self.dim_f, :] |
|
|
|
def inverse(self, x): |
|
window = self.window.to(x.device) |
|
batch_dims = x.shape[:-3] |
|
c, f, t = x.shape[-3:] |
|
n = self.n_fft // 2 + 1 |
|
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) |
|
x = torch.cat([x, f_pad], -2) |
|
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) |
|
x = x.permute([0, 2, 3, 1]) |
|
x = x[..., 0] + x[..., 1] * 1.j |
|
x = torch.istft( |
|
x, |
|
n_fft=self.n_fft, |
|
hop_length=self.hop_length, |
|
window=window, |
|
center=True |
|
) |
|
x = x.reshape([*batch_dims, 2, -1]) |
|
return x |
|
|
|
|
|
def get_act(act_type): |
|
if act_type == 'gelu': |
|
return nn.GELU() |
|
elif act_type == 'relu': |
|
return nn.ReLU() |
|
elif act_type[:3] == 'elu': |
|
alpha = float(act_type.replace('elu', '')) |
|
return nn.ELU(alpha) |
|
else: |
|
raise Exception |
|
|
|
|
|
def get_decoder(config, c): |
|
decoder = None |
|
decoder_options = dict() |
|
if config.model.decoder_type == 'unet': |
|
try: |
|
decoder_options = dict(config.decoder_unet) |
|
except: |
|
pass |
|
decoder = smp.Unet( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'fpn': |
|
try: |
|
decoder_options = dict(config.decoder_fpn) |
|
except: |
|
pass |
|
decoder = smp.FPN( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'unet++': |
|
try: |
|
decoder_options = dict(config.decoder_unet_plus_plus) |
|
except: |
|
pass |
|
decoder = smp.UnetPlusPlus( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'manet': |
|
try: |
|
decoder_options = dict(config.decoder_manet) |
|
except: |
|
pass |
|
decoder = smp.MAnet( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'linknet': |
|
try: |
|
decoder_options = dict(config.decoder_linknet) |
|
except: |
|
pass |
|
decoder = smp.Linknet( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'pspnet': |
|
try: |
|
decoder_options = dict(config.decoder_pspnet) |
|
except: |
|
pass |
|
decoder = smp.PSPNet( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'pspnet': |
|
try: |
|
decoder_options = dict(config.decoder_pspnet) |
|
except: |
|
pass |
|
decoder = smp.PSPNet( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'pan': |
|
try: |
|
decoder_options = dict(config.decoder_pan) |
|
except: |
|
pass |
|
decoder = smp.PAN( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'deeplabv3': |
|
try: |
|
decoder_options = dict(config.decoder_deeplabv3) |
|
except: |
|
pass |
|
decoder = smp.DeepLabV3( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
elif config.model.decoder_type == 'deeplabv3plus': |
|
try: |
|
decoder_options = dict(config.decoder_deeplabv3plus) |
|
except: |
|
pass |
|
decoder = smp.DeepLabV3Plus( |
|
encoder_name=config.model.encoder_name, |
|
encoder_weights="imagenet", |
|
in_channels=c, |
|
classes=c, |
|
**decoder_options, |
|
) |
|
return decoder |
|
|
|
|
|
class Segm_Models_Net(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
act = get_act(act_type=config.model.act) |
|
|
|
self.num_target_instruments = len(prefer_target_instrument(config)) |
|
self.num_subbands = config.model.num_subbands |
|
|
|
dim_c = self.num_subbands * config.audio.num_channels * 2 |
|
c = config.model.num_channels |
|
f = config.audio.dim_f // self.num_subbands |
|
|
|
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) |
|
|
|
self.unet_model = get_decoder(config, c) |
|
|
|
self.final_conv = nn.Sequential( |
|
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), |
|
act, |
|
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) |
|
) |
|
|
|
self.stft = STFT(config.audio) |
|
|
|
def cac2cws(self, x): |
|
k = self.num_subbands |
|
b, c, f, t = x.shape |
|
x = x.reshape(b, c, k, f // k, t) |
|
x = x.reshape(b, c * k, f // k, t) |
|
return x |
|
|
|
def cws2cac(self, x): |
|
k = self.num_subbands |
|
b, c, f, t = x.shape |
|
x = x.reshape(b, c // k, k, f, t) |
|
x = x.reshape(b, c // k, f * k, t) |
|
return x |
|
|
|
def forward(self, x): |
|
|
|
x = self.stft(x) |
|
|
|
mix = x = self.cac2cws(x) |
|
|
|
first_conv_out = x = self.first_conv(x) |
|
|
|
x = x.transpose(-1, -2) |
|
|
|
x = self.unet_model(x) |
|
|
|
x = x.transpose(-1, -2) |
|
|
|
x = x * first_conv_out |
|
|
|
x = self.final_conv(torch.cat([mix, x], 1)) |
|
|
|
x = self.cws2cac(x) |
|
|
|
if self.num_target_instruments > 1: |
|
b, c, f, t = x.shape |
|
x = x.reshape(b, self.num_target_instruments, -1, f, t) |
|
|
|
x = self.stft.inverse(x) |
|
return x |
|
|