### # Author: Kai Li # Date: 2024-01-22 01:16:22 # Email: lk21@mails.tsinghua.edu.cn # LastEditTime: 2024-01-24 00:05:10 ### import json from typing import Any, Dict, List, Optional, Tuple import os from omegaconf import OmegaConf import argparse import pytorch_lightning as pl import torch torch.set_float32_matmul_precision("highest") import hydra from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer # from pytorch_lightning.loggers import Logger from omegaconf import DictConfig import look2hear.system import look2hear.datas import look2hear.losses from look2hear.utils import RankedLogger, instantiate, print_only import warnings warnings.filterwarnings("ignore") def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("seed"): pl.seed_everything(cfg.seed, workers=True) # instantiate datamodule print_only(f"Instantiating datamodule <{cfg.datas._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas) # instantiate model print_only(f"Instantiating AudioNet <{cfg.model._target_}>") model: torch.nn.Module = hydra.utils.instantiate(cfg.model) print_only(f"Instantiating Discriminator <{cfg.discriminator._target_}>") discriminator: torch.nn.Module = hydra.utils.instantiate(cfg.discriminator) # instantiate optimizer print_only(f"Instantiating optimizer <{cfg.optimizer_g._target_}>") optimizer_g: torch.optim = hydra.utils.instantiate(cfg.optimizer_g, params=model.parameters()) optimizer_d: torch.optim = hydra.utils.instantiate(cfg.optimizer_d, params=discriminator.parameters()) # optimizer: torch.optim = torch.optim.Adam(model.parameters(), lr=cfg.optimizer.lr) # instantiate scheduler print_only(f"Instantiating scheduler <{cfg.scheduler_g._target_}>") scheduler_g: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_g, optimizer=optimizer_g) scheduler_d: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_d, optimizer=optimizer_d) # instantiate loss print_only(f"Instantiating loss <{cfg.loss_g._target_}>") loss_g: torch.nn.Module = hydra.utils.instantiate(cfg.loss_g) loss_d: torch.nn.Module = hydra.utils.instantiate(cfg.loss_d) losses = { "g": loss_g, "d": loss_d } # instantiate metrics print_only(f"Instantiating metrics <{cfg.metrics._target_}>") metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics) # instantiate system print_only(f"Instantiating system <{cfg.system._target_}>") system: LightningModule = hydra.utils.instantiate( cfg.system, model=model, discriminator=discriminator, loss_func=losses, metrics=metrics, optimizer=[optimizer_g, optimizer_d], scheduler=[scheduler_g, scheduler_d] ) # instantiate callbacks callbacks: List[Callback] = [] if cfg.get("early_stopping"): print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>") callbacks.append(hydra.utils.instantiate(cfg.early_stopping)) if cfg.get("checkpoint"): print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>") checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint) callbacks.append(checkpoint) # instantiate logger print_only(f"Instantiating logger <{cfg.logger._target_}>") os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True) logger = hydra.utils.instantiate(cfg.logger) # instantiate trainer print_only(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( cfg.trainer, callbacks=callbacks, logger=logger, strategy=DDPStrategy(find_unused_parameters=True), ) trainer.fit(system, datamodule=datamodule) print_only("Training finished!") best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.audio_model.serialize() torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth")) import wandb if wandb.run: print_only("Closing wandb!") wandb.finish() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--conf_dir", default="local/conf.yml", help="Full path to save best validation model", ) args = parser.parse_args() cfg = OmegaConf.load(args.conf_dir) os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True) # 保存配置到新的文件 OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml")) train(cfg)