|
from typing import Dict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config |
|
from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes |
|
|
|
def hydra_loss ( |
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, |
|
vocab_pdm_score |
|
): |
|
loss_val, loss = hydra_kd_imi_agent_loss(targets, predictions, config, vocab_pdm_score) |
|
loss_one2many_val, loss_one2many = hydra_kd_imi_agent_loss_one2many(targets, predictions, config, vocab_pdm_score) |
|
loss.update(loss_one2many) |
|
return loss_val + loss_one2many_val, loss |
|
|
|
def hydra_kd_imi_agent_loss( |
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, |
|
vocab_pdm_score |
|
): |
|
""" |
|
Helper function calculating complete loss of Transfuser |
|
:param targets: dictionary of name tensor pairings |
|
:param predictions: dictionary of name tensor pairings |
|
:param config: global Transfuser config |
|
:return: combined loss value |
|
""" |
|
|
|
noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], |
|
predictions['ttc'], |
|
predictions['comfort'], predictions['progress']) |
|
ddc, lk, tl = predictions['ddc'], predictions['lk'], predictions['tl'] |
|
imi = predictions['imi'] |
|
|
|
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) |
|
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) |
|
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) |
|
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) |
|
progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype)) |
|
|
|
ddc_loss = F.binary_cross_entropy(ddc, three_to_two_classes(vocab_pdm_score['ddc'].to(da.dtype))) |
|
lk_loss = F.binary_cross_entropy(lk, vocab_pdm_score['lk'].to(progress.dtype)) |
|
tl_loss = F.binary_cross_entropy(tl, vocab_pdm_score['tl'].to(da.dtype)) |
|
|
|
vocab = predictions["trajectory_vocab"] |
|
|
|
target_traj = targets["trajectory"] |
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)] |
|
B = target_traj.shape[0] |
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma |
|
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) |
|
|
|
imi_loss_final = config.trajectory_imi_weight * imi_loss |
|
|
|
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss |
|
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss |
|
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss |
|
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss |
|
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss |
|
|
|
ddc_loss_final = config.trajectory_pdm_weight['ddc'] * ddc_loss |
|
lk_loss_final = config.trajectory_pdm_weight['lk'] * lk_loss |
|
tl_loss_final = config.trajectory_pdm_weight['tl'] * tl_loss |
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) |
|
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss |
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss |
|
loss = ( |
|
imi_loss_final |
|
+ noc_loss_final |
|
+ da_loss_final |
|
+ ttc_loss_final |
|
+ progress_loss_final |
|
+ comfort_loss_final |
|
+ agent_class_loss_final |
|
+ agent_box_loss_final |
|
+ ddc_loss_final |
|
+ lk_loss_final |
|
+ tl_loss_final |
|
) |
|
return loss, { |
|
'imi_loss': imi_loss_final, |
|
'pdm_noc_loss': noc_loss_final, |
|
'pdm_da_loss': da_loss_final, |
|
'pdm_ttc_loss': ttc_loss_final, |
|
'pdm_progress_loss': progress_loss_final, |
|
'pdm_ddc_loss': ddc_loss_final, |
|
'pdm_lk_loss': lk_loss_final, |
|
'pdm_tl_loss': tl_loss_final, |
|
'pdm_comfort_loss': comfort_loss_final, |
|
'agent_class_loss': agent_class_loss_final, |
|
'agent_box_loss': agent_box_loss_final, |
|
} |
|
|
|
|
|
def hydra_kd_imi_agent_loss_one2many( |
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, |
|
vocab_pdm_score |
|
): |
|
""" |
|
Helper function calculating complete loss of Transfuser |
|
:param targets: dictionary of name tensor pairings |
|
:param predictions: dictionary of name tensor pairings |
|
:param config: global Transfuser config |
|
:return: combined loss value |
|
""" |
|
|
|
|
|
|
|
|
|
imi = predictions['imi'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab = predictions["trajectory_vocab"] |
|
|
|
target_traj = targets["trajectory"] |
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)] |
|
B = target_traj.shape[0] |
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma |
|
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) |
|
|
|
imi_loss_final = config.trajectory_imi_weight * imi_loss * 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = ( |
|
imi_loss_final |
|
) |
|
return loss, { |
|
'imi_loss': imi_loss_final, |
|
} |