ASesYusuf1's picture
Upload folder using huggingface_hub
3978e51
import warnings
import torch
import torch.backends.cuda
from torch import nn
from torch.nn.modules import rnn
from torch.utils.checkpoint import checkpoint_sequential
class TimeFrequencyModellingModule(nn.Module):
def __init__(self) -> None:
super().__init__()
class ResidualRNN(nn.Module):
def __init__(
self,
emb_dim: int,
rnn_dim: int,
bidirectional: bool = True,
rnn_type: str = "LSTM",
use_batch_trick: bool = True,
use_layer_norm: bool = True,
) -> None:
# n_group is the size of the 2nd dim
super().__init__()
assert use_layer_norm
assert use_batch_trick
self.use_layer_norm = use_layer_norm
self.norm = nn.LayerNorm(emb_dim)
self.rnn = rnn.__dict__[rnn_type](
input_size=emb_dim,
hidden_size=rnn_dim,
num_layers=1,
batch_first=True,
bidirectional=bidirectional,
)
self.fc = nn.Linear(
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
)
self.use_batch_trick = use_batch_trick
if not self.use_batch_trick:
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
def forward(self, z):
# z = (batch, n_uncrossed, n_across, emb_dim)
z0 = torch.clone(z)
z = self.norm(z)
batch, n_uncrossed, n_across, emb_dim = z.shape
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
z = self.rnn(z)[0]
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
z = z + z0
return z
class Transpose(nn.Module):
def __init__(self, dim0: int, dim1: int) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, z):
return z.transpose(self.dim0, self.dim1)
class SeqBandModellingModule(TimeFrequencyModellingModule):
def __init__(
self,
n_modules: int = 12,
emb_dim: int = 128,
rnn_dim: int = 256,
bidirectional: bool = True,
rnn_type: str = "LSTM",
parallel_mode=False,
) -> None:
super().__init__()
self.n_modules = n_modules
if parallel_mode:
self.seqband = nn.ModuleList([])
for _ in range(n_modules):
self.seqband.append(
nn.ModuleList(
[
ResidualRNN(
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
ResidualRNN(
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
]
)
)
else:
seqband = []
for _ in range(2 * n_modules):
seqband += [
ResidualRNN(
emb_dim=emb_dim,
rnn_dim=rnn_dim,
bidirectional=bidirectional,
rnn_type=rnn_type,
),
Transpose(1, 2),
]
self.seqband = nn.Sequential(*seqband)
self.parallel_mode = parallel_mode
def forward(self, z):
# z = (batch, n_bands, n_time, emb_dim)
if self.parallel_mode:
for sbm_pair in self.seqband:
# z: (batch, n_bands, n_time, emb_dim)
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
z = zt + zf.transpose(1, 2)
else:
z = checkpoint_sequential(
self.seqband, self.n_modules, z, use_reentrant=False
)
q = z
return q # (batch, n_bands, n_time, emb_dim)