File size: 3,860 Bytes
3978e51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from typing import List, Tuple
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint_sequential
from .utils import (
band_widths_from_specs,
check_no_gap,
check_no_overlap,
check_nonzero_bandwidth,
)
class NormFC(nn.Module):
def __init__(
self,
emb_dim: int,
bandwidth: int,
in_channels: int,
normalize_channel_independently: bool = False,
treat_channel_as_feature: bool = True,
) -> None:
super().__init__()
if not treat_channel_as_feature:
raise NotImplementedError
self.treat_channel_as_feature = treat_channel_as_feature
if normalize_channel_independently:
raise NotImplementedError
reim = 2
norm = nn.LayerNorm(in_channels * bandwidth * reim)
fc_in = bandwidth * reim
if treat_channel_as_feature:
fc_in *= in_channels
else:
assert emb_dim % in_channels == 0
emb_dim = emb_dim // in_channels
fc = nn.Linear(fc_in, emb_dim)
self.combined = nn.Sequential(norm, fc)
def forward(self, xb):
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
class BandSplitModule(nn.Module):
def __init__(
self,
band_specs: List[Tuple[float, float]],
emb_dim: int,
in_channels: int,
require_no_overlap: bool = False,
require_no_gap: bool = True,
normalize_channel_independently: bool = False,
treat_channel_as_feature: bool = True,
) -> None:
super().__init__()
check_nonzero_bandwidth(band_specs)
if require_no_gap:
check_no_gap(band_specs)
if require_no_overlap:
check_no_overlap(band_specs)
self.band_specs = band_specs
# list of [fstart, fend) in index.
# Note that fend is exclusive.
self.band_widths = band_widths_from_specs(band_specs)
self.n_bands = len(band_specs)
self.emb_dim = emb_dim
try:
self.norm_fc_modules = nn.ModuleList(
[ # type: ignore
torch.compile(
NormFC(
emb_dim=emb_dim,
bandwidth=bw,
in_channels=in_channels,
normalize_channel_independently=normalize_channel_independently,
treat_channel_as_feature=treat_channel_as_feature,
),
disable=True,
)
for bw in self.band_widths
]
)
except Exception as e:
self.norm_fc_modules = nn.ModuleList(
[ # type: ignore
NormFC(
emb_dim=emb_dim,
bandwidth=bw,
in_channels=in_channels,
normalize_channel_independently=normalize_channel_independently,
treat_channel_as_feature=treat_channel_as_feature,
)
for bw in self.band_widths
]
)
def forward(self, x: torch.Tensor):
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
batch, in_chan, band_width, n_time = x.shape
z = torch.zeros(
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
)
x = torch.permute(x, (0, 3, 1, 2)).contiguous()
for i, nfm in enumerate(self.norm_fc_modules):
fstart, fend = self.band_specs[i]
xb = x[:, :, :, fstart:fend]
xb = torch.view_as_real(xb)
xb = torch.reshape(xb, (batch, n_time, -1))
z[:, i, :, :] = nfm(xb)
return z
|