Spaces:
Sleeping
Sleeping
import tqdm | |
import torch | |
def validate(hp, args, generator, discriminator, valloader, writer, step): | |
generator.eval() | |
discriminator.eval() | |
torch.backends.cudnn.benchmark = False | |
loader = tqdm.tqdm(valloader, desc='Validation loop') | |
loss_g_sum = 0.0 | |
loss_d_sum = 0.0 | |
for mel, audio in loader: | |
mel = mel.cuda() | |
audio = audio.cuda() | |
# generator | |
fake_audio = generator(mel) | |
disc_fake = discriminator(fake_audio[:, :, :audio.size(2)]) | |
disc_real = discriminator(audio) | |
loss_g = 0.0 | |
loss_d = 0.0 | |
for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real): | |
loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2])) | |
for feat_f, feat_r in zip(feats_fake, feats_real): | |
loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r)) | |
loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2])) | |
loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) | |
loss_g_sum += loss_g.item() | |
loss_d_sum += loss_d.item() | |
loss_g_avg = loss_g_sum / len(valloader.dataset) | |
loss_d_avg = loss_d_sum / len(valloader.dataset) | |
audio = audio[0][0].cpu().detach().numpy() | |
fake_audio = fake_audio[0][0].cpu().detach().numpy() | |
writer.log_validation(loss_g_avg, loss_d_avg, generator, discriminator, audio, fake_audio, step) | |
torch.backends.cudnn.benchmark = True | |