Spaces:
Building
Building
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from pathlib import Path | |
import pytorch_lightning as pl | |
import torch | |
from omegaconf import DictConfig, OmegaConf, open_dict | |
from torchmetrics import MeanMetric, MetricCollection | |
from . import logger | |
from .models import get_model | |
class AverageKeyMeter(MeanMetric): | |
def __init__(self, key, *args, **kwargs): | |
self.key = key | |
super().__init__(*args, **kwargs) | |
def update(self, dict): | |
value = dict[self.key] | |
value = value[torch.isfinite(value)] | |
return super().update(value) | |
class GenericModule(pl.LightningModule): | |
def __init__(self, cfg): | |
super().__init__() | |
name = cfg.model.get("name") | |
name = "map_perception_net" if name is None else name | |
self.model = get_model(name)(cfg.model) | |
self.cfg = cfg | |
self.save_hyperparameters(cfg) | |
self.metrics_val = MetricCollection( | |
self.model.metrics(), prefix="val/") | |
self.losses_val = None # we do not know the loss keys in advance | |
def forward(self, batch): | |
return self.model(batch) | |
def training_step(self, batch): | |
pred = self(batch) | |
losses = self.model.loss(pred, batch) | |
self.log_dict( | |
{f"train/loss/{k}": v.mean() for k, v in losses.items()}, | |
prog_bar=True, | |
rank_zero_only=True, | |
on_epoch=True, | |
sync_dist=True | |
) | |
return losses["total"].mean() | |
def validation_step(self, batch, batch_idx): | |
pred = self(batch) | |
losses = self.model.loss(pred, batch) | |
if self.losses_val is None: | |
self.losses_val = MetricCollection( | |
{k: AverageKeyMeter(k).to(self.device) for k in losses}, | |
prefix="val/", | |
postfix="/loss", | |
) | |
self.metrics_val(pred, batch) | |
self.log_dict(self.metrics_val, on_epoch=True) | |
self.losses_val.update(losses) | |
self.log_dict(self.losses_val, on_epoch=True) | |
return pred | |
def test_step(self, batch, batch_idx): | |
pred = self(batch) | |
return pred | |
def validation_epoch_start(self, batch): | |
self.losses_val = None | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam( | |
self.parameters(), lr=self.cfg.training.lr) | |
ret = {"optimizer": optimizer} | |
cfg_scheduler = self.cfg.training.get("lr_scheduler") | |
if cfg_scheduler is not None: | |
scheduler_args = cfg_scheduler.get("args", {}) | |
for key in scheduler_args: | |
if scheduler_args[key] == "$total_epochs": | |
scheduler_args[key] = int(self.trainer.max_epochs) | |
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)( | |
optimizer=optimizer, **scheduler_args | |
) | |
ret["lr_scheduler"] = { | |
"scheduler": scheduler, | |
"interval": "epoch", | |
"frequency": 1, | |
"monitor": "loss/total/val", | |
"strict": True, | |
"name": "learning_rate", | |
} | |
return ret | |
def load_from_checkpoint( | |
cls, | |
checkpoint_path, | |
map_location=None, | |
hparams_file=None, | |
strict=True, | |
cfg=None, | |
find_best=False, | |
): | |
assert hparams_file is None, "hparams are not supported." | |
checkpoint = torch.load( | |
checkpoint_path, map_location=map_location or ( | |
lambda storage, loc: storage) | |
) | |
if find_best: | |
best_score, best_name = None, None | |
modes = {"min": torch.lt, "max": torch.gt} | |
for key, state in checkpoint["callbacks"].items(): | |
if not key.startswith("ModelCheckpoint"): | |
continue | |
mode = eval(key.replace("ModelCheckpoint", ""))["mode"] | |
if best_score is None or modes[mode]( | |
state["best_model_score"], best_score | |
): | |
best_score = state["best_model_score"] | |
best_name = Path(state["best_model_path"]).name | |
logger.info("Loading best checkpoint %s", best_name) | |
if best_name != checkpoint_path: | |
return cls.load_from_checkpoint( | |
Path(checkpoint_path).parent / best_name, | |
map_location, | |
hparams_file, | |
strict, | |
cfg, | |
find_best=False, | |
) | |
logger.info( | |
"Using checkpoint %s from epoch %d and step %d.", | |
checkpoint_path, | |
checkpoint["epoch"], | |
checkpoint["global_step"], | |
) | |
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] | |
if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility | |
cfg_ckpt = cfg_ckpt["cfg"] | |
cfg_ckpt = OmegaConf.create(cfg_ckpt) | |
if cfg is None: | |
cfg = {} | |
if not isinstance(cfg, DictConfig): | |
cfg = OmegaConf.create(cfg) | |
with open_dict(cfg_ckpt): | |
cfg = OmegaConf.merge(cfg_ckpt, cfg) | |
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg) | |