# Copyright (c) Meta Platforms, Inc. and affiliates. import numpy as np import torch from torch.nn.functional import normalize from . import get_model from models.base import BaseModel # from models.bev_net import BEVNet # from models.bev_projection import CartesianProjection, PolarProjectionDepth from models.voting import ( argmax_xyr, conv2d_fft_batchwise, expectation_xyr, log_softmax_spatial, mask_yaw_prior, nll_loss_xyr, nll_loss_xyr_smoothed, TemplateSampler, UAVTemplateSampler, UAVTemplateSamplerFast ) from .map_encoder import MapEncoder from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall class MapLocNet(BaseModel): default_conf = { "image_size": "???", "val_citys":"???", "image_encoder": "???", "map_encoder": "???", "bev_net": "???", "latent_dim": "???", "matching_dim": "???", "scale_range": [0, 9], "num_scale_bins": "???", "z_min": None, "z_max": "???", "x_max": "???", "pixel_per_meter": "???", "num_rotations": "???", "add_temperature": False, "normalize_features": False, "padding_matching": "replicate", "apply_map_prior": True, "do_label_smoothing": False, "sigma_xy": 1, "sigma_r": 2, # depcreated "depth_parameterization": "scale", "norm_depth_scores": False, "normalize_scores_by_dim": False, "normalize_scores_by_num_valid": True, "prior_renorm": True, "retrieval_dim": None, } def _init(self, conf): assert not self.conf.norm_depth_scores assert self.conf.depth_parameterization == "scale" assert not self.conf.normalize_scores_by_dim assert self.conf.normalize_scores_by_num_valid assert self.conf.prior_renorm Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2")) self.image_encoder = Encoder(conf.image_encoder.backbone) self.map_encoder = MapEncoder(conf.map_encoder) # self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net) ppm = conf.pixel_per_meter # self.projection_polar = PolarProjectionDepth( # conf.z_max, # ppm, # conf.scale_range, # conf.z_min, # ) # self.projection_bev = CartesianProjection( # conf.z_max, conf.x_max, ppm, conf.z_min # ) # self.template_sampler = TemplateSampler( # self.projection_bev.grid_xz, ppm, conf.num_rotations # ) # self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2) self.template_sampler = UAVTemplateSampler(conf.num_rotations) # self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins) # if conf.bev_net is None: # self.feature_projection = torch.nn.Linear( # conf.latent_dim, conf.matching_dim # ) if conf.add_temperature: temperature = torch.nn.Parameter(torch.tensor(0.0)) self.register_parameter("temperature", temperature) def exhaustive_voting(self, f_bev, f_map): if self.conf.normalize_features: f_bev = normalize(f_bev, dim=1) f_map = normalize(f_map, dim=1) # Build the templates and exhaustively match against the map. # if confidence_bev is not None: # f_bev = f_bev * confidence_bev.unsqueeze(1) # f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0) # torch.save(f_bev, 'f_bev.pt') # torch.save(f_map, 'f_map.pt') templates = self.template_sampler(f_bev)#[batch,256,8,129,129] # torch.save(templates, 'templates.pt') with torch.autocast("cuda", enabled=False): scores = conv2d_fft_batchwise( f_map.float(), templates.float(), padding_mode=self.conf.padding_matching, ) if self.conf.add_temperature: scores = scores * torch.exp(self.temperature) # Reweight the different rotations based on the number of valid pixels # in each template. Axis-aligned rotation have the maximum number of valid pixels. # valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4) # num_valid = valid_templates.float().sum((-3, -2, -1)) # scores = scores / num_valid[..., None, None] return scores def _forward(self, data): pred = {} pred_map = pred["map"] = self.map_encoder(data) f_map = pred_map["map_features"][0]#[batch,8,256,256] # Extract image features. level = 0 f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176] # print("f_map:",f_map.shape) scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129] scores = scores.moveaxis(1, -1) # B,H,W,N if "log_prior" in pred_map and self.conf.apply_map_prior: scores = scores + pred_map["log_prior"][0].unsqueeze(-1) # pred["scores_unmasked"] = scores.clone() if "map_mask" in data: scores.masked_fill_(~data["map_mask"][..., None], -np.inf) if "yaw_prior" in data: mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) log_probs = log_softmax_spatial(scores) # torch.save(scores, 'scores.pt') with torch.no_grad(): uvr_max = argmax_xyr(scores).to(scores) uvr_avg, _ = expectation_xyr(log_probs.exp()) return { **pred, "scores": scores, "log_probs": log_probs, "uvr_max": uvr_max, "uv_max": uvr_max[..., :2], "yaw_max": uvr_max[..., 2], "uvr_expectation": uvr_avg, "uv_expectation": uvr_avg[..., :2], "yaw_expectation": uvr_avg[..., 2], "features_image": f_image, } def loss(self, pred, data): xy_gt = data["uv"] yaw_gt = data["roll_pitch_yaw"][..., -1] if self.conf.do_label_smoothing: nll = nll_loss_xyr_smoothed( pred["log_probs"], xy_gt, yaw_gt, self.conf.sigma_xy / self.conf.pixel_per_meter, self.conf.sigma_r, mask=data.get("map_mask"), ) else: nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt) loss = {"total": nll, "nll": nll} if self.training and self.conf.add_temperature: loss["temperature"] = self.temperature.expand(len(nll)) return loss def metrics(self): return { "xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter), "xy_expectation_error": Location2DError( "uv_expectation", self.conf.pixel_per_meter ), "yaw_max_error": AngleError("yaw_max"), "xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), "xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), "xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), # "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), # "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), # "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), # # "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), # "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), # "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), "yaw_recall_1°": AngleRecall(1.0, "yaw_max"), "yaw_recall_3°": AngleRecall(3.0, "yaw_max"), "yaw_recall_5°": AngleRecall(5.0, "yaw_max"), }