Spaces:
Sleeping
Sleeping
import logging | |
from typing import List | |
import pytorch_lightning as pl | |
from omegaconf import DictConfig | |
from pytorch_lightning.utilities import rank_zero_only | |
def get_logger(name=__name__) -> logging.Logger: | |
"""Initializes multi-GPU-friendly python command line logger.""" | |
logger = logging.getLogger(name) | |
# this ensures all logging levels get marked with the rank zero decorator | |
# otherwise logs would get multiplied for each GPU process in multi-GPU setup | |
for level in ( | |
"debug", | |
"info", | |
"warning", | |
"error", | |
"exception", | |
"fatal", | |
"critical", | |
): | |
setattr(logger, level, rank_zero_only(getattr(logger, level))) | |
return logger | |
log = get_logger(__name__) | |
def log_hyperparameters( | |
config: DictConfig, | |
model: pl.LightningModule, | |
datamodule: pl.LightningDataModule, | |
trainer: pl.Trainer, | |
callbacks: List[pl.Callback], | |
logger: pl.loggers.logger.Logger, | |
) -> None: | |
"""Controls which config parts are saved by Lightning loggers. | |
Additionaly saves: | |
- number of model parameters | |
""" | |
if not trainer.logger: | |
return | |
hparams = {} | |
# choose which parts of hydra config will be saved to loggers | |
hparams["model"] = config["model"] | |
# save number of model parameters | |
hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) | |
hparams["model/params/trainable"] = sum( | |
p.numel() for p in model.parameters() if p.requires_grad | |
) | |
hparams["model/params/non_trainable"] = sum( | |
p.numel() for p in model.parameters() if not p.requires_grad | |
) | |
hparams["datamodule"] = config["datamodule"] | |
hparams["trainer"] = config["trainer"] | |
if "seed" in config: | |
hparams["seed"] = config["seed"] | |
if "callbacks" in config: | |
hparams["callbacks"] = config["callbacks"] | |
logger.experiment.config.update(hparams) | |