Spaces:
Sleeping
Sleeping
File size: 5,410 Bytes
c61c48a 05073fc c61c48a 05073fc c61c48a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import traceback
from model.generator import Generator
from model.multiscale import MultiScaleDiscriminator
from .utils import get_commit_hash
from .validation import validate
def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
model_g = Generator(hp.audio.n_mel_channels) # cuda()
model_d = MultiScaleDiscriminator() # cuda()
optim_g = torch.optim.Adam(model_g.parameters(),
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
optim_d = torch.optim.Adam(model_d.parameters(),
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
githash = get_commit_hash()
init_epoch = -1
step = 0
if chkpt_path is not None:
logger.info("Resuming from checkpoint: %s" % chkpt_path)
checkpoint = torch.load(chkpt_path)
model_g.load_state_dict(checkpoint['model_g'])
model_d.load_state_dict(checkpoint['model_d'])
optim_g.load_state_dict(checkpoint['optim_g'])
optim_d.load_state_dict(checkpoint['optim_d'])
step = checkpoint['step']
init_epoch = checkpoint['epoch']
if hp_str != checkpoint['hp_str']:
logger.warning("New hparams is different from checkpoint. Will use new.")
if githash != checkpoint['githash']:
logger.warning("Code might be different: git hash is different.")
logger.warning("%s -> %s" % (checkpoint['githash'], githash))
else:
logger.info("Starting new training run.")
# this accelerates training when the size of minibatch is always consistent.
# if not consistent, it'll horribly slow down.
torch.backends.cudnn.benchmark = True
try:
model_g.train()
model_d.train()
for epoch in itertools.count(init_epoch+1):
if epoch % hp.log.validation_interval == 0:
with torch.no_grad():
validate(hp, args, model_g, model_d, valloader, writer, step)
trainloader.dataset.shuffle_mapping()
loader = tqdm.tqdm(trainloader, desc='Loading train data')
for (melG, audioG), (melD, audioD) in loader:
# melG = melG.cuda()
# audioG = audioG.cuda()
# melD = melD.cuda()
# audioD = audioD.cuda()
# generator
optim_g.zero_grad()
fake_audio = model_g(melG)[:, :, :hp.audio.segment_length]
disc_fake = model_d(fake_audio)
disc_real = model_d(audioG)
loss_g = 0.0
for (feats_fake, score_fake), (feats_real, _) in zip(disc_fake, disc_real):
loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
for feat_f, feat_r in zip(feats_fake, feats_real):
loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
loss_g.backward()
optim_g.step()
# discriminator
fake_audio = model_g(melD)[:, :, :hp.audio.segment_length]
fake_audio = fake_audio.detach()
loss_d_sum = 0.0
for _ in range(hp.train.rep_discriminator):
optim_d.zero_grad()
disc_fake = model_d(fake_audio)
disc_real = model_d(audioD)
loss_d = 0.0
for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real):
loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
loss_d.backward()
optim_d.step()
loss_d_sum += loss_d
step += 1
# logging
loss_g = loss_g.item()
loss_d_avg = loss_d_sum / hp.train.rep_discriminator
loss_d_avg = loss_d_avg.item()
if any([loss_g > 1e8, math.isnan(loss_g), loss_d_avg > 1e8, math.isnan(loss_d_avg)]):
logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" % (loss_g, loss_d_avg, step))
raise Exception("Loss exploded")
if step % hp.log.summary_interval == 0:
writer.log_training(loss_g, loss_d_avg, step)
loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d_avg, step))
if epoch % hp.log.save_interval == 0:
save_path = os.path.join(pt_dir, '%s_%s_%04d.pt'
% (args.name, githash, epoch))
torch.save({
'model_g': model_g.state_dict(),
'model_d': model_d.state_dict(),
'optim_g': optim_g.state_dict(),
'optim_d': optim_d.state_dict(),
'step': step,
'epoch': epoch,
'hp_str': hp_str,
'githash': githash,
}, save_path)
logger.info("Saved checkpoint to: %s" % save_path)
except Exception as e:
logger.info("Exiting due to exception: %s" % e)
traceback.print_exc()
|