De-limiter / models /base_models.py
jeonchangbin49's picture
first commit
a00b67a
import torch
import torch.nn as nn
from asteroid.models.base_models import (
BaseEncoderMaskerDecoder,
_unsqueeze_to_3d,
_shape_reconstructed,
)
from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape
from einops import rearrange
class BaseEncoderMaskerDecoderWithConfigs(BaseEncoderMaskerDecoder):
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
super().__init__(encoder, masker, decoder, encoder_activation)
self.use_encoder = kwargs.get("use_encoder", True)
self.apply_mask = kwargs.get("apply_mask", True)
self.use_decoder = kwargs.get("use_decoder", True)
def forward(self, wav):
"""
Enc/Mask/Dec model forward with some additional options.
Some of the models we use, like TFC-TDF-UNet, have no masker.
In UMX or X-UMX, they already use masking in their model implementation.
Since we do not want to manipulate the model codes, we use this wrapper.
Args:
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
"""
# Remember shape to shape reconstruction, cast to Tensor for torchscript
shape = jitable_shape(wav)
# Reshape to (batch, n_mix, time)
wav = _unsqueeze_to_3d(wav)
# Real forward
if self.use_encoder:
tf_rep = self.forward_encoder(wav)
else:
tf_rep = wav
est_masks = self.forward_masker(tf_rep)
if self.apply_mask:
masked_tf_rep = self.apply_masks(tf_rep, est_masks)
else: # model already used masking
masked_tf_rep = est_masks
if self.use_decoder:
decoded = self.forward_decoder(masked_tf_rep)
reconstructed = pad_x_to_y(decoded, wav)
return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
else: # In UMX or X-UMX, decoder is not used
decoded = masked_tf_rep
return decoded
class BaseEncoderMaskerDecoder_mixture_consistency(BaseEncoderMaskerDecoder):
def __init__(self, encoder, masker, decoder, encoder_activation=None):
super().__init__(encoder, masker, decoder, encoder_activation)
def forward(self, wav):
"""Enc/Mask/Dec model forward with mixture consistent output
References:
[1] : Wisdom, Scott, et al. "Differentiable consistency constraints for improved deep speech enhancement." ICASSP 2019.
[2] : Wisdom, Scott, et al. "Unsupervised sound separation using mixture invariant training." NeurIPS 2020.
Args:
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
"""
# Remember shape to shape reconstruction, cast to Tensor for torchscript
shape = jitable_shape(wav)
# Reshape to (batch, n_mix, time)
wav = _unsqueeze_to_3d(wav)
# Real forward
tf_rep = self.forward_encoder(wav)
est_masks = self.forward_masker(tf_rep)
masked_tf_rep = self.apply_masks(tf_rep, est_masks)
decoded = self.forward_decoder(masked_tf_rep)
reconstructed = _shape_reconstructed(pad_x_to_y(decoded, wav), shape)
reconstructed = reconstructed + 1 / reconstructed.shape[1] * (
wav - reconstructed.sum(dim=1, keepdim=True)
)
return reconstructed
class BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(BaseEncoderMaskerDecoder):
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
super().__init__(encoder, masker, decoder, encoder_activation)
self.use_encoder = kwargs.get("use_encoder", True)
self.apply_mask = kwargs.get("apply_mask", True)
self.use_decoder = kwargs.get("use_decoder", True)
self.nb_channels = kwargs.get("nb_channels", 2)
self.decoder_activation = kwargs.get("decoder_activation", "sigmoid")
if self.decoder_activation == "sigmoid":
self.act_after_dec = nn.Sigmoid()
elif self.decoder_activation == "relu":
self.act_after_dec = nn.ReLU()
elif self.decoder_activation == "relu6":
self.act_after_dec = nn.ReLU6()
elif self.decoder_activation == "tanh":
self.act_after_dec = nn.Tanh()
elif self.decoder_activation == "none":
self.act_after_dec = nn.Identity()
else:
self.act_after_dec = nn.Sigmoid()
def forward(self, wav):
"""
For the De-limit task, we will apply the mask on the output of the decoder.
We want decoder to learn the sample-wise ratio of the sources.
Args:
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
"""
# Remember shape to shape reconstruction, cast to Tensor for torchscript
shape = jitable_shape(wav)
# Reshape to (batch, n_mix, time)
wav = _unsqueeze_to_3d(wav) # (batch, n_channels, time)
# Real forward
if self.use_encoder:
tf_rep = self.forward_encoder(wav) # (batch, n_channels, freq, time)
else:
tf_rep = wav
if self.nb_channels == 2:
tf_rep = rearrange(
tf_rep, "b c f t -> b (c f) t"
) # c == 2 when stereo input.
est_masks = self.forward_masker(tf_rep) # (batch, 1, freq, time)
# we are going to apply the mask on the output of the decoder
if self.use_decoder:
if self.nb_channels == 2:
est_masks = rearrange(est_masks, "b 1 f t -> b f t")
est_masks_decoded = self.forward_decoder(est_masks)
est_masks_decoded = pad_x_to_y(est_masks_decoded, wav) # (batch, 1, time)
est_masks_decoded = self.act_after_dec(
est_masks_decoded
) # (batch, 1, time)
decoded = wav * est_masks_decoded # (batch, n_channels, time)
return (
est_masks_decoded,
decoded,
)
else:
decoded = est_masks
return (decoded,)
class BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(BaseEncoderMaskerDecoder):
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
super().__init__(encoder, masker, decoder, encoder_activation)
self.use_encoder = kwargs.get("use_encoder", True)
self.apply_mask = kwargs.get("apply_mask", True)
self.use_decoder = kwargs.get("use_decoder", True)
self.nb_channels = kwargs.get("nb_channels", 2)
self.decoder_activation = kwargs.get("decoder_activation", "none")
if self.decoder_activation == "sigmoid":
self.act_after_dec = nn.Sigmoid()
elif self.decoder_activation == "relu":
self.act_after_dec = nn.ReLU()
elif self.decoder_activation == "relu6":
self.act_after_dec = nn.ReLU6()
elif self.decoder_activation == "tanh":
self.act_after_dec = nn.Tanh()
elif self.decoder_activation == "none":
self.act_after_dec = nn.Identity()
else:
self.act_after_dec = nn.Sigmoid()
def forward(self, wav):
"""
Enc/Mask/Dec model forward with some additional options.
For MultiChannel usage of asteroid-based models. (e.g. ConvTasNet)
Args:
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
Returns:
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
"""
# Remember shape to shape reconstruction, cast to Tensor for torchscript
shape = jitable_shape(wav)
# Reshape to (batch, n_mix, time)
wav = _unsqueeze_to_3d(wav)
# Real forward
if self.use_encoder:
tf_rep = self.forward_encoder(wav)
else:
tf_rep = wav
if self.nb_channels == 2:
tf_rep = rearrange(
tf_rep, "b c f t -> b (c f) t"
) # c == 2 when stereo input.
est_masks = self.forward_masker(tf_rep)
if self.nb_channels == 2:
tf_rep = rearrange(tf_rep, "b (c f) t -> b c f t", c=self.nb_channels)
if self.apply_mask:
# Since original asteroid implementation of masking includes unnecessary unsqueeze operation, we will do it manually.
masked_tf_rep = est_masks * tf_rep
else:
masked_tf_rep = est_masks
if self.use_decoder:
decoded = self.forward_decoder(masked_tf_rep)
reconstructed = pad_x_to_y(decoded, wav)
reconstructed = self.act_after_dec(reconstructed)
return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
else:
decoded = masked_tf_rep
return decoded