navsim_ours / navsim /agents /hydra /hydra_loss_fn.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
from typing import Dict
import torch
import torch.nn.functional as F
from navsim.agents.vadv2.vadv2_config import Vadv2Config
from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes
def hydra_loss (
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
loss_val, loss = hydra_kd_imi_agent_loss(targets, predictions, config, vocab_pdm_score)
loss_one2many_val, loss_one2many = hydra_kd_imi_agent_loss_one2many(targets, predictions, config, vocab_pdm_score)
loss.update(loss_one2many)
return loss_val + loss_one2many_val, loss
def hydra_kd_imi_agent_loss(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
predictions['ttc'],
predictions['comfort'], predictions['progress'])
imi = predictions['imi']
# 2 cls
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
vocab = predictions["trajectory_vocab"]
# B, 8 (4 secs, 0.5Hz), 3
target_traj = targets["trajectory"]
# 4, 9, ..., 39
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
B = target_traj.shape[0]
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
imi_loss_final = config.trajectory_imi_weight * imi_loss
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
agent_class_loss_final = config.agent_class_weight * agent_class_loss
agent_box_loss_final = config.agent_box_weight * agent_box_loss
loss = (
imi_loss_final
+ noc_loss_final
+ da_loss_final
+ ttc_loss_final
+ progress_loss_final
+ comfort_loss_final
+ agent_class_loss_final
+ agent_box_loss_final
)
return loss, {
'imi_loss': imi_loss_final,
'pdm_noc_loss': noc_loss_final,
'pdm_da_loss': da_loss_final,
'pdm_ttc_loss': ttc_loss_final,
'pdm_progress_loss': progress_loss_final,
'pdm_comfort_loss': comfort_loss_final,
'agent_class_loss': agent_class_loss_final,
'agent_box_loss': agent_box_loss_final,
}
def hydra_kd_imi_agent_loss_one2many(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
vocab_pdm_score
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
# noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
# predictions['ttc'],
# predictions['comfort'], predictions['progress'])
imi = predictions['imi']
# 2 cls
# da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
# ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
# comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
# noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
# progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
vocab = predictions["trajectory_vocab"]
# B, 8 (4 secs, 0.5Hz), 3
target_traj = targets["trajectory"]
# 4, 9, ..., 39
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
B = target_traj.shape[0]
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
imi_loss_final = config.trajectory_imi_weight * imi_loss * 0.5
# noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
# da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
# ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
# progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
# comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
# agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
# agent_class_loss_final = config.agent_class_weight * agent_class_loss
# agent_box_loss_final = config.agent_box_weight * agent_box_loss
loss = (
imi_loss_final
)
return loss, {
'imi_loss': imi_loss_final,
}