from __future__ import annotations from typing import Any, List, Dict import torch import torch.optim as optim import copy from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler import torch.nn as nn from det_map.data.datasets.dataclasses import SensorConfig, Scene from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder from navsim.agents.abstract_agent import AbstractAgent from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask import torch.nn.functional as F from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter from det_map.det.dal.mmdet3d.models import builder from mmcv.utils import TORCH_VERSION, digit_version from typing import Any, List, Dict import numpy as np import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from det_map.data.datasets.dataclasses import SensorConfig, Scene from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder from det_map.map.map_target import MapTargetBuilder from navsim.agents.abstract_agent import AbstractAgent from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder import torch.optim as optim try: from det_map.map.assigners import * from det_map.map.dense_heads import * from det_map.map.losses import * from det_map.map.modules import * except Exception: raise Exception class MapAgent(AbstractAgent): def __init__( self, model, pipelines, lr: float, checkpoint_path: str = None, **kwargs ): super().__init__() # todo eval everything self.model = model self.pipelines = pipelines self._checkpoint_path = checkpoint_path self._lr = lr def name(self) -> str: """Inherited, see superclass.""" return self.__class__.__name__ def initialize(self) -> None: """Inherited, see superclass.""" state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"] self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()}) def get_sensor_config(self) -> SensorConfig: """Inherited, see superclass.""" return SensorConfig.build_all_sensors(True) def get_target_builders(self) -> List[AbstractTargetBuilder]: return [ MapTargetBuilder(), ] def get_feature_builders(self) -> List[AbstractFeatureBuilder]: return [ LiDARCameraFeatureBuilder(self.pipelines) ] def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.model(features) def compute_loss( self, features: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], tokens=None ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: losses = dict() # depth = predictions.pop('depth') # if "gt_depth" in targets: # gt_depth = targets["gt_depth"] # loss_depth = self.pts_bbox_head.transformer.encoder.get_depth_loss(gt_depth, depth) # if digit_version(TORCH_VERSION) >= digit_version('1.8'): # loss_depth = torch.nan_to_num(loss_depth) # losses.update(loss_depth=loss_depth) gt_bboxes_3d = targets["gt_bboxes_3d"] gt_labels_3d = targets["gt_labels_3d"] # print(type(gt_labels_3d)) # gt_labels_3d = torch.tensor(gt_labels_3d) #import pdb; #pdb.set_trace() #gt_labels_3d = None gt_seg_mask = None gt_pv_seg_mask = None # gt_seg_mask = targets["gt_seg_mask"] # gt_pv_seg_mask = targets["gt_pv_seg_mask"] #import pdb; # pdb.set_trace() loss_inputs = [gt_bboxes_3d, gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, predictions] losses_pts = self.model.pts_bbox_head.loss(*loss_inputs, img_metas=None) losses.update(losses_pts) k_one2many = self.model.pts_bbox_head.k_one2many multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d) multi_gt_labels_3d = copy.deepcopy(gt_labels_3d) # multi_gt_labels_3d = torch.zeros((gt_labels_3d.size(0), gt_labels_3d.size(1) * k_one2many)) for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)): each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many) one2many_outs = predictions['one2many_outs'] loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, one2many_outs] loss_dict_one2many = self.model.pts_bbox_head.loss(*loss_one2many_inputs, img_metas=None) lambda_one2many = self.model.pts_bbox_head.lambda_one2many for key, value in loss_dict_one2many.items(): if key + "_one2many" in losses.keys(): losses[key + "_one2many"] += value * lambda_one2many else: losses[key + "_one2many"] = value * lambda_one2many loss = 0 for k, v in losses.items(): loss = loss + v return loss, losses def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]: optimizer = initialize_optimizer(self.model, self._lr) return {'optimizer': optimizer} def initialize_optimizer(model, lr): optimizer = optim.AdamW([ {'params': [param for name, param in model.named_parameters() if 'img_backbone' in name], 'lr': lr * 0.1}, {'params': [param for name, param in model.named_parameters() if 'img_backbone' not in name], 'lr': lr}, ], weight_decay=0.01) return optimizer