tts-fastspeech-mydata / vocoder /vocgan_generator.py
94insane's picture
230817,V0.1
320e69e
"""
[VocGAN] Generator
this source code is implemenation of the modified-VocGAN from rishikksh20
git repository: https://github.com/rishikksh20/VocGAN
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
MAX_WAV_VALUE = 32768.0
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class ResStack(nn.Module):
def __init__(self, channel, dilation=1):
super(ResStack, self).__init__()
self.block = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=dilation)),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
)
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x):
return self.shortcut(x) + self.block(x)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.block[2])
nn.utils.remove_weight_norm(self.block[4])
nn.utils.remove_weight_norm(self.shortcut)
# Modified VocGAN
class Generator(nn.Module):
def __init__(self, mel_channel, n_residual_layers, ratios=[4, 4, 2, 2, 2, 2], mult=256, out_band=1):
super(Generator, self).__init__()
self.mel_channel = mel_channel
self.start = nn.Sequential(
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mel_channel, mult * 2, kernel_size=7, stride=1))
)
r = ratios[0]
self.upsample_1 = nn.Sequential(
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult,
kernel_size=r * 2, stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2)
)
)
self.res_stack_1 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)])
r = ratios[1]
mult = mult // 2
self.upsample_2 = nn.Sequential(
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult,
kernel_size=r * 2, stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2)
)
)
self.res_stack_2 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)])
r = ratios[2]
mult = mult // 2
self.upsample_3 = nn.Sequential(
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult,
kernel_size=r * 2, stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2)
)
)
self.skip_upsample_1 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult,
kernel_size=64, stride=32,
padding=16,
output_padding=0)
)
self.res_stack_3 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)])
r = ratios[3]
mult = mult // 2
self.upsample_4 = nn.Sequential(
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult,
kernel_size=r * 2, stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2)
)
)
self.skip_upsample_2 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult,
kernel_size=128, stride=64,
padding=32,
output_padding=0)
)
self.res_stack_4 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)])
r = ratios[4]
mult = mult // 2
self.upsample_5 = nn.Sequential(
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult,
kernel_size=r * 2, stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2)
)
)
self.skip_upsample_3 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult,
kernel_size=256, stride=128,
padding=64,
output_padding=0)
)
self.res_stack_5 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)])
r = ratios[5]
mult = mult // 2
self.upsample_6 = nn.Sequential(
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult,
kernel_size=r * 2, stride=r,
padding=r // 2 + r % 2,
output_padding=r % 2)
)
)
self.skip_upsample_4 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult,
kernel_size=512, stride=256,
padding=128,
output_padding=0)
)
self.res_stack_6 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)])
self.out = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)),
nn.Tanh(),
)
self.apply(weights_init)
def forward(self, mel):
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
# Mel Shape [B, num_mels, T] -> torch.Size([3, 80, 10])
x = self.start(mel) # [B, dim*2, T] -> torch.Size([3, 512, 10])
x = self.upsample_1(x)
x = self.res_stack_1(x) # [B, dim, T*4] -> torch.Size([3, 256, 40])
x = self.upsample_2(x)
x = self.res_stack_2(x) # [B, dim/2, T*16] -> torch.Size([3, 128, 160])
x = self.upsample_3(x)
x = x + self.skip_upsample_1(mel)
x = self.res_stack_3(x) # [B, dim/4, T*32] -> torch.Size([3, 64, 320])
x = self.upsample_4(x)
x = x + self.skip_upsample_2(mel)
x = self.res_stack_4(x) # [B, dim/8, T*64] -> torch.Size([3, 32, 640])
x = self.upsample_5(x)
x = x + self.skip_upsample_3(mel)
x = self.res_stack_5(x) # [B, dim/16, T*128] -> torch.Size([3, 16, 1280])
x = self.upsample_6(x)
x = x + self.skip_upsample_4(mel)
x = self.res_stack_6(x) # [B, dim/32, T*256] -> torch.Size([3, 8, 2560])
out = self.out(x) # [B, 1, T*256] -> torch.Size([3, 1, 2560])
return out
def eval(self, inference=False):
super(Generator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def infer(self, mel):
hop_length = 256
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
mel = torch.cat((mel, zero), dim=2)
audio = self.forward(mel)
return audio