File size: 6,631 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from typing import Dict
import torch
import torch.nn.functional as F
from navsim.agents.vadv2.vadv2_config import Vadv2Config
from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes
def hydra_loss (
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
loss_val, loss = hydra_kd_imi_agent_loss(targets, predictions, config, vocab_pdm_score)
loss_one2many_val, loss_one2many = hydra_kd_imi_agent_loss_one2many(targets, predictions, config, vocab_pdm_score)
loss.update(loss_one2many)
return loss_val + loss_one2many_val, loss
def hydra_kd_imi_agent_loss(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
predictions['ttc'],
predictions['comfort'], predictions['progress'])
ddc, lk, tl = predictions['ddc'], predictions['lk'], predictions['tl']
imi = predictions['imi']
# 2 cls
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
#expansion
ddc_loss = F.binary_cross_entropy(ddc, three_to_two_classes(vocab_pdm_score['ddc'].to(da.dtype)))
lk_loss = F.binary_cross_entropy(lk, vocab_pdm_score['lk'].to(progress.dtype))
tl_loss = F.binary_cross_entropy(tl, vocab_pdm_score['tl'].to(da.dtype))
vocab = predictions["trajectory_vocab"]
# B, 8 (4 secs, 0.5Hz), 3
target_traj = targets["trajectory"]
# 4, 9, ..., 39
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
B = target_traj.shape[0]
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
imi_loss_final = config.trajectory_imi_weight * imi_loss
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
#expansion
ddc_loss_final = config.trajectory_pdm_weight['ddc'] * ddc_loss
lk_loss_final = config.trajectory_pdm_weight['lk'] * lk_loss
tl_loss_final = config.trajectory_pdm_weight['tl'] * tl_loss
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
agent_class_loss_final = config.agent_class_weight * agent_class_loss
agent_box_loss_final = config.agent_box_weight * agent_box_loss
loss = (
imi_loss_final
+ noc_loss_final
+ da_loss_final
+ ttc_loss_final
+ progress_loss_final
+ comfort_loss_final
+ agent_class_loss_final
+ agent_box_loss_final
+ ddc_loss_final
+ lk_loss_final
+ tl_loss_final
)
return loss, {
'imi_loss': imi_loss_final,
'pdm_noc_loss': noc_loss_final,
'pdm_da_loss': da_loss_final,
'pdm_ttc_loss': ttc_loss_final,
'pdm_progress_loss': progress_loss_final,
'pdm_ddc_loss': ddc_loss_final,
'pdm_lk_loss': lk_loss_final,
'pdm_tl_loss': tl_loss_final,
'pdm_comfort_loss': comfort_loss_final,
'agent_class_loss': agent_class_loss_final,
'agent_box_loss': agent_box_loss_final,
}
def hydra_kd_imi_agent_loss_one2many(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
# noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
# predictions['ttc'],
# predictions['comfort'], predictions['progress'])
imi = predictions['imi']
# 2 cls
# da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
# ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
# comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
# noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
# progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
vocab = predictions["trajectory_vocab"]
# B, 8 (4 secs, 0.5Hz), 3
target_traj = targets["trajectory"]
# 4, 9, ..., 39
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
B = target_traj.shape[0]
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
imi_loss_final = config.trajectory_imi_weight * imi_loss * 0.5
# noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
# da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
# ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
# progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
# comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
# agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
# agent_class_loss_final = config.agent_class_weight * agent_class_loss
# agent_box_loss_final = config.agent_box_weight * agent_box_loss
loss = (
imi_loss_final
)
return loss, {
'imi_loss': imi_loss_final,
} |