import torch import torch.nn as nn from asteroid.models.base_models import ( BaseEncoderMaskerDecoder, _unsqueeze_to_3d, _shape_reconstructed, ) from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape from einops import rearrange class BaseEncoderMaskerDecoderWithConfigs(BaseEncoderMaskerDecoder): def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs): super().__init__(encoder, masker, decoder, encoder_activation) self.use_encoder = kwargs.get("use_encoder", True) self.apply_mask = kwargs.get("apply_mask", True) self.use_decoder = kwargs.get("use_decoder", True) def forward(self, wav): """ Enc/Mask/Dec model forward with some additional options. Some of the models we use, like TFC-TDF-UNet, have no masker. In UMX or X-UMX, they already use masking in their model implementation. Since we do not want to manipulate the model codes, we use this wrapper. Args: wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time). """ # Remember shape to shape reconstruction, cast to Tensor for torchscript shape = jitable_shape(wav) # Reshape to (batch, n_mix, time) wav = _unsqueeze_to_3d(wav) # Real forward if self.use_encoder: tf_rep = self.forward_encoder(wav) else: tf_rep = wav est_masks = self.forward_masker(tf_rep) if self.apply_mask: masked_tf_rep = self.apply_masks(tf_rep, est_masks) else: # model already used masking masked_tf_rep = est_masks if self.use_decoder: decoded = self.forward_decoder(masked_tf_rep) reconstructed = pad_x_to_y(decoded, wav) return masked_tf_rep, _shape_reconstructed(reconstructed, shape) else: # In UMX or X-UMX, decoder is not used decoded = masked_tf_rep return decoded class BaseEncoderMaskerDecoder_mixture_consistency(BaseEncoderMaskerDecoder): def __init__(self, encoder, masker, decoder, encoder_activation=None): super().__init__(encoder, masker, decoder, encoder_activation) def forward(self, wav): """Enc/Mask/Dec model forward with mixture consistent output References: [1] : Wisdom, Scott, et al. "Differentiable consistency constraints for improved deep speech enhancement." ICASSP 2019. [2] : Wisdom, Scott, et al. "Unsupervised sound separation using mixture invariant training." NeurIPS 2020. Args: wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time). """ # Remember shape to shape reconstruction, cast to Tensor for torchscript shape = jitable_shape(wav) # Reshape to (batch, n_mix, time) wav = _unsqueeze_to_3d(wav) # Real forward tf_rep = self.forward_encoder(wav) est_masks = self.forward_masker(tf_rep) masked_tf_rep = self.apply_masks(tf_rep, est_masks) decoded = self.forward_decoder(masked_tf_rep) reconstructed = _shape_reconstructed(pad_x_to_y(decoded, wav), shape) reconstructed = reconstructed + 1 / reconstructed.shape[1] * ( wav - reconstructed.sum(dim=1, keepdim=True) ) return reconstructed class BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(BaseEncoderMaskerDecoder): def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs): super().__init__(encoder, masker, decoder, encoder_activation) self.use_encoder = kwargs.get("use_encoder", True) self.apply_mask = kwargs.get("apply_mask", True) self.use_decoder = kwargs.get("use_decoder", True) self.nb_channels = kwargs.get("nb_channels", 2) self.decoder_activation = kwargs.get("decoder_activation", "sigmoid") if self.decoder_activation == "sigmoid": self.act_after_dec = nn.Sigmoid() elif self.decoder_activation == "relu": self.act_after_dec = nn.ReLU() elif self.decoder_activation == "relu6": self.act_after_dec = nn.ReLU6() elif self.decoder_activation == "tanh": self.act_after_dec = nn.Tanh() elif self.decoder_activation == "none": self.act_after_dec = nn.Identity() else: self.act_after_dec = nn.Sigmoid() def forward(self, wav): """ For the De-limit task, we will apply the mask on the output of the decoder. We want decoder to learn the sample-wise ratio of the sources. Args: wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time). """ # Remember shape to shape reconstruction, cast to Tensor for torchscript shape = jitable_shape(wav) # Reshape to (batch, n_mix, time) wav = _unsqueeze_to_3d(wav) # (batch, n_channels, time) # Real forward if self.use_encoder: tf_rep = self.forward_encoder(wav) # (batch, n_channels, freq, time) else: tf_rep = wav if self.nb_channels == 2: tf_rep = rearrange( tf_rep, "b c f t -> b (c f) t" ) # c == 2 when stereo input. est_masks = self.forward_masker(tf_rep) # (batch, 1, freq, time) # we are going to apply the mask on the output of the decoder if self.use_decoder: if self.nb_channels == 2: est_masks = rearrange(est_masks, "b 1 f t -> b f t") est_masks_decoded = self.forward_decoder(est_masks) est_masks_decoded = pad_x_to_y(est_masks_decoded, wav) # (batch, 1, time) est_masks_decoded = self.act_after_dec( est_masks_decoded ) # (batch, 1, time) decoded = wav * est_masks_decoded # (batch, n_channels, time) return ( est_masks_decoded, decoded, ) else: decoded = est_masks return (decoded,) class BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(BaseEncoderMaskerDecoder): def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs): super().__init__(encoder, masker, decoder, encoder_activation) self.use_encoder = kwargs.get("use_encoder", True) self.apply_mask = kwargs.get("apply_mask", True) self.use_decoder = kwargs.get("use_decoder", True) self.nb_channels = kwargs.get("nb_channels", 2) self.decoder_activation = kwargs.get("decoder_activation", "none") if self.decoder_activation == "sigmoid": self.act_after_dec = nn.Sigmoid() elif self.decoder_activation == "relu": self.act_after_dec = nn.ReLU() elif self.decoder_activation == "relu6": self.act_after_dec = nn.ReLU6() elif self.decoder_activation == "tanh": self.act_after_dec = nn.Tanh() elif self.decoder_activation == "none": self.act_after_dec = nn.Identity() else: self.act_after_dec = nn.Sigmoid() def forward(self, wav): """ Enc/Mask/Dec model forward with some additional options. For MultiChannel usage of asteroid-based models. (e.g. ConvTasNet) Args: wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last. Returns: torch.Tensor, of shape (batch, n_src, time) or (n_src, time). """ # Remember shape to shape reconstruction, cast to Tensor for torchscript shape = jitable_shape(wav) # Reshape to (batch, n_mix, time) wav = _unsqueeze_to_3d(wav) # Real forward if self.use_encoder: tf_rep = self.forward_encoder(wav) else: tf_rep = wav if self.nb_channels == 2: tf_rep = rearrange( tf_rep, "b c f t -> b (c f) t" ) # c == 2 when stereo input. est_masks = self.forward_masker(tf_rep) if self.nb_channels == 2: tf_rep = rearrange(tf_rep, "b (c f) t -> b c f t", c=self.nb_channels) if self.apply_mask: # Since original asteroid implementation of masking includes unnecessary unsqueeze operation, we will do it manually. masked_tf_rep = est_masks * tf_rep else: masked_tf_rep = est_masks if self.use_decoder: decoded = self.forward_decoder(masked_tf_rep) reconstructed = pad_x_to_y(decoded, wav) reconstructed = self.act_after_dec(reconstructed) return masked_tf_rep, _shape_reconstructed(reconstructed, shape) else: decoded = masked_tf_rep return decoded