atsushieee's picture
Upload folder using huggingface_hub
9791162
raw
history blame
13 kB
import os
import time
import logging
import math
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from vits_extend.dataloader import create_dataloader_train
from vits_extend.dataloader import create_dataloader_eval
from vits_extend.writer import MyWriter
from vits_extend.stft import TacotronSTFT
from vits_extend.stft_loss import MultiResolutionSTFTLoss
from vits_extend.validation import validate
from vits_decoder.discriminator import Discriminator
from vits.models import SynthesizerTrn
from vits import commons
from vits.losses import kl_loss
from vits.commons import clip_grad_value_
def load_part(model, saved_state_dict):
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('TODO'):
new_state_dict[k] = v
else:
new_state_dict[k] = saved_state_dict[k]
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
return model
def load_model(model, saved_state_dict):
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
print("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
return model
def train(rank, args, chkpt_path, hp, hp_str):
if args.num_gpus > 1:
init_process_group(backend=hp.dist_config.dist_backend, init_method=hp.dist_config.dist_url,
world_size=hp.dist_config.world_size * args.num_gpus, rank=rank)
torch.cuda.manual_seed(hp.train.seed)
device = torch.device('cuda:{:d}'.format(rank))
model_g = SynthesizerTrn(
hp.data.filter_length // 2 + 1,
hp.data.segment_size // hp.data.hop_length,
hp).to(device)
model_d = Discriminator(hp).to(device)
optim_g = torch.optim.AdamW(model_g.parameters(),
lr=hp.train.learning_rate, betas=hp.train.betas, eps=hp.train.eps)
optim_d = torch.optim.AdamW(model_d.parameters(),
lr=(hp.train.learning_rate / hp.train.accum_step), betas=hp.train.betas, eps=hp.train.eps)
init_epoch = 1
step = 0
stft = TacotronSTFT(filter_length=hp.data.filter_length,
hop_length=hp.data.hop_length,
win_length=hp.data.win_length,
n_mel_channels=hp.data.mel_channels,
sampling_rate=hp.data.sampling_rate,
mel_fmin=hp.data.mel_fmin,
mel_fmax=hp.data.mel_fmax,
center=False,
device=device)
# define logger, writer, valloader, stft at rank_zero
if rank == 0:
pth_dir = os.path.join(hp.log.pth_dir, args.name)
log_dir = os.path.join(hp.log.log_dir, args.name)
os.makedirs(pth_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.join(log_dir, '%s-%d.log' % (args.name, time.time()))),
logging.StreamHandler()
]
)
logger = logging.getLogger()
writer = MyWriter(hp, log_dir)
valloader = create_dataloader_eval(hp)
if os.path.isfile(hp.train.pretrain):
if rank == 0:
logger.info("Start from 32k pretrain model: %s" % hp.train.pretrain)
checkpoint = torch.load(hp.train.pretrain, map_location='cpu')
load_model(model_g, checkpoint['model_g'])
load_model(model_d, checkpoint['model_d'])
if chkpt_path is not None:
if rank == 0:
logger.info("Resuming from checkpoint: %s" % chkpt_path)
checkpoint = torch.load(chkpt_path, map_location='cpu')
load_model(model_g, checkpoint['model_g'])
load_model(model_d, checkpoint['model_d'])
optim_g.load_state_dict(checkpoint['optim_g'])
optim_d.load_state_dict(checkpoint['optim_d'])
init_epoch = checkpoint['epoch']
step = checkpoint['step']
if rank == 0:
if hp_str != checkpoint['hp_str']:
logger.warning("New hparams is different from checkpoint. Will use new.")
else:
if rank == 0:
logger.info("Starting new training run.")
if args.num_gpus > 1:
model_g = DistributedDataParallel(model_g, device_ids=[rank])
model_d = DistributedDataParallel(model_d, device_ids=[rank])
# this accelerates training when the size of minibatch is always consistent.
# if not consistent, it'll horribly slow down.
torch.backends.cudnn.benchmark = True
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.lr_decay, last_epoch=init_epoch-2)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.lr_decay, last_epoch=init_epoch-2)
stft_criterion = MultiResolutionSTFTLoss(device, eval(hp.mrd.resolutions))
spkc_criterion = nn.CosineEmbeddingLoss()
trainloader = create_dataloader_train(hp, args.num_gpus, rank)
for epoch in range(init_epoch, hp.train.epochs):
trainloader.batch_sampler.set_epoch(epoch)
if rank == 0 and epoch % hp.log.eval_interval == 0:
with torch.no_grad():
validate(hp, args, model_g, model_d, valloader, stft, writer, step, device)
if rank == 0:
loader = tqdm.tqdm(trainloader, desc='Loading train data')
else:
loader = trainloader
model_g.train()
model_d.train()
for ppg, ppg_l, vec, pit, spk, spec, spec_l, audio, audio_l in loader:
ppg = ppg.to(device)
vec = vec.to(device)
pit = pit.to(device)
spk = spk.to(device)
spec = spec.to(device)
audio = audio.to(device)
ppg_l = ppg_l.to(device)
spec_l = spec_l.to(device)
audio_l = audio_l.to(device)
# generator
fake_audio, ids_slice, z_mask, \
(z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r), spk_preds = model_g(
ppg, vec, pit, spec, spk, ppg_l, spec_l)
audio = commons.slice_segments(
audio, ids_slice * hp.data.hop_length, hp.data.segment_size) # slice
# Spk Loss
spk_loss = spkc_criterion(spk, spk_preds, torch.Tensor(spk_preds.size(0))
.to(device).fill_(1.0))
# Mel Loss
mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1))
mel_real = stft.mel_spectrogram(audio.squeeze(1))
mel_loss = F.l1_loss(mel_fake, mel_real) * hp.train.c_mel
# Multi-Resolution STFT Loss
sc_loss, mag_loss = stft_criterion(fake_audio.squeeze(1), audio.squeeze(1))
stft_loss = (sc_loss + mag_loss) * hp.train.c_stft
# Generator Loss
disc_fake = model_d(fake_audio)
score_loss = 0.0
for (_, score_fake) in disc_fake:
score_loss += torch.mean(torch.pow(score_fake - 1.0, 2))
score_loss = score_loss / len(disc_fake)
# Feature Loss
disc_real = model_d(audio)
feat_loss = 0.0
for (feat_fake, _), (feat_real, _) in zip(disc_fake, disc_real):
for fake, real in zip(feat_fake, feat_real):
feat_loss += torch.mean(torch.abs(fake - real))
feat_loss = feat_loss / len(disc_fake)
feat_loss = feat_loss * 2
# Kl Loss
loss_kl_f = kl_loss(z_f, logs_q, m_p, logs_p, logdet_f, z_mask) * hp.train.c_kl
loss_kl_r = kl_loss(z_r, logs_p, m_q, logs_q, logdet_r, z_mask) * hp.train.c_kl
# Loss
loss_g = score_loss + feat_loss + mel_loss + stft_loss + loss_kl_f + loss_kl_r * 0.5 + spk_loss * 2
loss_g.backward()
if ((step + 1) % hp.train.accum_step == 0) or (step + 1 == len(loader)):
# accumulate gradients for accum steps
for param in model_g.parameters():
param.grad /= hp.train.accum_step
clip_grad_value_(model_g.parameters(), None)
# update model
optim_g.step()
optim_g.zero_grad()
# discriminator
optim_d.zero_grad()
disc_fake = model_d(fake_audio.detach())
disc_real = model_d(audio)
loss_d = 0.0
for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real):
loss_d += torch.mean(torch.pow(score_real - 1.0, 2))
loss_d += torch.mean(torch.pow(score_fake, 2))
loss_d = loss_d / len(disc_fake)
loss_d.backward()
clip_grad_value_(model_d.parameters(), None)
optim_d.step()
step += 1
# logging
loss_g = loss_g.item()
loss_d = loss_d.item()
loss_s = stft_loss.item()
loss_m = mel_loss.item()
loss_k = loss_kl_f.item()
loss_r = loss_kl_r.item()
loss_i = spk_loss.item()
if rank == 0 and step % hp.log.info_interval == 0:
writer.log_training(
loss_g, loss_d, loss_m, loss_s, loss_k, loss_r, score_loss.item(), step)
logger.info("epoch %d | g %.04f m %.04f s %.04f d %.04f k %.04f r %.04f i %.04f | step %d" % (
epoch, loss_g, loss_m, loss_s, loss_d, loss_k, loss_r, loss_i, step))
if rank == 0 and epoch % hp.log.save_interval == 0:
save_path = os.path.join(pth_dir, '%s_%04d.pt'
% (args.name, epoch))
torch.save({
'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(),
'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(),
'optim_g': optim_g.state_dict(),
'optim_d': optim_d.state_dict(),
'step': step,
'epoch': epoch,
'hp_str': hp_str,
}, save_path)
logger.info("Saved checkpoint to: %s" % save_path)
if rank == 0:
def clean_checkpoints(path_to_models=f'{pth_dir}', n_ckpts_to_keep=hp.log.keep_ckpts, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding sovits5.0_0.pth
If n_ckpts_to_keep == 0, do not delete any ckpts
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
assert isinstance(n_ckpts_to_keep, int) and n_ckpts_to_keep >= 0
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile(f'{args.name}_(\d+)\.pt').match(_f).group(1)))
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith('sovits5.0_0.pth')], key=sort_key)
if n_ckpts_to_keep == 0:
to_del = []
else:
to_del = [os.path.join(path_to_models, fn) for fn in x_sorted(f'{args.name}')[:-n_ckpts_to_keep]]
del_info = lambda fn: logger.info(f"Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
clean_checkpoints()
os.makedirs(f'{pth_dir}', exist_ok=True)
keep_ckpts = getattr(hp.log, 'keep_ckpts', 0)
if keep_ckpts > 0:
clean_checkpoints(path_to_models=f'{pth_dir}', n_ckpts_to_keep=hp.log.keep_ckpts, sort_by_time=True)
scheduler_g.step()
scheduler_d.step()