ASesYusuf1's picture
Upload folder using huggingface_hub
3978e51
from typing import Dict, List, Optional, Tuple, Type
import torch
from torch import nn
from torch.nn.modules import activation
from torch.utils.checkpoint import checkpoint_sequential
from .utils import (
band_widths_from_specs,
check_no_gap,
check_no_overlap,
check_nonzero_bandwidth,
)
class BaseNormMLP(nn.Module):
def __init__(
self,
emb_dim: int,
mlp_dim: int,
bandwidth: int,
in_channels: Optional[int],
hidden_activation: str = "Tanh",
hidden_activation_kwargs=None,
complex_mask: bool = True,
):
super().__init__()
if hidden_activation_kwargs is None:
hidden_activation_kwargs = {}
self.hidden_activation_kwargs = hidden_activation_kwargs
self.norm = nn.LayerNorm(emb_dim)
self.hidden = nn.Sequential(
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
)
self.bandwidth = bandwidth
self.in_channels = in_channels
self.complex_mask = complex_mask
self.reim = 2 if complex_mask else 1
self.glu_mult = 2
class NormMLP(BaseNormMLP):
def __init__(
self,
emb_dim: int,
mlp_dim: int,
bandwidth: int,
in_channels: Optional[int],
hidden_activation: str = "Tanh",
hidden_activation_kwargs=None,
complex_mask: bool = True,
) -> None:
super().__init__(
emb_dim=emb_dim,
mlp_dim=mlp_dim,
bandwidth=bandwidth,
in_channels=in_channels,
hidden_activation=hidden_activation,
hidden_activation_kwargs=hidden_activation_kwargs,
complex_mask=complex_mask,
)
self.output = nn.Sequential(
nn.Linear(
in_features=mlp_dim,
out_features=bandwidth * in_channels * self.reim * 2,
),
nn.GLU(dim=-1),
)
try:
self.combined = torch.compile(
nn.Sequential(self.norm, self.hidden, self.output), disable=True
)
except Exception as e:
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
def reshape_output(self, mb):
# print(mb.shape)
batch, n_time, _ = mb.shape
if self.complex_mask:
mb = mb.reshape(
batch, n_time, self.in_channels, self.bandwidth, self.reim
).contiguous()
# print(mb.shape)
mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
else:
mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
return mb
def forward(self, qb):
# qb = (batch, n_time, emb_dim)
# qb = self.norm(qb) # (batch, n_time, emb_dim)
# qb = self.hidden(qb) # (batch, n_time, mlp_dim)
# mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
return mb
class MaskEstimationModuleSuperBase(nn.Module):
pass
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
def __init__(
self,
band_specs: List[Tuple[float, float]],
emb_dim: int,
mlp_dim: int,
in_channels: Optional[int],
hidden_activation: str = "Tanh",
hidden_activation_kwargs: Dict = None,
complex_mask: bool = True,
norm_mlp_cls: Type[nn.Module] = NormMLP,
norm_mlp_kwargs: Dict = None,
) -> None:
super().__init__()
self.band_widths = band_widths_from_specs(band_specs)
self.n_bands = len(band_specs)
if hidden_activation_kwargs is None:
hidden_activation_kwargs = {}
if norm_mlp_kwargs is None:
norm_mlp_kwargs = {}
self.norm_mlp = nn.ModuleList(
[
norm_mlp_cls(
bandwidth=self.band_widths[b],
emb_dim=emb_dim,
mlp_dim=mlp_dim,
in_channels=in_channels,
hidden_activation=hidden_activation,
hidden_activation_kwargs=hidden_activation_kwargs,
complex_mask=complex_mask,
**norm_mlp_kwargs,
)
for b in range(self.n_bands)
]
)
def compute_masks(self, q):
batch, n_bands, n_time, emb_dim = q.shape
masks = []
for b, nmlp in enumerate(self.norm_mlp):
# print(f"maskestim/{b:02d}")
qb = q[:, b, :, :]
mb = nmlp(qb)
masks.append(mb)
return masks
def compute_mask(self, q, b):
batch, n_bands, n_time, emb_dim = q.shape
qb = q[:, b, :, :]
mb = self.norm_mlp[b](qb)
return mb
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
def __init__(
self,
in_channels: int,
band_specs: List[Tuple[float, float]],
freq_weights: List[torch.Tensor],
n_freq: int,
emb_dim: int,
mlp_dim: int,
cond_dim: int = 0,
hidden_activation: str = "Tanh",
hidden_activation_kwargs: Dict = None,
complex_mask: bool = True,
norm_mlp_cls: Type[nn.Module] = NormMLP,
norm_mlp_kwargs: Dict = None,
use_freq_weights: bool = False,
) -> None:
check_nonzero_bandwidth(band_specs)
check_no_gap(band_specs)
if cond_dim > 0:
raise NotImplementedError
super().__init__(
band_specs=band_specs,
emb_dim=emb_dim + cond_dim,
mlp_dim=mlp_dim,
in_channels=in_channels,
hidden_activation=hidden_activation,
hidden_activation_kwargs=hidden_activation_kwargs,
complex_mask=complex_mask,
norm_mlp_cls=norm_mlp_cls,
norm_mlp_kwargs=norm_mlp_kwargs,
)
self.n_freq = n_freq
self.band_specs = band_specs
self.in_channels = in_channels
if freq_weights is not None and use_freq_weights:
for i, fw in enumerate(freq_weights):
self.register_buffer(f"freq_weights/{i}", fw)
self.use_freq_weights = use_freq_weights
else:
self.use_freq_weights = False
def forward(self, q):
# q = (batch, n_bands, n_time, emb_dim)
batch, n_bands, n_time, emb_dim = q.shape
masks = torch.zeros(
(batch, self.in_channels, self.n_freq, n_time),
device=q.device,
dtype=torch.complex64,
)
for im in range(n_bands):
fstart, fend = self.band_specs[im]
mask = self.compute_mask(q, im)
if self.use_freq_weights:
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
mask = mask * fw
masks[:, :, fstart:fend, :] += mask
return masks
class MaskEstimationModule(OverlappingMaskEstimationModule):
def __init__(
self,
band_specs: List[Tuple[float, float]],
emb_dim: int,
mlp_dim: int,
in_channels: Optional[int],
hidden_activation: str = "Tanh",
hidden_activation_kwargs: Dict = None,
complex_mask: bool = True,
**kwargs,
) -> None:
check_nonzero_bandwidth(band_specs)
check_no_gap(band_specs)
check_no_overlap(band_specs)
super().__init__(
in_channels=in_channels,
band_specs=band_specs,
freq_weights=None,
n_freq=None,
emb_dim=emb_dim,
mlp_dim=mlp_dim,
hidden_activation=hidden_activation,
hidden_activation_kwargs=hidden_activation_kwargs,
complex_mask=complex_mask,
)
def forward(self, q, cond=None):
# q = (batch, n_bands, n_time, emb_dim)
masks = self.compute_masks(
q
) # [n_bands * (batch, in_channels, bandwidth, n_time)]
# TODO: currently this requires band specs to have no gap and no overlap
masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
return masks