|
from typing import List, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from torch.utils.checkpoint import checkpoint_sequential |
|
|
|
from .utils import ( |
|
band_widths_from_specs, |
|
check_no_gap, |
|
check_no_overlap, |
|
check_nonzero_bandwidth, |
|
) |
|
|
|
|
|
class NormFC(nn.Module): |
|
def __init__( |
|
self, |
|
emb_dim: int, |
|
bandwidth: int, |
|
in_channels: int, |
|
normalize_channel_independently: bool = False, |
|
treat_channel_as_feature: bool = True, |
|
) -> None: |
|
super().__init__() |
|
|
|
if not treat_channel_as_feature: |
|
raise NotImplementedError |
|
|
|
self.treat_channel_as_feature = treat_channel_as_feature |
|
|
|
if normalize_channel_independently: |
|
raise NotImplementedError |
|
|
|
reim = 2 |
|
|
|
norm = nn.LayerNorm(in_channels * bandwidth * reim) |
|
|
|
fc_in = bandwidth * reim |
|
|
|
if treat_channel_as_feature: |
|
fc_in *= in_channels |
|
else: |
|
assert emb_dim % in_channels == 0 |
|
emb_dim = emb_dim // in_channels |
|
|
|
fc = nn.Linear(fc_in, emb_dim) |
|
|
|
self.combined = nn.Sequential(norm, fc) |
|
|
|
def forward(self, xb): |
|
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False) |
|
|
|
|
|
class BandSplitModule(nn.Module): |
|
def __init__( |
|
self, |
|
band_specs: List[Tuple[float, float]], |
|
emb_dim: int, |
|
in_channels: int, |
|
require_no_overlap: bool = False, |
|
require_no_gap: bool = True, |
|
normalize_channel_independently: bool = False, |
|
treat_channel_as_feature: bool = True, |
|
) -> None: |
|
super().__init__() |
|
|
|
check_nonzero_bandwidth(band_specs) |
|
|
|
if require_no_gap: |
|
check_no_gap(band_specs) |
|
|
|
if require_no_overlap: |
|
check_no_overlap(band_specs) |
|
|
|
self.band_specs = band_specs |
|
|
|
|
|
self.band_widths = band_widths_from_specs(band_specs) |
|
self.n_bands = len(band_specs) |
|
self.emb_dim = emb_dim |
|
|
|
try: |
|
self.norm_fc_modules = nn.ModuleList( |
|
[ |
|
torch.compile( |
|
NormFC( |
|
emb_dim=emb_dim, |
|
bandwidth=bw, |
|
in_channels=in_channels, |
|
normalize_channel_independently=normalize_channel_independently, |
|
treat_channel_as_feature=treat_channel_as_feature, |
|
), |
|
disable=True, |
|
) |
|
for bw in self.band_widths |
|
] |
|
) |
|
except Exception as e: |
|
self.norm_fc_modules = nn.ModuleList( |
|
[ |
|
NormFC( |
|
emb_dim=emb_dim, |
|
bandwidth=bw, |
|
in_channels=in_channels, |
|
normalize_channel_independently=normalize_channel_independently, |
|
treat_channel_as_feature=treat_channel_as_feature, |
|
) |
|
for bw in self.band_widths |
|
] |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
batch, in_chan, band_width, n_time = x.shape |
|
|
|
z = torch.zeros( |
|
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device |
|
) |
|
|
|
x = torch.permute(x, (0, 3, 1, 2)).contiguous() |
|
|
|
for i, nfm in enumerate(self.norm_fc_modules): |
|
fstart, fend = self.band_specs[i] |
|
xb = x[:, :, :, fstart:fend] |
|
xb = torch.view_as_real(xb) |
|
xb = torch.reshape(xb, (batch, n_time, -1)) |
|
z[:, i, :, :] = nfm(xb) |
|
|
|
return z |
|
|