|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
import torchaudio as ta |
|
from torch import nn |
|
import pytorch_lightning as pl |
|
|
|
from .bandsplit import BandSplitModule |
|
from .maskestim import OverlappingMaskEstimationModule |
|
from .tfmodel import SeqBandModellingModule |
|
from .utils import MusicalBandsplitSpecification |
|
|
|
|
|
|
|
class BaseEndToEndModule(pl.LightningModule): |
|
def __init__( |
|
self, |
|
) -> None: |
|
super().__init__() |
|
|
|
|
|
class BaseBandit(BaseEndToEndModule): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
fs: int, |
|
band_type: str = "musical", |
|
n_bands: int = 64, |
|
require_no_overlap: bool = False, |
|
require_no_gap: bool = True, |
|
normalize_channel_independently: bool = False, |
|
treat_channel_as_feature: bool = True, |
|
n_sqm_modules: int = 12, |
|
emb_dim: int = 128, |
|
rnn_dim: int = 256, |
|
bidirectional: bool = True, |
|
rnn_type: str = "LSTM", |
|
n_fft: int = 2048, |
|
win_length: Optional[int] = 2048, |
|
hop_length: int = 512, |
|
window_fn: str = "hann_window", |
|
wkwargs: Optional[Dict] = None, |
|
power: Optional[int] = None, |
|
center: bool = True, |
|
normalized: bool = True, |
|
pad_mode: str = "constant", |
|
onesided: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
|
|
self.instantitate_spectral( |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
window_fn=window_fn, |
|
wkwargs=wkwargs, |
|
power=power, |
|
normalized=normalized, |
|
center=center, |
|
pad_mode=pad_mode, |
|
onesided=onesided, |
|
) |
|
|
|
self.instantiate_bandsplit( |
|
in_channels=in_channels, |
|
band_type=band_type, |
|
n_bands=n_bands, |
|
require_no_overlap=require_no_overlap, |
|
require_no_gap=require_no_gap, |
|
normalize_channel_independently=normalize_channel_independently, |
|
treat_channel_as_feature=treat_channel_as_feature, |
|
emb_dim=emb_dim, |
|
n_fft=n_fft, |
|
fs=fs, |
|
) |
|
|
|
self.instantiate_tf_modelling( |
|
n_sqm_modules=n_sqm_modules, |
|
emb_dim=emb_dim, |
|
rnn_dim=rnn_dim, |
|
bidirectional=bidirectional, |
|
rnn_type=rnn_type, |
|
) |
|
|
|
def instantitate_spectral( |
|
self, |
|
n_fft: int = 2048, |
|
win_length: Optional[int] = 2048, |
|
hop_length: int = 512, |
|
window_fn: str = "hann_window", |
|
wkwargs: Optional[Dict] = None, |
|
power: Optional[int] = None, |
|
normalized: bool = True, |
|
center: bool = True, |
|
pad_mode: str = "constant", |
|
onesided: bool = True, |
|
): |
|
assert power is None |
|
|
|
window_fn = torch.__dict__[window_fn] |
|
|
|
self.stft = ta.transforms.Spectrogram( |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
pad_mode=pad_mode, |
|
pad=0, |
|
window_fn=window_fn, |
|
wkwargs=wkwargs, |
|
power=power, |
|
normalized=normalized, |
|
center=center, |
|
onesided=onesided, |
|
) |
|
|
|
self.istft = ta.transforms.InverseSpectrogram( |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
pad_mode=pad_mode, |
|
pad=0, |
|
window_fn=window_fn, |
|
wkwargs=wkwargs, |
|
normalized=normalized, |
|
center=center, |
|
onesided=onesided, |
|
) |
|
|
|
def instantiate_bandsplit( |
|
self, |
|
in_channels: int, |
|
band_type: str = "musical", |
|
n_bands: int = 64, |
|
require_no_overlap: bool = False, |
|
require_no_gap: bool = True, |
|
normalize_channel_independently: bool = False, |
|
treat_channel_as_feature: bool = True, |
|
emb_dim: int = 128, |
|
n_fft: int = 2048, |
|
fs: int = 44100, |
|
): |
|
assert band_type == "musical" |
|
|
|
self.band_specs = MusicalBandsplitSpecification( |
|
nfft=n_fft, fs=fs, n_bands=n_bands |
|
) |
|
|
|
self.band_split = BandSplitModule( |
|
in_channels=in_channels, |
|
band_specs=self.band_specs.get_band_specs(), |
|
require_no_overlap=require_no_overlap, |
|
require_no_gap=require_no_gap, |
|
normalize_channel_independently=normalize_channel_independently, |
|
treat_channel_as_feature=treat_channel_as_feature, |
|
emb_dim=emb_dim, |
|
) |
|
|
|
def instantiate_tf_modelling( |
|
self, |
|
n_sqm_modules: int = 12, |
|
emb_dim: int = 128, |
|
rnn_dim: int = 256, |
|
bidirectional: bool = True, |
|
rnn_type: str = "LSTM", |
|
): |
|
try: |
|
self.tf_model = torch.compile( |
|
SeqBandModellingModule( |
|
n_modules=n_sqm_modules, |
|
emb_dim=emb_dim, |
|
rnn_dim=rnn_dim, |
|
bidirectional=bidirectional, |
|
rnn_type=rnn_type, |
|
), |
|
disable=True, |
|
) |
|
except Exception as e: |
|
self.tf_model = SeqBandModellingModule( |
|
n_modules=n_sqm_modules, |
|
emb_dim=emb_dim, |
|
rnn_dim=rnn_dim, |
|
bidirectional=bidirectional, |
|
rnn_type=rnn_type, |
|
) |
|
|
|
def mask(self, x, m): |
|
return x * m |
|
|
|
def forward(self, batch, mode="train"): |
|
|
|
init_shape = batch.shape |
|
if not isinstance(batch, dict): |
|
mono = batch.view(-1, 1, batch.shape[-1]) |
|
batch = { |
|
"mixture": { |
|
"audio": mono |
|
} |
|
} |
|
|
|
with torch.no_grad(): |
|
mixture = batch["mixture"]["audio"] |
|
|
|
x = self.stft(mixture) |
|
batch["mixture"]["spectrogram"] = x |
|
|
|
if "sources" in batch.keys(): |
|
for stem in batch["sources"].keys(): |
|
s = batch["sources"][stem]["audio"] |
|
s = self.stft(s) |
|
batch["sources"][stem]["spectrogram"] = s |
|
|
|
batch = self.separate(batch) |
|
|
|
if 1: |
|
b = [] |
|
for s in self.stems: |
|
|
|
r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2]) |
|
b.append(r) |
|
|
|
batch = torch.stack(b, dim=1) |
|
return batch |
|
|
|
def encode(self, batch): |
|
x = batch["mixture"]["spectrogram"] |
|
length = batch["mixture"]["audio"].shape[-1] |
|
|
|
z = self.band_split(x) |
|
q = self.tf_model(z) |
|
|
|
return x, q, length |
|
|
|
def separate(self, batch): |
|
raise NotImplementedError |
|
|
|
|
|
class Bandit(BaseBandit): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
stems: List[str], |
|
band_type: str = "musical", |
|
n_bands: int = 64, |
|
require_no_overlap: bool = False, |
|
require_no_gap: bool = True, |
|
normalize_channel_independently: bool = False, |
|
treat_channel_as_feature: bool = True, |
|
n_sqm_modules: int = 12, |
|
emb_dim: int = 128, |
|
rnn_dim: int = 256, |
|
bidirectional: bool = True, |
|
rnn_type: str = "LSTM", |
|
mlp_dim: int = 512, |
|
hidden_activation: str = "Tanh", |
|
hidden_activation_kwargs: Dict | None = None, |
|
complex_mask: bool = True, |
|
use_freq_weights: bool = True, |
|
n_fft: int = 2048, |
|
win_length: int | None = 2048, |
|
hop_length: int = 512, |
|
window_fn: str = "hann_window", |
|
wkwargs: Dict | None = None, |
|
power: int | None = None, |
|
center: bool = True, |
|
normalized: bool = True, |
|
pad_mode: str = "constant", |
|
onesided: bool = True, |
|
fs: int = 44100, |
|
stft_precisions="32", |
|
bandsplit_precisions="bf16", |
|
tf_model_precisions="bf16", |
|
mask_estim_precisions="bf16", |
|
): |
|
super().__init__( |
|
in_channels=in_channels, |
|
band_type=band_type, |
|
n_bands=n_bands, |
|
require_no_overlap=require_no_overlap, |
|
require_no_gap=require_no_gap, |
|
normalize_channel_independently=normalize_channel_independently, |
|
treat_channel_as_feature=treat_channel_as_feature, |
|
n_sqm_modules=n_sqm_modules, |
|
emb_dim=emb_dim, |
|
rnn_dim=rnn_dim, |
|
bidirectional=bidirectional, |
|
rnn_type=rnn_type, |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
window_fn=window_fn, |
|
wkwargs=wkwargs, |
|
power=power, |
|
center=center, |
|
normalized=normalized, |
|
pad_mode=pad_mode, |
|
onesided=onesided, |
|
fs=fs, |
|
) |
|
|
|
self.stems = stems |
|
|
|
self.instantiate_mask_estim( |
|
in_channels=in_channels, |
|
stems=stems, |
|
emb_dim=emb_dim, |
|
mlp_dim=mlp_dim, |
|
hidden_activation=hidden_activation, |
|
hidden_activation_kwargs=hidden_activation_kwargs, |
|
complex_mask=complex_mask, |
|
n_freq=n_fft // 2 + 1, |
|
use_freq_weights=use_freq_weights, |
|
) |
|
|
|
def instantiate_mask_estim( |
|
self, |
|
in_channels: int, |
|
stems: List[str], |
|
emb_dim: int, |
|
mlp_dim: int, |
|
hidden_activation: str, |
|
hidden_activation_kwargs: Optional[Dict] = None, |
|
complex_mask: bool = True, |
|
n_freq: Optional[int] = None, |
|
use_freq_weights: bool = False, |
|
): |
|
if hidden_activation_kwargs is None: |
|
hidden_activation_kwargs = {} |
|
|
|
assert n_freq is not None |
|
|
|
self.mask_estim = nn.ModuleDict( |
|
{ |
|
stem: OverlappingMaskEstimationModule( |
|
band_specs=self.band_specs.get_band_specs(), |
|
freq_weights=self.band_specs.get_freq_weights(), |
|
n_freq=n_freq, |
|
emb_dim=emb_dim, |
|
mlp_dim=mlp_dim, |
|
in_channels=in_channels, |
|
hidden_activation=hidden_activation, |
|
hidden_activation_kwargs=hidden_activation_kwargs, |
|
complex_mask=complex_mask, |
|
use_freq_weights=use_freq_weights, |
|
) |
|
for stem in stems |
|
} |
|
) |
|
|
|
def separate(self, batch): |
|
batch["estimates"] = {} |
|
|
|
x, q, length = self.encode(batch) |
|
|
|
for stem, mem in self.mask_estim.items(): |
|
m = mem(q) |
|
|
|
s = self.mask(x, m.to(x.dtype)) |
|
s = torch.reshape(s, x.shape) |
|
batch["estimates"][stem] = { |
|
"audio": self.istft(s, length), |
|
"spectrogram": s, |
|
} |
|
|
|
return batch |
|
|
|
|