import logging from typing import List import pytorch_lightning as pl from omegaconf import DictConfig from pytorch_lightning.utilities import rank_zero_only def get_logger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python command line logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ( "debug", "info", "warning", "error", "exception", "fatal", "critical", ): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger log = get_logger(__name__) @rank_zero_only def log_hyperparameters( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: pl.loggers.logger.Logger, ) -> None: """Controls which config parts are saved by Lightning loggers. Additionaly saves: - number of model parameters """ if not trainer.logger: return hparams = {} # choose which parts of hydra config will be saved to loggers hparams["model"] = config["model"] # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) hparams["datamodule"] = config["datamodule"] hparams["trainer"] = config["trainer"] if "seed" in config: hparams["seed"] = config["seed"] if "callbacks" in config: hparams["callbacks"] = config["callbacks"] logger.experiment.config.update(hparams)