from typing import Dict, List, Optional import torch import torchaudio as ta from torch import nn import pytorch_lightning as pl from .bandsplit import BandSplitModule from .maskestim import OverlappingMaskEstimationModule from .tfmodel import SeqBandModellingModule from .utils import MusicalBandsplitSpecification class BaseEndToEndModule(pl.LightningModule): def __init__( self, ) -> None: super().__init__() class BaseBandit(BaseEndToEndModule): def __init__( self, in_channels: int, fs: int, band_type: str = "musical", n_bands: int = 64, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, n_sqm_modules: int = 12, emb_dim: int = 128, rnn_dim: int = 256, bidirectional: bool = True, rnn_type: str = "LSTM", n_fft: int = 2048, win_length: Optional[int] = 2048, hop_length: int = 512, window_fn: str = "hann_window", wkwargs: Optional[Dict] = None, power: Optional[int] = None, center: bool = True, normalized: bool = True, pad_mode: str = "constant", onesided: bool = True, ): super().__init__() self.in_channels = in_channels self.instantitate_spectral( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_fn=window_fn, wkwargs=wkwargs, power=power, normalized=normalized, center=center, pad_mode=pad_mode, onesided=onesided, ) self.instantiate_bandsplit( in_channels=in_channels, band_type=band_type, n_bands=n_bands, require_no_overlap=require_no_overlap, require_no_gap=require_no_gap, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, emb_dim=emb_dim, n_fft=n_fft, fs=fs, ) self.instantiate_tf_modelling( n_sqm_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, ) def instantitate_spectral( self, n_fft: int = 2048, win_length: Optional[int] = 2048, hop_length: int = 512, window_fn: str = "hann_window", wkwargs: Optional[Dict] = None, power: Optional[int] = None, normalized: bool = True, center: bool = True, pad_mode: str = "constant", onesided: bool = True, ): assert power is None window_fn = torch.__dict__[window_fn] self.stft = ta.transforms.Spectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, pad_mode=pad_mode, pad=0, window_fn=window_fn, wkwargs=wkwargs, power=power, normalized=normalized, center=center, onesided=onesided, ) self.istft = ta.transforms.InverseSpectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, pad_mode=pad_mode, pad=0, window_fn=window_fn, wkwargs=wkwargs, normalized=normalized, center=center, onesided=onesided, ) def instantiate_bandsplit( self, in_channels: int, band_type: str = "musical", n_bands: int = 64, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, emb_dim: int = 128, n_fft: int = 2048, fs: int = 44100, ): assert band_type == "musical" self.band_specs = MusicalBandsplitSpecification( nfft=n_fft, fs=fs, n_bands=n_bands ) self.band_split = BandSplitModule( in_channels=in_channels, band_specs=self.band_specs.get_band_specs(), require_no_overlap=require_no_overlap, require_no_gap=require_no_gap, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, emb_dim=emb_dim, ) def instantiate_tf_modelling( self, n_sqm_modules: int = 12, emb_dim: int = 128, rnn_dim: int = 256, bidirectional: bool = True, rnn_type: str = "LSTM", ): try: self.tf_model = torch.compile( SeqBandModellingModule( n_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, ), disable=True, ) except Exception as e: self.tf_model = SeqBandModellingModule( n_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, ) def mask(self, x, m): return x * m def forward(self, batch, mode="train"): # Model takes mono as input we give stereo, so we do process of each channel independently init_shape = batch.shape if not isinstance(batch, dict): mono = batch.view(-1, 1, batch.shape[-1]) batch = { "mixture": { "audio": mono } } with torch.no_grad(): mixture = batch["mixture"]["audio"] x = self.stft(mixture) batch["mixture"]["spectrogram"] = x if "sources" in batch.keys(): for stem in batch["sources"].keys(): s = batch["sources"][stem]["audio"] s = self.stft(s) batch["sources"][stem]["spectrogram"] = s batch = self.separate(batch) if 1: b = [] for s in self.stems: # We need to obtain stereo again r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2]) b.append(r) # And we need to return back tensor and not independent stems batch = torch.stack(b, dim=1) return batch def encode(self, batch): x = batch["mixture"]["spectrogram"] length = batch["mixture"]["audio"].shape[-1] z = self.band_split(x) # (batch, emb_dim, n_band, n_time) q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) return x, q, length def separate(self, batch): raise NotImplementedError class Bandit(BaseBandit): def __init__( self, in_channels: int, stems: List[str], band_type: str = "musical", n_bands: int = 64, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, n_sqm_modules: int = 12, emb_dim: int = 128, rnn_dim: int = 256, bidirectional: bool = True, rnn_type: str = "LSTM", mlp_dim: int = 512, hidden_activation: str = "Tanh", hidden_activation_kwargs: Dict | None = None, complex_mask: bool = True, use_freq_weights: bool = True, n_fft: int = 2048, win_length: int | None = 2048, hop_length: int = 512, window_fn: str = "hann_window", wkwargs: Dict | None = None, power: int | None = None, center: bool = True, normalized: bool = True, pad_mode: str = "constant", onesided: bool = True, fs: int = 44100, stft_precisions="32", bandsplit_precisions="bf16", tf_model_precisions="bf16", mask_estim_precisions="bf16", ): super().__init__( in_channels=in_channels, band_type=band_type, n_bands=n_bands, require_no_overlap=require_no_overlap, require_no_gap=require_no_gap, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, n_sqm_modules=n_sqm_modules, emb_dim=emb_dim, rnn_dim=rnn_dim, bidirectional=bidirectional, rnn_type=rnn_type, n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_fn=window_fn, wkwargs=wkwargs, power=power, center=center, normalized=normalized, pad_mode=pad_mode, onesided=onesided, fs=fs, ) self.stems = stems self.instantiate_mask_estim( in_channels=in_channels, stems=stems, emb_dim=emb_dim, mlp_dim=mlp_dim, hidden_activation=hidden_activation, hidden_activation_kwargs=hidden_activation_kwargs, complex_mask=complex_mask, n_freq=n_fft // 2 + 1, use_freq_weights=use_freq_weights, ) def instantiate_mask_estim( self, in_channels: int, stems: List[str], emb_dim: int, mlp_dim: int, hidden_activation: str, hidden_activation_kwargs: Optional[Dict] = None, complex_mask: bool = True, n_freq: Optional[int] = None, use_freq_weights: bool = False, ): if hidden_activation_kwargs is None: hidden_activation_kwargs = {} assert n_freq is not None self.mask_estim = nn.ModuleDict( { stem: OverlappingMaskEstimationModule( band_specs=self.band_specs.get_band_specs(), freq_weights=self.band_specs.get_freq_weights(), n_freq=n_freq, emb_dim=emb_dim, mlp_dim=mlp_dim, in_channels=in_channels, hidden_activation=hidden_activation, hidden_activation_kwargs=hidden_activation_kwargs, complex_mask=complex_mask, use_freq_weights=use_freq_weights, ) for stem in stems } ) def separate(self, batch): batch["estimates"] = {} x, q, length = self.encode(batch) for stem, mem in self.mask_estim.items(): m = mem(q) s = self.mask(x, m.to(x.dtype)) s = torch.reshape(s, x.shape) batch["estimates"][stem] = { "audio": self.istft(s, length), "spectrogram": s, } return batch