Spaces:
Running
Running
import torch | |
import torchaudio | |
import typing as T | |
class MelspecDiscriminator(torch.nn.Module): | |
"""mel spectrogram (frequency domain) discriminator""" | |
def __init__(self) -> None: | |
super().__init__() | |
self.SAMPLE_RATE = 48000 | |
# mel filterbank transform | |
self._melspec = torchaudio.transforms.MelSpectrogram( | |
sample_rate=self.SAMPLE_RATE, | |
n_fft=2048, | |
win_length=int(0.025 * self.SAMPLE_RATE), | |
hop_length=int(0.010 * self.SAMPLE_RATE), | |
n_mels=128, | |
power=1, | |
) | |
# time-frequency 2D convolutions | |
kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)] | |
strides = [(1, 2), (1, 2), (1, 2), (1, 2)] | |
self._convs = torch.nn.ModuleList( | |
[ | |
torch.nn.Sequential( | |
torch.nn.Conv2d( | |
in_channels=1 if i == 0 else 32, | |
out_channels=64, | |
kernel_size=k, | |
stride=s, | |
padding=(1, 2), | |
bias=False, | |
), | |
torch.nn.BatchNorm2d(num_features=64), | |
torch.nn.GLU(dim=1), | |
) | |
for i, (k, s) in enumerate(zip(kernel_sizes, strides)) | |
] | |
) | |
# output adversarial projection | |
self._postnet = torch.nn.Conv2d( | |
in_channels=32, | |
out_channels=1, | |
kernel_size=(15, 3), | |
stride=(1, 2), | |
) | |
def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]: | |
# apply the log-scale mel spectrogram transform | |
x = torch.log(self._melspec(x) + 1e-5) | |
# compute hidden layers and feature maps | |
f = [] | |
for c in self._convs: | |
x = c(x) | |
f.append(x) | |
# apply the output projection and global average pooling | |
x = self._postnet(x) | |
x = x.mean(dim=[-2, -1]) | |
return [(f, x)] | |