File size: 5,538 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import Dict
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn.functional as F
from navsim.agents.transfuser.transfuser_config import TransfuserConfig
def transfuser_loss(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: TransfuserConfig
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
trajectory_loss = F.l1_loss(predictions["trajectory"], targets["trajectory"])
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
bev_semantic_loss = F.cross_entropy(
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
)
loss = (
config.trajectory_imi_weight * trajectory_loss
+ config.agent_class_weight * agent_class_loss
+ config.agent_box_weight * agent_box_loss
+ config.bev_semantic_weight * bev_semantic_loss
)
return loss, {
'trajectory_loss': config.trajectory_imi_weight * trajectory_loss,
'agent_class_loss': config.agent_class_weight * agent_class_loss,
'agent_box_loss': config.agent_box_weight * agent_box_loss,
'bev_semantic_loss': config.bev_semantic_weight * bev_semantic_loss
}
def _agent_loss(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: TransfuserConfig
):
"""
Hungarian matching loss for agent detection
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: detection loss
"""
gt_states, gt_valid = targets["agent_states"], targets["agent_labels"]
pred_states, pred_logits = predictions["agent_states"], predictions["agent_labels"]
# save constants
batch_dim, num_instances = pred_states.shape[:2]
num_gt_instances = gt_valid.sum()
num_gt_instances = num_gt_instances if num_gt_instances > 0 else num_gt_instances + 1
ce_cost = _get_ce_cost(gt_valid, pred_logits)
l1_cost = _get_l1_cost(gt_states, pred_states, gt_valid)
cost = config.agent_class_weight * ce_cost + config.agent_box_weight * l1_cost
cost = cost.cpu()
indices = [linear_sum_assignment(c) for i, c in enumerate(cost)]
matching = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices
]
idx = _get_src_permutation_idx(matching)
pred_states_idx = pred_states[idx]
gt_states_idx = torch.cat([t[i] for t, (_, i) in zip(gt_states, indices)], dim=0)
pred_valid_idx = pred_logits[idx]
gt_valid_idx = torch.cat([t[i] for t, (_, i) in zip(gt_valid, indices)], dim=0).float()
l1_loss = F.l1_loss(pred_states_idx, gt_states_idx, reduction="none")
l1_loss = l1_loss.sum(-1) * gt_valid_idx
l1_loss = l1_loss.view(batch_dim, -1).sum() / num_gt_instances
ce_loss = F.binary_cross_entropy_with_logits(pred_valid_idx, gt_valid_idx, reduction="none")
ce_loss = ce_loss.view(batch_dim, -1).mean()
return ce_loss, l1_loss
@torch.no_grad()
def _get_ce_cost(gt_valid: torch.Tensor, pred_logits: torch.Tensor) -> torch.Tensor:
"""
Function to calculate cross-entropy cost for cost matrix.
:param gt_valid: tensor of binary ground-truth labels
:param pred_logits: tensor of predicted logits of neural net
:return: bce cost matrix as tensor
"""
# NOTE: numerically stable BCE with logits
# https://github.com/pytorch/pytorch/blob/c64e006fc399d528bb812ae589789d0365f3daf4/aten/src/ATen/native/Loss.cpp#L214
gt_valid_expanded = gt_valid[:, :, None].detach().float() # (b, n, 1)
pred_logits_expanded = pred_logits[:, None, :].detach() # (b, 1, n)
max_val = torch.relu(-pred_logits_expanded)
helper_term = max_val + torch.log(
torch.exp(-max_val) + torch.exp(-pred_logits_expanded - max_val)
)
ce_cost = (1 - gt_valid_expanded) * pred_logits_expanded + helper_term # (b, n, n)
ce_cost = ce_cost.permute(0, 2, 1)
return ce_cost
@torch.no_grad()
def _get_l1_cost(
gt_states: torch.Tensor, pred_states: torch.Tensor, gt_valid: torch.Tensor
) -> torch.Tensor:
"""
Function to calculate L1 cost for cost matrix.
:param gt_states: tensor of ground-truth bounding boxes
:param pred_states: tensor of predicted bounding boxes
:param gt_valid: mask of binary ground-truth labels
:return: l1 cost matrix as tensor
"""
gt_states_expanded = gt_states[:, :, None, :2].detach() # (b, n, 1, 2)
pred_states_expanded = pred_states[:, None, :, :2].detach() # (b, 1, n, 2)
l1_cost = gt_valid[..., None].float() * (gt_states_expanded - pred_states_expanded).abs().sum(
dim=-1
)
l1_cost = l1_cost.permute(0, 2, 1)
return l1_cost
def _get_src_permutation_idx(indices):
"""
Helper function to align indices after matching
:param indices: matched indices
:return: permuted indices
"""
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
|