|
import math
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.utils.parametrize import remove_parametrizations
|
|
|
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
|
|
|
|
|
class ParallelWaveganDiscriminator(nn.Module):
|
|
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
|
|
It classifies each audio window real/fake and returns a sequence
|
|
of predictions.
|
|
It is a stack of convolutional blocks with dilation.
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=3,
|
|
num_layers=10,
|
|
conv_channels=64,
|
|
dilation_factor=1,
|
|
nonlinear_activation="LeakyReLU",
|
|
nonlinear_activation_params={"negative_slope": 0.2},
|
|
bias=True,
|
|
):
|
|
super().__init__()
|
|
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
|
|
assert dilation_factor > 0, " [!] dilation factor must be > 0."
|
|
self.conv_layers = nn.ModuleList()
|
|
conv_in_channels = in_channels
|
|
for i in range(num_layers - 1):
|
|
if i == 0:
|
|
dilation = 1
|
|
else:
|
|
dilation = i if dilation_factor == 1 else dilation_factor**i
|
|
conv_in_channels = conv_channels
|
|
padding = (kernel_size - 1) // 2 * dilation
|
|
conv_layer = [
|
|
nn.Conv1d(
|
|
conv_in_channels,
|
|
conv_channels,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias,
|
|
),
|
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
|
]
|
|
self.conv_layers += conv_layer
|
|
padding = (kernel_size - 1) // 2
|
|
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
|
self.conv_layers += [last_conv_layer]
|
|
self.apply_weight_norm()
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x : (B, 1, T).
|
|
Returns:
|
|
Tensor: (B, 1, T)
|
|
"""
|
|
for f in self.conv_layers:
|
|
x = f(x)
|
|
return x
|
|
|
|
def apply_weight_norm(self):
|
|
def _apply_weight_norm(m):
|
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
|
torch.nn.utils.parametrizations.weight_norm(m)
|
|
|
|
self.apply(_apply_weight_norm)
|
|
|
|
def remove_weight_norm(self):
|
|
def _remove_weight_norm(m):
|
|
try:
|
|
|
|
remove_parametrizations(m, "weight")
|
|
except ValueError:
|
|
return
|
|
|
|
self.apply(_remove_weight_norm)
|
|
|
|
|
|
class ResidualParallelWaveganDiscriminator(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=3,
|
|
num_layers=30,
|
|
stacks=3,
|
|
res_channels=64,
|
|
gate_channels=128,
|
|
skip_channels=64,
|
|
dropout=0.0,
|
|
bias=True,
|
|
nonlinear_activation="LeakyReLU",
|
|
nonlinear_activation_params={"negative_slope": 0.2},
|
|
):
|
|
super().__init__()
|
|
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.num_layers = num_layers
|
|
self.stacks = stacks
|
|
self.kernel_size = kernel_size
|
|
self.res_factor = math.sqrt(1.0 / num_layers)
|
|
|
|
|
|
assert num_layers % stacks == 0
|
|
layers_per_stack = num_layers // stacks
|
|
|
|
|
|
self.first_conv = nn.Sequential(
|
|
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
|
)
|
|
|
|
|
|
self.conv_layers = nn.ModuleList()
|
|
for layer in range(num_layers):
|
|
dilation = 2 ** (layer % layers_per_stack)
|
|
conv = ResidualBlock(
|
|
kernel_size=kernel_size,
|
|
res_channels=res_channels,
|
|
gate_channels=gate_channels,
|
|
skip_channels=skip_channels,
|
|
aux_channels=-1,
|
|
dilation=dilation,
|
|
dropout=dropout,
|
|
bias=bias,
|
|
use_causal_conv=False,
|
|
)
|
|
self.conv_layers += [conv]
|
|
|
|
|
|
self.last_conv_layers = nn.ModuleList(
|
|
[
|
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
|
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params),
|
|
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True),
|
|
]
|
|
)
|
|
|
|
|
|
self.apply_weight_norm()
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: (B, 1, T).
|
|
"""
|
|
x = self.first_conv(x)
|
|
|
|
skips = 0
|
|
for f in self.conv_layers:
|
|
x, h = f(x, None)
|
|
skips += h
|
|
skips *= self.res_factor
|
|
|
|
|
|
x = skips
|
|
for f in self.last_conv_layers:
|
|
x = f(x)
|
|
return x
|
|
|
|
def apply_weight_norm(self):
|
|
def _apply_weight_norm(m):
|
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
|
torch.nn.utils.parametrizations.weight_norm(m)
|
|
|
|
self.apply(_apply_weight_norm)
|
|
|
|
def remove_weight_norm(self):
|
|
def _remove_weight_norm(m):
|
|
try:
|
|
print(f"Weight norm is removed from {m}.")
|
|
remove_parametrizations(m, "weight")
|
|
except ValueError:
|
|
return
|
|
|
|
self.apply(_remove_weight_norm)
|
|
|