hocherie
add files
4187c6f
raw
history blame
5.29 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
from pathlib import Path
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from torchmetrics import MeanMetric, MetricCollection
from . import logger
from .models import get_model
class AverageKeyMeter(MeanMetric):
def __init__(self, key, *args, **kwargs):
self.key = key
super().__init__(*args, **kwargs)
def update(self, dict):
value = dict[self.key]
value = value[torch.isfinite(value)]
return super().update(value)
class GenericModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
name = cfg.model.get("name")
name = "map_perception_net" if name is None else name
self.model = get_model(name)(cfg.model)
self.cfg = cfg
self.save_hyperparameters(cfg)
self.metrics_val = MetricCollection(
self.model.metrics(), prefix="val/")
self.losses_val = None # we do not know the loss keys in advance
def forward(self, batch):
return self.model(batch)
def training_step(self, batch):
pred = self(batch)
losses = self.model.loss(pred, batch)
self.log_dict(
{f"train/loss/{k}": v.mean() for k, v in losses.items()},
prog_bar=True,
rank_zero_only=True,
on_epoch=True,
sync_dist=True
)
return losses["total"].mean()
def validation_step(self, batch, batch_idx):
pred = self(batch)
losses = self.model.loss(pred, batch)
if self.losses_val is None:
self.losses_val = MetricCollection(
{k: AverageKeyMeter(k).to(self.device) for k in losses},
prefix="val/",
postfix="/loss",
)
self.metrics_val(pred, batch)
self.log_dict(self.metrics_val, on_epoch=True)
self.losses_val.update(losses)
self.log_dict(self.losses_val, on_epoch=True)
return pred
def test_step(self, batch, batch_idx):
pred = self(batch)
return pred
def validation_epoch_start(self, batch):
self.losses_val = None
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=self.cfg.training.lr)
ret = {"optimizer": optimizer}
cfg_scheduler = self.cfg.training.get("lr_scheduler")
if cfg_scheduler is not None:
scheduler_args = cfg_scheduler.get("args", {})
for key in scheduler_args:
if scheduler_args[key] == "$total_epochs":
scheduler_args[key] = int(self.trainer.max_epochs)
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
optimizer=optimizer, **scheduler_args
)
ret["lr_scheduler"] = {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
"monitor": "loss/total/val",
"strict": True,
"name": "learning_rate",
}
return ret
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path,
map_location=None,
hparams_file=None,
strict=True,
cfg=None,
find_best=False,
):
assert hparams_file is None, "hparams are not supported."
checkpoint = torch.load(
checkpoint_path, map_location=map_location or (
lambda storage, loc: storage)
)
if find_best:
best_score, best_name = None, None
modes = {"min": torch.lt, "max": torch.gt}
for key, state in checkpoint["callbacks"].items():
if not key.startswith("ModelCheckpoint"):
continue
mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
if best_score is None or modes[mode](
state["best_model_score"], best_score
):
best_score = state["best_model_score"]
best_name = Path(state["best_model_path"]).name
logger.info("Loading best checkpoint %s", best_name)
if best_name != checkpoint_path:
return cls.load_from_checkpoint(
Path(checkpoint_path).parent / best_name,
map_location,
hparams_file,
strict,
cfg,
find_best=False,
)
logger.info(
"Using checkpoint %s from epoch %d and step %d.",
checkpoint_path,
checkpoint["epoch"],
checkpoint["global_step"],
)
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
cfg_ckpt = cfg_ckpt["cfg"]
cfg_ckpt = OmegaConf.create(cfg_ckpt)
if cfg is None:
cfg = {}
if not isinstance(cfg, DictConfig):
cfg = OmegaConf.create(cfg)
with open_dict(cfg_ckpt):
cfg = OmegaConf.merge(cfg_ckpt, cfg)
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)