AlexK-PL's picture
Upload 72 files
c61c48a
raw
history blame
1.87 kB
import os
import time
import logging
import argparse
from utils.train import train
from utils.hparams import HParam
from utils.writer import MyWriter
from datasets.dataloader import create_dataloader
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True,
help="yaml file for configuration")
parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
help="path of checkpoint pt file to resume training")
parser.add_argument('-n', '--name', type=str, required=True,
help="name of the model for logging, saving checkpoint")
args = parser.parse_args()
hp = HParam(args.config)
with open(args.config, 'r') as f:
hp_str = ''.join(f.readlines())
pt_dir = os.path.join(hp.log.chkpt_dir, args.name)
log_dir = os.path.join(hp.log.log_dir, args.name)
os.makedirs(pt_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.join(log_dir,
'%s-%d.log' % (args.name, time.time()))),
logging.StreamHandler()
]
)
logger = logging.getLogger()
writer = MyWriter(hp, log_dir)
assert hp.audio.hop_length == 256, \
'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length
assert hp.data.train != '' and hp.data.validation != '', \
'hp.data.train and hp.data.validation can\'t be empty: please fix %s' % args.config
trainloader = create_dataloader(hp, args, True)
valloader = create_dataloader(hp, args, False)
train(args, pt_dir, args.checkpoint_path, trainloader, valloader, writer, logger, hp, hp_str)