import numpy as np | |
import torch | |
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator | |
def test_pwgan_generator(): | |
model = ParallelWaveganGenerator( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=3, | |
num_res_blocks=30, | |
stacks=3, | |
res_channels=64, | |
gate_channels=128, | |
skip_channels=64, | |
aux_channels=80, | |
dropout=0.0, | |
bias=True, | |
use_weight_norm=True, | |
upsample_factors=[4, 4, 4, 4], | |
) | |
dummy_c = torch.rand((2, 80, 5)) | |
output = model(dummy_c) | |
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape | |
model.remove_weight_norm() | |
output = model.inference(dummy_c) | |
assert np.all(output.shape == (2, 1, (5 + 4) * 256)) | |