# Copyright (c) Meta Platforms, Inc. and affiliates. from pathlib import Path import pytorch_lightning as pl import torch from omegaconf import DictConfig, OmegaConf, open_dict from torchmetrics import MeanMetric, MetricCollection import logger from models import get_model class AverageKeyMeter(MeanMetric): def __init__(self, key, *args, **kwargs): self.key = key super().__init__(*args, **kwargs) def update(self, dict): value = dict[self.key] value = value[torch.isfinite(value)] return super().update(value) class GenericModule(pl.LightningModule): def __init__(self, cfg): super().__init__() name = cfg.model.get("name") name = "orienternet" if name in ("localizer_bev_depth", None) else name self.model = get_model(name)(cfg.model) self.cfg = cfg self.save_hyperparameters(cfg) self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/") self.losses_val = None # we do not know the loss keys in advance # self.citys = self.cfg.data.val_citys # for i in range(len(self.citys)): # city=self.citys[i] # setattr(self, "metric_vals_{}".format(i), MetricCollection(self.model.metrics(), prefix="val_{}/".format(city))) # self.losse_vals = [None for city in self.cfg.data.val_citys] def forward(self, batch): return self.model(batch) def training_step(self, batch): pred = self(batch) losses = self.model.loss(pred, batch) self.log_dict( {f"loss/{k}/train": v.mean() for k, v in losses.items()}, prog_bar=True, rank_zero_only=True, ) return losses["total"].mean() # def validation_step(self, batch, batch_idx,dataloader_idx): # city=self.citys[dataloader_idx] # # pred = self(batch) # losses = self.model.loss(pred, batch) # # if hasattr(self,"losse_val_{}".format(dataloader_idx)) is False: # setattr(self,"losse_val_{}".format(dataloader_idx),MetricCollection( # {k: AverageKeyMeter(k).to(self.device) for k in losses}, # prefix="loss_{}/".format(city), # postfix="/val_{}".format(city), # )) # # # print(pred, batch) # getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch) # self.log_dict(getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch), sync_dist=True) # # getattr(self,"losse_val_{}".format(dataloader_idx)).update(losses) # # print(getattr(self,"losse_val_{}".format(dataloader_idx))) # self.log_dict(getattr(self,"losse_val_{}".format(dataloader_idx)).compute(), sync_dist=True) def validation_step(self, batch, batch_idx): pred = self(batch) losses = self.model.loss(pred, batch) if self.losses_val is None: self.losses_val = MetricCollection( {k: AverageKeyMeter(k).to(self.device) for k in losses}, prefix="loss/", postfix="/val", ) self.metrics_val(pred, batch) self.log_dict(self.metrics_val, sync_dist=True) self.losses_val.update(losses) self.log_dict(self.losses_val, sync_dist=True) def validation_epoch_start(self, batch): self.losses_val = None # self.losse_val = [None for city in self.cfg.data.val_citys] def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr) ret = {"optimizer": optimizer} cfg_scheduler = self.cfg.training.get("lr_scheduler") if cfg_scheduler is not None: scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)( optimizer=optimizer, **cfg_scheduler.get("args", {}) ) ret["lr_scheduler"] = { "scheduler": scheduler, "interval": "epoch", "frequency": 1, "monitor": "loss/total/val", "strict": True, "name": "learning_rate", } return ret @classmethod def load_from_checkpoint( cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, cfg=None, find_best=False, ): assert hparams_file is None, "hparams are not supported." checkpoint = torch.load( checkpoint_path, map_location=map_location or (lambda storage, loc: storage) ) if find_best: best_score, best_name = None, None modes = {"min": torch.lt, "max": torch.gt} for key, state in checkpoint["callbacks"].items(): if not key.startswith("ModelCheckpoint"): continue mode = eval(key.replace("ModelCheckpoint", ""))["mode"] if best_score is None or modes[mode]( state["best_model_score"], best_score ): best_score = state["best_model_score"] best_name = Path(state["best_model_path"]).name logger.info("Loading best checkpoint %s", best_name) if best_name != checkpoint_path: return cls.load_from_checkpoint( Path(checkpoint_path).parent / best_name, map_location, hparams_file, strict, cfg, find_best=False, ) logger.info( "Using checkpoint %s from epoch %d and step %d.", checkpoint_path.name, checkpoint["epoch"], checkpoint["global_step"], ) cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility cfg_ckpt = cfg_ckpt["cfg"] cfg_ckpt = OmegaConf.create(cfg_ckpt) if cfg is None: cfg = {} if not isinstance(cfg, DictConfig): cfg = OmegaConf.create(cfg) with open_dict(cfg_ckpt): cfg = OmegaConf.merge(cfg_ckpt, cfg) return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)