from typing import Dict import torch import torch.nn.functional as F from scipy.optimize import linear_sum_assignment from navsim.agents.transfuser.transfuser_config import TransfuserConfig from navsim.agents.vadv2.vadv2_config import Vadv2Config def vadv2_loss_pdm_ablate( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ total = predictions['total'] imi = predictions['imi'] # 2 cls pdmtotal_loss = F.binary_cross_entropy(total, vocab_pdm_score['total'].to(total.dtype)) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss pdmtotal_loss_final = config.trajectory_pdm_weight['total'] * pdmtotal_loss agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) bev_semantic_loss = F.cross_entropy( predictions["bev_semantic_map"], targets["bev_semantic_map"].long() ) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss loss = ( imi_loss_final + pdmtotal_loss_final + agent_class_loss_final + agent_box_loss_final + bev_semantic_loss_final ) return loss, { 'imi_loss': imi_loss_final, 'pdmtotal_loss': pdmtotal_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, 'bev_semantic_loss': bev_semantic_loss_final } def vadv2_loss_center_woper( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ pred_dist = predictions["trajectory_distribution"] # cb_weight = predictions["cb_weight"].to(pred_dist.device) # vocab_pdm_score = torch.from_numpy(vocab_pdm_score).to(pred_dist.device) # todo sample weights https://medium.com/@matrixB/modified-cross-entropy-loss-for-multi-label-classification-with-class-a8afede21eb9 # todo put regressed traj into vocab and calculate loss together # todo more gaussian parameters # center-based loss B, N_VOCAB = pred_dist.shape # 4096, 40 (4 secs, 0.1Hz), 3 vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma trajectory_loss = F.cross_entropy(pred_dist, l2_distance.sum((-2, -1)).softmax(1)) trajectory_imi_loss_final = config.trajectory_imi_weight * trajectory_loss loss = ( trajectory_imi_loss_final ) return loss, { 'trajectory_imi_loss': trajectory_imi_loss_final, } def vadv2_loss_center( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ pred_dist = predictions["trajectory_distribution"] # cb_weight = predictions["cb_weight"].to(pred_dist.device) # vocab_pdm_score = torch.from_numpy(vocab_pdm_score).to(pred_dist.device) # todo sample weights https://medium.com/@matrixB/modified-cross-entropy-loss-for-multi-label-classification-with-class-a8afede21eb9 # todo put regressed traj into vocab and calculate loss together # todo more gaussian parameters # center-based loss B, N_VOCAB = pred_dist.shape # 4096, 40 (4 secs, 0.1Hz), 3 vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma trajectory_loss = F.cross_entropy(pred_dist, l2_distance.sum((-2, -1)).softmax(1)) agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) bev_semantic_loss = F.cross_entropy( predictions["bev_semantic_map"], targets["bev_semantic_map"].long() ) trajectory_imi_loss_final = config.trajectory_imi_weight * trajectory_loss agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss loss = ( trajectory_imi_loss_final + agent_class_loss_final + agent_box_loss_final + bev_semantic_loss_final ) return loss, { 'trajectory_imi_loss': trajectory_imi_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, 'bev_semantic_loss': bev_semantic_loss_final } def vadv2_loss_ori( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ pred_dist = predictions["trajectory_distribution"] # cb_weight = predictions["cb_weight"].to(pred_dist.device) # ############################### 2. Ori Vad v2 ################################################################# B, N_SAMPLES = pred_dist.shape # vocab = predictions["trajectory_vocab"] # log_replay_traj = targets["trajectory"] # sampled_timepoints = [5 * k - 1 for k in range(1, 9)] # l2_imi = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - log_replay_traj[:, None]) ** 2).sum((-2, -1)) # l2_imi = 1 - l2_imi.exp() target_dist = torch.zeros((B, config.vocab_size), dtype=pred_dist.dtype, device=pred_dist.device) mask = torch.eye(B, dtype=pred_dist.dtype, device=pred_dist.device) target_dist = torch.cat([target_dist, mask], dim=-1).contiguous() trajectory_loss = F.cross_entropy(pred_dist, target_dist, reduction='mean') agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) bev_semantic_loss = F.cross_entropy( predictions["bev_semantic_map"], targets["bev_semantic_map"].long() ) trajectory_pdm_loss_final = config.trajectory_imi_weight * trajectory_loss agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss loss = ( trajectory_pdm_loss_final + agent_class_loss_final + agent_box_loss_final + bev_semantic_loss_final ) return loss, { 'trajectory_pdm_loss': trajectory_pdm_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, 'bev_semantic_loss': bev_semantic_loss_final } def three_to_two_classes(x): x[x==0.5] = 0.0 return x def vadv2_loss_pdm_wo_progress( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ noc, da, dd, ttc, comfort = (predictions['noc'], predictions['da'], predictions['dd'], predictions['ttc'], predictions['comfort']) imi = predictions['imi'] # 2 cls da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) # 3 cls -> 2 cls ?? noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) dd_loss = F.binary_cross_entropy(dd, three_to_two_classes(vocab_pdm_score['dd'].to(da.dtype))) # regression # progress_weight = torch.ones_like(progress) # progress_target = vocab_pdm_score['progress'].to(da.dtype) # mask_0_5 = progress_target <= 0.5 # mask_5_8 = (progress_target > 0.5).logical_and(progress_target <= 0.8) # mask_8_1 = progress_target > 0.8 # progress_weight[mask_0_5] = 0.36 # progress_weight[mask_5_8] = 5.73 # progress_weight[mask_8_1] = 20.19 # progress_loss = F.binary_cross_entropy(progress, progress_target, # weight=progress_weight) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss da_loss_final = config.trajectory_pdm_weight['da'] * da_loss dd_loss_final = config.trajectory_pdm_weight['dd'] * dd_loss ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss # progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) bev_semantic_loss = F.cross_entropy( predictions["bev_semantic_map"], targets["bev_semantic_map"].long() ) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss loss = ( imi_loss_final + noc_loss_final + da_loss_final + dd_loss_final + ttc_loss_final # + progress_loss_final + comfort_loss_final + agent_class_loss_final + agent_box_loss_final + bev_semantic_loss_final ) return loss, { 'imi_loss': imi_loss_final, 'pdm_noc_loss': noc_loss_final, 'pdm_da_loss': da_loss_final, 'pdm_dd_loss': dd_loss_final, 'pdm_ttc_loss': ttc_loss_final, # 'pdm_progress_loss': progress_loss_final, 'pdm_comfort_loss': comfort_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, 'bev_semantic_loss': bev_semantic_loss_final } def vadv2_loss_pdm_w_progress( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], predictions['ttc'], predictions['comfort'], predictions['progress']) imi = predictions['imi'] # 2 cls da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) # 3 cls -> 2 cls ?? noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) # regression # progress_weight = torch.ones_like(progress) # progress_target = vocab_pdm_score['progress'].to(da.dtype) # mask_0_5 = progress_target <= 0.5 # mask_5_8 = (progress_target > 0.5).logical_and(progress_target <= 0.8) # mask_8_1 = progress_target > 0.8 # progress_weight[mask_0_5] = 0.36 # progress_weight[mask_5_8] = 5.73 # progress_weight[mask_8_1] = 20.19 progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype)) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss da_loss_final = config.trajectory_pdm_weight['da'] * da_loss ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) bev_semantic_loss = F.cross_entropy( predictions["bev_semantic_map"], targets["bev_semantic_map"].long() ) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss loss = ( imi_loss_final + noc_loss_final + da_loss_final + ttc_loss_final + progress_loss_final + comfort_loss_final + agent_class_loss_final + agent_box_loss_final + bev_semantic_loss_final ) return loss, { 'imi_loss': imi_loss_final, 'pdm_noc_loss': noc_loss_final, 'pdm_da_loss': da_loss_final, 'pdm_ttc_loss': ttc_loss_final, 'pdm_progress_loss': progress_loss_final, 'pdm_comfort_loss': comfort_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, 'bev_semantic_loss': bev_semantic_loss_final } def vadv2_loss_pdm( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ noc, da, dd, ttc, comfort, progress = (predictions['noc'], predictions['da'], predictions['dd'], predictions['ttc'], predictions['comfort'], predictions['progress']) # 2 cls da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) # 3 cls -> 2 cls ?? noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) dd_loss = F.binary_cross_entropy(dd, three_to_two_classes(vocab_pdm_score['dd'].to(da.dtype))) # regression progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(da.dtype)) noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss da_loss_final = config.trajectory_pdm_weight['da'] * da_loss dd_loss_final = config.trajectory_pdm_weight['dd'] * dd_loss ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) bev_semantic_loss = F.cross_entropy( predictions["bev_semantic_map"], targets["bev_semantic_map"].long() ) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss loss = ( noc_loss_final + da_loss_final + dd_loss_final + ttc_loss_final + progress_loss_final + comfort_loss_final + agent_class_loss_final + agent_box_loss_final + bev_semantic_loss_final ) return loss, { 'pdm_noc_loss': noc_loss_final, 'pdm_da_loss': da_loss_final, 'pdm_dd_loss': dd_loss_final, 'pdm_ttc_loss': ttc_loss_final, 'pdm_progress_loss': progress_loss_final, 'pdm_comfort_loss': comfort_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, 'bev_semantic_loss': bev_semantic_loss_final } def _agent_loss( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: TransfuserConfig ): """ Hungarian matching loss for agent detection :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: detection loss """ gt_states, gt_valid = targets["agent_states"], targets["agent_labels"] pred_states, pred_logits = predictions["agent_states"], predictions["agent_labels"] # save constants batch_dim, num_instances = pred_states.shape[:2] num_gt_instances = gt_valid.sum() num_gt_instances = num_gt_instances if num_gt_instances > 0 else num_gt_instances + 1 ce_cost = _get_ce_cost(gt_valid, pred_logits) l1_cost = _get_l1_cost(gt_states, pred_states, gt_valid) cost = config.agent_class_weight * ce_cost + config.agent_box_weight * l1_cost cost = cost.cpu() indices = [linear_sum_assignment(c) for i, c in enumerate(cost)] matching = [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] idx = _get_src_permutation_idx(matching) pred_states_idx = pred_states[idx] gt_states_idx = torch.cat([t[i] for t, (_, i) in zip(gt_states, indices)], dim=0) pred_valid_idx = pred_logits[idx] gt_valid_idx = torch.cat([t[i] for t, (_, i) in zip(gt_valid, indices)], dim=0).float() l1_loss = F.l1_loss(pred_states_idx, gt_states_idx, reduction="none") l1_loss = l1_loss.sum(-1) * gt_valid_idx l1_loss = l1_loss.view(batch_dim, -1).sum() / num_gt_instances ce_loss = F.binary_cross_entropy_with_logits(pred_valid_idx, gt_valid_idx, reduction="none") ce_loss = ce_loss.view(batch_dim, -1).mean() return ce_loss, l1_loss @torch.no_grad() def _get_ce_cost(gt_valid: torch.Tensor, pred_logits: torch.Tensor) -> torch.Tensor: """ Function to calculate cross-entropy cost for cost matrix. :param gt_valid: tensor of binary ground-truth labels :param pred_logits: tensor of predicted logits of neural net :return: bce cost matrix as tensor """ # NOTE: numerically stable BCE with logits # https://github.com/pytorch/pytorch/blob/c64e006fc399d528bb812ae589789d0365f3daf4/aten/src/ATen/native/Loss.cpp#L214 gt_valid_expanded = gt_valid[:, :, None].detach().float() # (b, n, 1) pred_logits_expanded = pred_logits[:, None, :].detach() # (b, 1, n) max_val = torch.relu(-pred_logits_expanded) helper_term = max_val + torch.log( torch.exp(-max_val) + torch.exp(-pred_logits_expanded - max_val) ) ce_cost = (1 - gt_valid_expanded) * pred_logits_expanded + helper_term # (b, n, n) ce_cost = ce_cost.permute(0, 2, 1) return ce_cost @torch.no_grad() def _get_l1_cost( gt_states: torch.Tensor, pred_states: torch.Tensor, gt_valid: torch.Tensor ) -> torch.Tensor: """ Function to calculate L1 cost for cost matrix. :param gt_states: tensor of ground-truth bounding boxes :param pred_states: tensor of predicted bounding boxes :param gt_valid: mask of binary ground-truth labels :return: l1 cost matrix as tensor """ gt_states_expanded = gt_states[:, :, None, :2].detach() # (b, n, 1, 2) pred_states_expanded = pred_states[:, None, :, :2].detach() # (b, 1, n, 2) l1_cost = gt_valid[..., None].float() * (gt_states_expanded - pred_states_expanded).abs().sum( dim=-1 ) l1_cost = l1_cost.permute(0, 2, 1) return l1_cost def _get_src_permutation_idx(indices): """ Helper function to align indices after matching :param indices: matched indices :return: permuted indices """ # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx