|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from models.scnet_unofficial.utils import create_intervals |
|
|
|
|
|
class Downsample(nn.Module): |
|
""" |
|
Downsample class implements a module for downsampling input tensors using 2D convolution. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input channels. |
|
- output_dim (int): Dimensionality of the output channels. |
|
- stride (int): Stride value for the convolution operation. |
|
|
|
Shapes: |
|
- Input: (B, C_in, F, T) where |
|
B is batch size, |
|
C_in is the number of input channels, |
|
F is the frequency dimension, |
|
T is the time dimension. |
|
- Output: (B, C_out, F // stride, T) where |
|
B is batch size, |
|
C_out is the number of output channels, |
|
F // stride is the downsampled frequency dimension. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
output_dim: int, |
|
stride: int, |
|
): |
|
""" |
|
Initializes Downsample with input dimension, output dimension, and stride. |
|
""" |
|
super().__init__() |
|
self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the Downsample module. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, C_in, F, T). |
|
|
|
Returns: |
|
- torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T). |
|
""" |
|
return self.conv(x) |
|
|
|
|
|
class ConvolutionModule(nn.Module): |
|
""" |
|
ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input features. |
|
- hidden_dim (int): Dimensionality of the hidden features. |
|
- kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. |
|
- bias (bool, optional): If True, adds a learnable bias to the output. Default is False. |
|
|
|
Shapes: |
|
- Input: (B, T, D) where |
|
B is batch size, |
|
T is sequence length, |
|
D is input dimensionality. |
|
- Output: (B, T, D) where |
|
B is batch size, |
|
T is sequence length, |
|
D is input dimensionality. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
hidden_dim: int, |
|
kernel_sizes: List[int], |
|
bias: bool = False, |
|
) -> None: |
|
""" |
|
Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias. |
|
""" |
|
super().__init__() |
|
self.sequential = nn.Sequential( |
|
nn.GroupNorm(num_groups=1, num_channels=input_dim), |
|
nn.Conv1d( |
|
input_dim, |
|
2 * hidden_dim, |
|
kernel_sizes[0], |
|
stride=1, |
|
padding=(kernel_sizes[0] - 1) // 2, |
|
bias=bias, |
|
), |
|
nn.GLU(dim=1), |
|
nn.Conv1d( |
|
hidden_dim, |
|
hidden_dim, |
|
kernel_sizes[1], |
|
stride=1, |
|
padding=(kernel_sizes[1] - 1) // 2, |
|
groups=hidden_dim, |
|
bias=bias, |
|
), |
|
nn.GroupNorm(num_groups=1, num_channels=hidden_dim), |
|
nn.SiLU(), |
|
nn.Conv1d( |
|
hidden_dim, |
|
input_dim, |
|
kernel_sizes[2], |
|
stride=1, |
|
padding=(kernel_sizes[2] - 1) // 2, |
|
bias=bias, |
|
), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the ConvolutionModule. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, T, D). |
|
|
|
Returns: |
|
- torch.Tensor: Output tensor of shape (B, T, D). |
|
""" |
|
x = x.transpose(1, 2) |
|
x = x + self.sequential(x) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
class SDLayer(nn.Module): |
|
""" |
|
SDLayer class implements a subband decomposition layer with downsampling and convolutional modules. |
|
|
|
Args: |
|
- subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition. |
|
- input_dim (int): Dimensionality of the input channels. |
|
- output_dim (int): Dimensionality of the output channels after downsampling. |
|
- downsample_stride (int): Stride value for the downsampling operation. |
|
- n_conv_modules (int): Number of convolutional modules. |
|
- kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. |
|
- bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True. |
|
|
|
Shapes: |
|
- Input: (B, Fi, T, Ci) where |
|
B is batch size, |
|
Fi is the number of input subbands, |
|
T is sequence length, and |
|
Ci is the number of input channels. |
|
- Output: (B, Fi+1, T, Ci+1) where |
|
B is batch size, |
|
Fi+1 is the number of output subbands, |
|
T is sequence length, |
|
Ci+1 is the number of output channels. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
subband_interval: Tuple[float, float], |
|
input_dim: int, |
|
output_dim: int, |
|
downsample_stride: int, |
|
n_conv_modules: int, |
|
kernel_sizes: List[int], |
|
bias: bool = True, |
|
): |
|
""" |
|
Initializes SDLayer with subband interval, input dimension, |
|
output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias. |
|
""" |
|
super().__init__() |
|
self.subband_interval = subband_interval |
|
self.downsample = Downsample(input_dim, output_dim, downsample_stride) |
|
self.activation = nn.GELU() |
|
conv_modules = [ |
|
ConvolutionModule( |
|
input_dim=output_dim, |
|
hidden_dim=output_dim // 4, |
|
kernel_sizes=kernel_sizes, |
|
bias=bias, |
|
) |
|
for _ in range(n_conv_modules) |
|
] |
|
self.conv_modules = nn.Sequential(*conv_modules) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the SDLayer. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). |
|
|
|
Returns: |
|
- torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1). |
|
""" |
|
B, F, T, C = x.shape |
|
x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)] |
|
x = x.permute(0, 3, 1, 2) |
|
x = self.downsample(x) |
|
x = self.activation(x) |
|
x = x.permute(0, 2, 3, 1) |
|
|
|
B, F, T, C = x.shape |
|
x = x.reshape((B * F), T, C) |
|
x = self.conv_modules(x) |
|
x = x.reshape(B, F, T, C) |
|
|
|
return x |
|
|
|
|
|
class SDBlock(nn.Module): |
|
""" |
|
SDBlock class implements a block with subband decomposition layers and global convolution. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input channels. |
|
- output_dim (int): Dimensionality of the output channels. |
|
- bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. |
|
- downsample_strides (List[int]): List of stride values for downsampling in each subband layer. |
|
- n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer. |
|
- kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None. |
|
|
|
Shapes: |
|
- Input: (B, Fi, T, Ci) where |
|
B is batch size, |
|
Fi is the number of input subbands, |
|
T is sequence length, |
|
Ci is the number of input channels. |
|
- Output: (B, Fi+1, T, Ci+1) where |
|
B is batch size, |
|
Fi+1 is the number of output subbands, |
|
T is sequence length, |
|
Ci+1 is the number of output channels. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
output_dim: int, |
|
bandsplit_ratios: List[float], |
|
downsample_strides: List[int], |
|
n_conv_modules: List[int], |
|
kernel_sizes: List[int] = None, |
|
): |
|
""" |
|
Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes. |
|
""" |
|
super().__init__() |
|
if kernel_sizes is None: |
|
kernel_sizes = [3, 3, 1] |
|
assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1." |
|
subband_intervals = create_intervals(bandsplit_ratios) |
|
self.sd_layers = nn.ModuleList( |
|
SDLayer( |
|
input_dim=input_dim, |
|
output_dim=output_dim, |
|
subband_interval=sbi, |
|
downsample_stride=dss, |
|
n_conv_modules=ncm, |
|
kernel_sizes=kernel_sizes, |
|
) |
|
for sbi, dss, ncm in zip( |
|
subband_intervals, downsample_strides, n_conv_modules |
|
) |
|
) |
|
self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1) |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Performs forward pass through the SDBlock. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). |
|
|
|
Returns: |
|
- Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor. |
|
""" |
|
x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1) |
|
x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) |
|
return x, x_skip |
|
|