# Copyright (c) ByteDance, Inc. and its affiliates. # Copyright (c) Chutong Meng # # This source code is licensed under the CC BY-NC license found in the # LICENSE file in the root directory of this source tree. # Based on AudioDec (https://github.com/facebookresearch/AudioDec) import argparse import logging import os logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), ) logger = logging.getLogger("repcodec_train") # init logger before other modules import random import numpy as np import torch import yaml from torch.utils.data import DataLoader from dataloader import ReprDataset, ReprCollater from losses.repr_reconstruct_loss import ReprReconstructLoss from repcodec.RepCodec import RepCodec from trainer.autoencoder import Trainer class TrainMain: def __init__(self, args): # Fix seed and make backends deterministic random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not torch.cuda.is_available(): self.device = torch.device('cpu') logger.info(f"device: cpu") else: self.device = torch.device('cuda:0') # only supports single gpu for now logger.info(f"device: gpu") torch.cuda.manual_seed_all(args.seed) if args.disable_cudnn == "False": torch.backends.cudnn.benchmark = True # initialize config with open(args.config, 'r') as f: self.config = yaml.load(f, Loader=yaml.FullLoader) self.config.update(vars(args)) # initialize model folder expdir = os.path.join(args.exp_root, args.tag) os.makedirs(expdir, exist_ok=True) self.config["outdir"] = expdir # save config with open(os.path.join(expdir, "config.yml"), "w") as f: yaml.dump(self.config, f, Dumper=yaml.Dumper) for key, value in self.config.items(): logger.info(f"{key} = {value}") # initialize attribute self.resume: str = args.resume self.data_loader = None self.model = None self.optimizer = None self.scheduler = None self.criterion = None self.trainer = None # initialize batch_length self.batch_length: int = self.config['batch_length'] self.data_path: str = self.config['data']['path'] def initialize_data_loader(self): train_set = self._build_dataset("train") valid_set = self._build_dataset("valid") collater = ReprCollater() logger.info(f"The number of training files = {len(train_set)}.") logger.info(f"The number of validation files = {len(valid_set)}.") dataset = {"train": train_set, "dev": valid_set} self._set_data_loader(dataset, collater) def define_model_optimizer_scheduler(self): # model arch self.model = { "repcodec": RepCodec(**self.config["model_params"]).to(self.device) } logger.info(f"Model Arch:\n{self.model['repcodec']}") # opt optimizer_class = getattr( torch.optim, self.config["model_optimizer_type"] ) self.optimizer = { "repcodec": optimizer_class( self.model["repcodec"].parameters(), **self.config["model_optimizer_params"] ) } # scheduler scheduler_class = getattr( torch.optim.lr_scheduler, self.config.get("model_scheduler_type", "StepLR"), ) self.scheduler = { "repcodec": scheduler_class( optimizer=self.optimizer["repcodec"], **self.config["model_scheduler_params"] ) } def define_criterion(self): self.criterion = { "repr_reconstruct_loss": ReprReconstructLoss( **self.config.get("repr_reconstruct_loss_params", {}), ).to(self.device) } def define_trainer(self): self.trainer = Trainer( steps=0, epochs=0, data_loader=self.data_loader, model=self.model, criterion=self.criterion, optimizer=self.optimizer, scheduler=self.scheduler, config=self.config, device=self.device ) def initialize_model(self): initial = self.config.get("initial", "") if os.path.exists(self.resume): # resume from trained model self.trainer.load_checkpoint(self.resume) logger.info(f"Successfully resumed from {self.resume}.") elif os.path.exists(initial): # initial new model with the pre-trained model self.trainer.load_checkpoint(initial, load_only_params=True) logger.info(f"Successfully initialize parameters from {initial}.") else: logger.info("Train from scrach") def run(self): assert self.trainer is not None self.trainer: Trainer try: logger.info(f"The current training step: {self.trainer.steps}") self.trainer.train_max_steps = self.config["train_max_steps"] if not self.trainer._check_train_finish(): self.trainer.run() finally: self.trainer.save_checkpoint( os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl") ) logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.") def _build_dataset( self, subset: str ) -> ReprDataset: data_dir = os.path.join( self.data_path, self.config['data']['subset'][subset] ) params = { "data_dir": data_dir, "batch_len": self.batch_length } return ReprDataset(**params) def _set_data_loader(self, dataset, collater): self.data_loader = { "train": DataLoader( dataset=dataset["train"], shuffle=True, collate_fn=collater, batch_size=self.config["batch_size"], num_workers=self.config["num_workers"], pin_memory=self.config["pin_memory"], ), "dev": DataLoader( dataset=dataset["dev"], shuffle=False, collate_fn=collater, batch_size=self.config["batch_size"], num_workers=0, pin_memory=False, # save some memory. set to True if you have enough memory. ), } def train(): parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, required=True, help="the path of config yaml file." ) parser.add_argument( "--tag", type=str, required=True, help="the outputs will be saved to exp_root/tag/" ) parser.add_argument( "--exp_root", type=str, default="exp" ) parser.add_argument( "--resume", default="", type=str, nargs="?", help='checkpoint file path to resume training. (default="")', ) parser.add_argument("--seed", default=1337, type=int) parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN") args = parser.parse_args() train_main = TrainMain(args) train_main.initialize_data_loader() train_main.define_model_optimizer_scheduler() train_main.define_criterion() train_main.define_trainer() train_main.initialize_model() train_main.run() if __name__ == '__main__': train()