import torch import torch.nn as nn import torch.nn.functional as F from .res_stack import ResStack # from res_stack import ResStack MAX_WAV_VALUE = 32768.0 class Generator(nn.Module): def __init__(self, mel_channel): super(Generator, self).__init__() self.mel_channel = mel_channel self.generator = nn.Sequential( nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)), nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)), ResStack(256), nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)), ResStack(128), nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)), ResStack(64), nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)), ResStack(32), nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)), nn.Tanh(), ) def forward(self, mel): mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram return self.generator(mel) def eval(self, inference=False): super(Generator, self).eval() # don't remove weight norm while validation in training loop if inference: self.remove_weight_norm() def remove_weight_norm(self): for idx, layer in enumerate(self.generator): if len(layer.state_dict()) != 0: try: nn.utils.remove_weight_norm(layer) except: layer.remove_weight_norm() def inference(self, mel): hop_length = 256 # pad input mel with zeros to cut artifact # see https://github.com/seungwonpark/melgan/issues/8 zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) mel = torch.cat((mel, zero), dim=2) audio = self.forward(mel) audio = audio.squeeze() # collapse all dimension except time axis audio = audio[:-(hop_length*10)] audio = MAX_WAV_VALUE * audio audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) audio = audio.short() return audio ''' to run this, fix from . import ResStack into from res_stack import ResStack ''' if __name__ == '__main__': model = Generator(80) x = torch.randn(3, 80, 10) print(x.shape) y = model(x) print(y.shape) assert y.shape == torch.Size([3, 1, 2560]) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(pytorch_total_params)