|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from models.scnet_unofficial.utils import get_convtranspose_output_padding |
|
|
|
|
|
class FusionLayer(nn.Module): |
|
""" |
|
FusionLayer class implements a module for fusing two input tensors using convolutional operations. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input channels. |
|
- kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. |
|
- stride (int, optional): Stride value for the convolutional layer. Default is 1. |
|
- padding (int, optional): Padding value for the convolutional layer. Default is 1. |
|
|
|
Shapes: |
|
- Input: (B, F, T, C) and (B, F, T, C) where |
|
B is batch size, |
|
F is the number of features, |
|
T is sequence length, |
|
C is input dimensionality. |
|
- Output: (B, F, T, C) where |
|
B is batch size, |
|
F is the number of features, |
|
T is sequence length, |
|
C is input dimensionality. |
|
""" |
|
|
|
def __init__( |
|
self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 |
|
): |
|
""" |
|
Initializes FusionLayer with input dimension, kernel size, stride, and padding. |
|
""" |
|
super().__init__() |
|
self.conv = nn.Conv2d( |
|
input_dim * 2, |
|
input_dim * 2, |
|
kernel_size=(kernel_size, 1), |
|
stride=(stride, 1), |
|
padding=(padding, 0), |
|
) |
|
self.activation = nn.GLU() |
|
|
|
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the FusionLayer. |
|
|
|
Args: |
|
- x1 (torch.Tensor): First input tensor of shape (B, F, T, C). |
|
- x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). |
|
|
|
Returns: |
|
- torch.Tensor: Output tensor of shape (B, F, T, C). |
|
""" |
|
x = x1 + x2 |
|
x = x.repeat(1, 1, 1, 2) |
|
x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) |
|
x = self.activation(x) |
|
return x |
|
|
|
|
|
class Upsample(nn.Module): |
|
""" |
|
Upsample class implements a module for upsampling input tensors using transposed 2D convolution. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input channels. |
|
- output_dim (int): Dimensionality of the output channels. |
|
- stride (int): Stride value for the transposed convolution operation. |
|
- output_padding (int): Output padding value for the transposed convolution operation. |
|
|
|
Shapes: |
|
- Input: (B, C_in, F, T) where |
|
B is batch size, |
|
C_in is the number of input channels, |
|
F is the frequency dimension, |
|
T is the time dimension. |
|
- Output: (B, C_out, F * stride + output_padding, T) where |
|
B is batch size, |
|
C_out is the number of output channels, |
|
F * stride + output_padding is the upsampled frequency dimension. |
|
""" |
|
|
|
def __init__( |
|
self, input_dim: int, output_dim: int, stride: int, output_padding: int |
|
): |
|
""" |
|
Initializes Upsample with input dimension, output dimension, stride, and output padding. |
|
""" |
|
super().__init__() |
|
self.conv = nn.ConvTranspose2d( |
|
input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the Upsample module. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, C_in, F, T). |
|
|
|
Returns: |
|
- torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). |
|
""" |
|
return self.conv(x) |
|
|
|
|
|
class SULayer(nn.Module): |
|
""" |
|
SULayer class implements a subband upsampling layer using transposed convolution. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input channels. |
|
- output_dim (int): Dimensionality of the output channels. |
|
- upsample_stride (int): Stride value for the upsampling operation. |
|
- subband_shape (int): Shape of the subband. |
|
- sd_interval (Tuple[int, int]): Start and end indices of the subband interval. |
|
|
|
Shapes: |
|
- Input: (B, F, T, C) where |
|
B is batch size, |
|
F is the number of features, |
|
T is sequence length, |
|
C is input dimensionality. |
|
- Output: (B, F, T, C) where |
|
B is batch size, |
|
F is the number of features, |
|
T is sequence length, |
|
C is input dimensionality. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
output_dim: int, |
|
upsample_stride: int, |
|
subband_shape: int, |
|
sd_interval: Tuple[int, int], |
|
): |
|
""" |
|
Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. |
|
""" |
|
super().__init__() |
|
sd_shape = sd_interval[1] - sd_interval[0] |
|
upsample_output_padding = get_convtranspose_output_padding( |
|
input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride |
|
) |
|
self.upsample = Upsample( |
|
input_dim=input_dim, |
|
output_dim=output_dim, |
|
stride=upsample_stride, |
|
output_padding=upsample_output_padding, |
|
) |
|
self.sd_interval = sd_interval |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the SULayer. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, F, T, C). |
|
|
|
Returns: |
|
- torch.Tensor: Output tensor of shape (B, F, T, C). |
|
""" |
|
x = x[:, self.sd_interval[0] : self.sd_interval[1]] |
|
x = x.permute(0, 3, 1, 2) |
|
x = self.upsample(x) |
|
x = x.permute(0, 2, 3, 1) |
|
return x |
|
|
|
|
|
class SUBlock(nn.Module): |
|
""" |
|
SUBlock class implements a block with fusion layer and subband upsampling layers. |
|
|
|
Args: |
|
- input_dim (int): Dimensionality of the input channels. |
|
- output_dim (int): Dimensionality of the output channels. |
|
- upsample_strides (List[int]): List of stride values for the upsampling operations. |
|
- subband_shapes (List[int]): List of shapes for the subbands. |
|
- sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. |
|
|
|
Shapes: |
|
- Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where |
|
B is batch size, |
|
Fi-1 is the number of input subbands, |
|
T is sequence length, |
|
Ci-1 is the number of input channels. |
|
- Output: (B, Fi, T, Ci) where |
|
B is batch size, |
|
Fi is the number of output subbands, |
|
T is sequence length, |
|
Ci is the number of output channels. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int, |
|
output_dim: int, |
|
upsample_strides: List[int], |
|
subband_shapes: List[int], |
|
sd_intervals: List[Tuple[int, int]], |
|
): |
|
""" |
|
Initializes SUBlock with input dimension, output dimension, |
|
upsample strides, subband shapes, and subband intervals. |
|
""" |
|
super().__init__() |
|
self.fusion_layer = FusionLayer(input_dim=input_dim) |
|
self.su_layers = nn.ModuleList( |
|
SULayer( |
|
input_dim=input_dim, |
|
output_dim=output_dim, |
|
upsample_stride=uss, |
|
subband_shape=sbs, |
|
sd_interval=sdi, |
|
) |
|
for i, (uss, sbs, sdi) in enumerate( |
|
zip(upsample_strides, subband_shapes, sd_intervals) |
|
) |
|
) |
|
|
|
def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs forward pass through the SUBlock. |
|
|
|
Args: |
|
- x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). |
|
- x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). |
|
|
|
Returns: |
|
- torch.Tensor: Output tensor of shape (B, Fi, T, Ci). |
|
""" |
|
x = self.fusion_layer(x, x_skip) |
|
x = torch.concat([layer(x) for layer in self.su_layers], dim=1) |
|
return x |
|
|