Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from torch import nn | |
| from modules.vocoder_blocks import * | |
| from einops import rearrange | |
| import torchaudio.transforms as T | |
| from nnAudio import features | |
| LRELU_SLOPE = 0.1 | |
| class DiscriminatorCQT(nn.Module): | |
| def __init__(self, cfg, hop_length, n_octaves, bins_per_octave): | |
| super(DiscriminatorCQT, self).__init__() | |
| self.cfg = cfg | |
| self.filters = cfg.model.mssbcqtd.filters | |
| self.max_filters = cfg.model.mssbcqtd.max_filters | |
| self.filters_scale = cfg.model.mssbcqtd.filters_scale | |
| self.kernel_size = (3, 9) | |
| self.dilations = cfg.model.mssbcqtd.dilations | |
| self.stride = (1, 2) | |
| self.in_channels = cfg.model.mssbcqtd.in_channels | |
| self.out_channels = cfg.model.mssbcqtd.out_channels | |
| self.fs = cfg.preprocess.sample_rate | |
| self.hop_length = hop_length | |
| self.n_octaves = n_octaves | |
| self.bins_per_octave = bins_per_octave | |
| self.cqt_transform = features.cqt.CQT2010v2( | |
| sr=self.fs * 2, | |
| hop_length=self.hop_length, | |
| n_bins=self.bins_per_octave * self.n_octaves, | |
| bins_per_octave=self.bins_per_octave, | |
| output_format="Complex", | |
| pad_mode="constant", | |
| ) | |
| self.conv_pres = nn.ModuleList() | |
| for i in range(self.n_octaves): | |
| self.conv_pres.append( | |
| NormConv2d( | |
| self.in_channels * 2, | |
| self.in_channels * 2, | |
| kernel_size=self.kernel_size, | |
| padding=get_2d_padding(self.kernel_size), | |
| ) | |
| ) | |
| self.convs = nn.ModuleList() | |
| self.convs.append( | |
| NormConv2d( | |
| self.in_channels * 2, | |
| self.filters, | |
| kernel_size=self.kernel_size, | |
| padding=get_2d_padding(self.kernel_size), | |
| ) | |
| ) | |
| in_chs = min(self.filters_scale * self.filters, self.max_filters) | |
| for i, dilation in enumerate(self.dilations): | |
| out_chs = min( | |
| (self.filters_scale ** (i + 1)) * self.filters, self.max_filters | |
| ) | |
| self.convs.append( | |
| NormConv2d( | |
| in_chs, | |
| out_chs, | |
| kernel_size=self.kernel_size, | |
| stride=self.stride, | |
| dilation=(dilation, 1), | |
| padding=get_2d_padding(self.kernel_size, (dilation, 1)), | |
| norm="weight_norm", | |
| ) | |
| ) | |
| in_chs = out_chs | |
| out_chs = min( | |
| (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, | |
| self.max_filters, | |
| ) | |
| self.convs.append( | |
| NormConv2d( | |
| in_chs, | |
| out_chs, | |
| kernel_size=(self.kernel_size[0], self.kernel_size[0]), | |
| padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), | |
| norm="weight_norm", | |
| ) | |
| ) | |
| self.conv_post = NormConv2d( | |
| out_chs, | |
| self.out_channels, | |
| kernel_size=(self.kernel_size[0], self.kernel_size[0]), | |
| padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), | |
| norm="weight_norm", | |
| ) | |
| self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE) | |
| self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2) | |
| def forward(self, x): | |
| fmap = [] | |
| x = self.resample(x) | |
| z = self.cqt_transform(x) | |
| z_amplitude = z[:, :, :, 0].unsqueeze(1) | |
| z_phase = z[:, :, :, 1].unsqueeze(1) | |
| z = torch.cat([z_amplitude, z_phase], dim=1) | |
| z = rearrange(z, "b c w t -> b c t w") | |
| latent_z = [] | |
| for i in range(self.n_octaves): | |
| latent_z.append( | |
| self.conv_pres[i]( | |
| z[ | |
| :, | |
| :, | |
| :, | |
| i * self.bins_per_octave : (i + 1) * self.bins_per_octave, | |
| ] | |
| ) | |
| ) | |
| latent_z = torch.cat(latent_z, dim=-1) | |
| for i, l in enumerate(self.convs): | |
| latent_z = l(latent_z) | |
| latent_z = self.activation(latent_z) | |
| fmap.append(latent_z) | |
| latent_z = self.conv_post(latent_z) | |
| return latent_z, fmap | |
| class MultiScaleSubbandCQTDiscriminator(nn.Module): | |
| def __init__(self, cfg): | |
| super(MultiScaleSubbandCQTDiscriminator, self).__init__() | |
| self.cfg = cfg | |
| self.discriminators = nn.ModuleList( | |
| [ | |
| DiscriminatorCQT( | |
| cfg, | |
| hop_length=cfg.model.mssbcqtd.hop_lengths[i], | |
| n_octaves=cfg.model.mssbcqtd.n_octaves[i], | |
| bins_per_octave=cfg.model.mssbcqtd.bins_per_octaves[i], | |
| ) | |
| for i in range(len(cfg.model.mssbcqtd.hop_lengths)) | |
| ] | |
| ) | |
| def forward(self, y, y_hat): | |
| y_d_rs = [] | |
| y_d_gs = [] | |
| fmap_rs = [] | |
| fmap_gs = [] | |
| for disc in self.discriminators: | |
| y_d_r, fmap_r = disc(y) | |
| y_d_g, fmap_g = disc(y_hat) | |
| y_d_rs.append(y_d_r) | |
| fmap_rs.append(fmap_r) | |
| y_d_gs.append(y_d_g) | |
| fmap_gs.append(fmap_g) | |
| return y_d_rs, y_d_gs, fmap_rs, fmap_gs | |