from typing import List, Tuple import torch from torch import nn from torch.utils.checkpoint import checkpoint_sequential from .utils import ( band_widths_from_specs, check_no_gap, check_no_overlap, check_nonzero_bandwidth, ) class NormFC(nn.Module): def __init__( self, emb_dim: int, bandwidth: int, in_channels: int, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, ) -> None: super().__init__() if not treat_channel_as_feature: raise NotImplementedError self.treat_channel_as_feature = treat_channel_as_feature if normalize_channel_independently: raise NotImplementedError reim = 2 norm = nn.LayerNorm(in_channels * bandwidth * reim) fc_in = bandwidth * reim if treat_channel_as_feature: fc_in *= in_channels else: assert emb_dim % in_channels == 0 emb_dim = emb_dim // in_channels fc = nn.Linear(fc_in, emb_dim) self.combined = nn.Sequential(norm, fc) def forward(self, xb): return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False) class BandSplitModule(nn.Module): def __init__( self, band_specs: List[Tuple[float, float]], emb_dim: int, in_channels: int, require_no_overlap: bool = False, require_no_gap: bool = True, normalize_channel_independently: bool = False, treat_channel_as_feature: bool = True, ) -> None: super().__init__() check_nonzero_bandwidth(band_specs) if require_no_gap: check_no_gap(band_specs) if require_no_overlap: check_no_overlap(band_specs) self.band_specs = band_specs # list of [fstart, fend) in index. # Note that fend is exclusive. self.band_widths = band_widths_from_specs(band_specs) self.n_bands = len(band_specs) self.emb_dim = emb_dim try: self.norm_fc_modules = nn.ModuleList( [ # type: ignore torch.compile( NormFC( emb_dim=emb_dim, bandwidth=bw, in_channels=in_channels, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, ), disable=True, ) for bw in self.band_widths ] ) except Exception as e: self.norm_fc_modules = nn.ModuleList( [ # type: ignore NormFC( emb_dim=emb_dim, bandwidth=bw, in_channels=in_channels, normalize_channel_independently=normalize_channel_independently, treat_channel_as_feature=treat_channel_as_feature, ) for bw in self.band_widths ] ) def forward(self, x: torch.Tensor): # x = complex spectrogram (batch, in_chan, n_freq, n_time) batch, in_chan, band_width, n_time = x.shape z = torch.zeros( size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device ) x = torch.permute(x, (0, 3, 1, 2)).contiguous() for i, nfm in enumerate(self.norm_fc_modules): fstart, fend = self.band_specs[i] xb = x[:, :, :, fstart:fend] xb = torch.view_as_real(xb) xb = torch.reshape(xb, (batch, n_time, -1)) z[:, i, :, :] = nfm(xb) return z