File size: 9,442 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
from __future__ import annotations

from typing import Any, List, Dict

import torch
import torch.optim as optim
import copy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import torch.nn as nn
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder

from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask

import torch.nn.functional as F

from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
from det_map.det.dal.mmdet3d.models import builder
from mmcv.utils import TORCH_VERSION, digit_version


class MapModel(nn.Module):
    def __init__(

            self,

            use_grid_mask=False,

            pts_voxel_layer=None,

            pts_voxel_encoder=None,

            pts_middle_encoder=None,

            pts_fusion_layer=None,

            img_backbone=None,

            pts_backbone=None,

            img_neck=None,

            pts_neck=None,

            pts_bbox_head=None,

            img_roi_head=None,

            img_rpn_head=None,

            train_cfg=None,

            test_cfg=None,

            pretrained=None,

            video_test_mode=False,

            modality='vision',

            lidar_encoder=None,

            lr=None,

    ):
        super().__init__()
        # self.pipelines = pipelines
        self.grid_mask = GridMask(
            True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
        if pts_voxel_layer:
            self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
        if pts_voxel_encoder:
            self.pts_voxel_encoder = builder.build_voxel_encoder(
                pts_voxel_encoder)
        if pts_middle_encoder:
            self.pts_middle_encoder = builder.build_middle_encoder(
                pts_middle_encoder)
        if pts_backbone:
            self.pts_backbone = builder.build_backbone(pts_backbone)
        if pts_fusion_layer:
            self.pts_fusion_layer = builder.build_fusion_layer(
                pts_fusion_layer)
        if pts_neck is not None:
            self.pts_neck = builder.build_neck(pts_neck)
        if pts_bbox_head:
            pts_train_cfg = None
            pts_bbox_head.update(train_cfg=pts_train_cfg)
            pts_test_cfg =  None
            pts_bbox_head.update(test_cfg=pts_test_cfg)
            self.pts_bbox_head = builder.build_head(pts_bbox_head)
        if img_backbone:
            self.img_backbone = builder.build_backbone(img_backbone)
        if img_neck is not None:
            self.img_neck = builder.build_neck(img_neck)
        if img_rpn_head is not None:
            self.img_rpn_head = builder.build_head(img_rpn_head)
        if img_roi_head is not None:
            self.img_roi_head = builder.build_head(img_roi_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        if pretrained is None:
            img_pretrained = None
            pts_pretrained = None
        elif isinstance(pretrained, dict):
            img_pretrained = pretrained.get('img', None)
            pts_pretrained = pretrained.get('pts', None)
        else:
            raise ValueError(
                f'pretrained should be a dict, got {type(pretrained)}')

        self.use_grid_mask = use_grid_mask
        self.fp16_enabled = False

        # temporal
        self.video_test_mode = video_test_mode
        self.prev_frame_info = {
            'prev_bev': None,
            'scene_token': None,
            'prev_pos': 0,
            'prev_angle': 0,
        }
        self.modality = modality
        if self.modality == 'fusion' and lidar_encoder is not None:
            if lidar_encoder["voxelize"].get("max_num_points", -1) > 0:
                voxelize_module = Voxelization(**lidar_encoder["voxelize"])
            else:
                voxelize_module = DynamicScatter(**lidar_encoder["voxelize"])
            self.lidar_modal_extractor = nn.ModuleDict(
                {
                    "voxelize": voxelize_module,
                    "backbone": builder.build_middle_encoder(lidar_encoder["backbone"]),
                }
            )
            self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True)

        self._lr = lr


    def extract_img_feat(self, img, img_metas=None, len_queue=None):
        """Extract features of images."""
        B = img.size(0)
        if img is not None:

            # input_shape = img.shape[-2:]
            # # update real input shape of each single img
            # for img_meta in img_metas:
            #     img_meta.update(input_shape=input_shape)

            if img.dim() == 5 and img.size(0) == 1:
                img.squeeze_()
            elif img.dim() == 5 and img.size(0) > 1:
                B, N, C, H, W = img.size()
                img = img.reshape(B * N, C, H, W)
            if self.use_grid_mask:
                img = self.grid_mask(img)

            img_feats = self.img_backbone(img)
            if isinstance(img_feats, dict):
                img_feats = list(img_feats.values())
        else:
            return None

        self.with_img_neck = True
        if self.with_img_neck:
            img_feats = self.img_neck(img_feats)

        BN, C, H, W = img_feats[0].shape
        return [tmp.view(B, BN // B, C, H , W) for tmp in img_feats]

    @torch.no_grad()
    def voxelize(self, points):
        feats, coords, sizes = [], [], []
        for k, res in enumerate(points):
            ret = self.lidar_modal_extractor["voxelize"](res)
            if len(ret) == 3:
                # hard voxelize
                f, c, n = ret
            else:
                assert len(ret) == 2
                f, c = ret
                n = None
            feats.append(f)
            coords.append(F.pad(c, (1, 0), mode="constant", value=k))
            if n is not None:
                sizes.append(n)

        feats = torch.cat(feats, dim=0)
        coords = torch.cat(coords, dim=0)
        if len(sizes) > 0:
            sizes = torch.cat(sizes, dim=0)
            if self.voxelize_reduce:
                feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
                    -1, 1
                )
                feats = feats.contiguous()

        return feats, coords, sizes

    def extract_lidar_feat(self, points):
        feats, coords, sizes = self.voxelize(points)
        # voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
        batch_size = coords[-1, 0] + 1
        lidar_feat = self.lidar_modal_extractor["backbone"](feats, coords, batch_size)

        return lidar_feat

    def forward(self, feature_dict=None, points=None, img_metas=None) -> Dict[str, torch.Tensor]:
        lidar_feat = None
        # points = feature_dict['lidars_warped']
        # points_input = []
        # for tmp in points:
        #     points_input.append(torch.cat(tmp, 0))
        if self.modality == 'fusion':
            lidar_feat = self.extract_lidar_feat(points_input)


        img = feature_dict['image']
        len_queue = img.size(1)
        img = img[:, -1, ...]
        img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)

        outs = self.pts_bbox_head(
            img_feats, lidar_feat, feature_dict, None)

        return outs


# class MyLightningModule(pl.LightningModule):
#     def __init__(
#         self,
#         agent: AbstractAgent,
#     ):
#         super().__init__()
#         self.agent = agent

#     def _step(
#         self,
#         batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
#         logging_prefix: str,
#     ):
#         features, targets = batch
#         prediction = self.agent.forward(features)
#         loss = self.agent.compute_loss(features, targets, prediction)
#         self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
#         return loss

#     def training_step(
#         self,
#         batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
#         batch_idx: int
#     ):
#         return self._step(batch, "train")

#     def validation_step(
#         self,
#         batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
#         batch_idx: int
#     ):
#         return self._step(batch, "val")

#     def configure_optimizers(self):
#         optimizer = self.agent.get_optimizers()
#         # 应用梯度裁剪
#         if 'grad_clip' in self.optimizer_config:
#             grad_clip = self.optimizer_config['grad_clip']
#             max_norm = grad_clip.get('max_norm', 1.0)
#             norm_type = grad_clip.get('norm_type', 2)
#             optimizer = optim.Adam(self.parameters(), lr=1e-3)
#             return {
#                 'optimizer': optimizer,
#                 'clip_grad_norm': max_norm,
#                 'clip_grad_value': None,  # 可以使用 'clip_grad_value' 来限制梯度的绝对值
#             }
#         else:
#             return optimizerfrom __future__ import annotations