|
from __future__ import annotations
|
|
|
|
from typing import Any, List, Dict
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
import copy
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
import torch.nn as nn
|
|
from det_map.data.datasets.dataclasses import SensorConfig, Scene
|
|
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
|
|
|
|
from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
|
|
from det_map.det.dal.mmdet3d.models import builder
|
|
from mmcv.utils import TORCH_VERSION, digit_version
|
|
|
|
|
|
from typing import Any, List, Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
|
|
from det_map.data.datasets.dataclasses import SensorConfig, Scene
|
|
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
|
|
from det_map.map.map_target import MapTargetBuilder
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
|
|
import torch.optim as optim
|
|
|
|
try:
|
|
from det_map.map.assigners import *
|
|
from det_map.map.dense_heads import *
|
|
from det_map.map.losses import *
|
|
from det_map.map.modules import *
|
|
except Exception:
|
|
raise Exception
|
|
class MapAgent(AbstractAgent):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
pipelines,
|
|
lr: float,
|
|
checkpoint_path: str = None, **kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
self.pipelines = pipelines
|
|
self._checkpoint_path = checkpoint_path
|
|
self._lr = lr
|
|
|
|
def name(self) -> str:
|
|
"""Inherited, see superclass."""
|
|
|
|
return self.__class__.__name__
|
|
|
|
def initialize(self) -> None:
|
|
"""Inherited, see superclass."""
|
|
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
|
|
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
|
|
|
|
def get_sensor_config(self) -> SensorConfig:
|
|
"""Inherited, see superclass."""
|
|
return SensorConfig.build_all_sensors(True)
|
|
|
|
def get_target_builders(self) -> List[AbstractTargetBuilder]:
|
|
return [
|
|
MapTargetBuilder(),
|
|
]
|
|
|
|
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
|
|
return [
|
|
LiDARCameraFeatureBuilder(self.pipelines)
|
|
]
|
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
|
|
return self.model(features)
|
|
|
|
def compute_loss(
|
|
self,
|
|
features: Dict[str, torch.Tensor],
|
|
targets: Dict[str, torch.Tensor],
|
|
predictions: Dict[str, torch.Tensor],
|
|
tokens=None
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
|
losses = dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_bboxes_3d = targets["gt_bboxes_3d"]
|
|
gt_labels_3d = targets["gt_labels_3d"]
|
|
|
|
|
|
|
|
|
|
|
|
gt_seg_mask = None
|
|
gt_pv_seg_mask = None
|
|
|
|
|
|
|
|
|
|
loss_inputs = [gt_bboxes_3d, gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, predictions]
|
|
losses_pts = self.model.pts_bbox_head.loss(*loss_inputs, img_metas=None)
|
|
|
|
losses.update(losses_pts)
|
|
|
|
k_one2many = self.model.pts_bbox_head.k_one2many
|
|
multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d)
|
|
multi_gt_labels_3d = copy.deepcopy(gt_labels_3d)
|
|
|
|
for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)):
|
|
each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many
|
|
each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many
|
|
multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many)
|
|
one2many_outs = predictions['one2many_outs']
|
|
loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, one2many_outs]
|
|
loss_dict_one2many = self.model.pts_bbox_head.loss(*loss_one2many_inputs, img_metas=None)
|
|
|
|
lambda_one2many = self.model.pts_bbox_head.lambda_one2many
|
|
for key, value in loss_dict_one2many.items():
|
|
if key + "_one2many" in losses.keys():
|
|
losses[key + "_one2many"] += value * lambda_one2many
|
|
else:
|
|
losses[key + "_one2many"] = value * lambda_one2many
|
|
loss = 0
|
|
for k, v in losses.items():
|
|
loss = loss + v
|
|
return loss, losses
|
|
|
|
def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
|
|
optimizer = initialize_optimizer(self.model, self._lr)
|
|
return {'optimizer': optimizer}
|
|
|
|
def initialize_optimizer(model, lr):
|
|
optimizer = optim.AdamW([
|
|
{'params': [param for name, param in model.named_parameters() if 'img_backbone' in name], 'lr': lr * 0.1},
|
|
{'params': [param for name, param in model.named_parameters() if 'img_backbone' not in name], 'lr': lr},
|
|
], weight_decay=0.01)
|
|
return optimizer
|
|
|