from typing import Tuple import hydra from hydra.utils import instantiate import logging from omegaconf import DictConfig from pathlib import Path import pytorch_lightning as pl from torch.utils.data import DataLoader from det_map.data.datasets.dataset_det import DetDataset from det_map.utils import collate_fn_pad_lidar from det_map.data.datasets.dataset import Dataset from navsim.planning.training.agent_lightning_module import AgentLightningModule from det_map.data.datasets.dataloader import SceneLoader from det_map.data.datasets.dataclasses import SceneFilter from navsim.agents.abstract_agent import AbstractAgent logger = logging.getLogger(__name__) CONFIG_PATH = "config/" CONFIG_NAME = "train_det" def build_datasets(cfg: DictConfig, agent: AbstractAgent) -> Tuple[Dataset, Dataset]: train_scene_filter: SceneFilter = instantiate(cfg.scene_filter) train_scene_filter.log_names = cfg.train_logs val_scene_filter: SceneFilter = instantiate(cfg.scene_filter) val_scene_filter.log_names = cfg.val_logs data_path = Path(cfg.navsim_log_path) sensor_blobs_path = Path(cfg.sensor_blobs_path) train_scene_loader = SceneLoader( sensor_blobs_path=sensor_blobs_path, data_path=data_path, scene_filter=train_scene_filter, sensor_config=agent.get_sensor_config(), ) val_scene_loader = SceneLoader( sensor_blobs_path=sensor_blobs_path, data_path=data_path, scene_filter=val_scene_filter, sensor_config=agent.get_sensor_config(), ) train_data = DetDataset( scene_loader=train_scene_loader, feature_builders=agent.get_feature_builders(), target_builders=agent.get_target_builders(), pipelines=agent.pipelines, is_train=True ) val_data = DetDataset( scene_loader=val_scene_loader, feature_builders=agent.get_feature_builders(), target_builders=agent.get_target_builders(), pipelines=agent.pipelines, is_train=False ) return train_data, val_data @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) def main(cfg: DictConfig) -> None: logger.info("Global Seed set to 0") pl.seed_everything(0, workers=True) logger.info(f"Path where all results are stored: {cfg.output_dir}") logger.info("Building Agent") agent: AbstractAgent = instantiate(cfg.agent) logger.info("Building Lightning Module") lightning_module = AgentLightningModule( agent=agent, ) logger.info("Building SceneLoader") train_data, val_data = build_datasets(cfg, agent) logger.info("Building Datasets") train_dataloader = DataLoader(train_data, **cfg.dataloader.params, shuffle=True, collate_fn=collate_fn_pad_lidar) logger.info("Num training samples: %d", len(train_data)) val_dataloader = DataLoader(val_data, **cfg.dataloader.params, shuffle=False, collate_fn=collate_fn_pad_lidar) logger.info("Num validation samples: %d", len(val_data)) logger.info("Building Trainer") trainer = pl.Trainer(**cfg.trainer.params) logger.info("Starting Training") trainer.fit( model=lightning_module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ) if __name__ == "__main__": main()