from typing import Dict import torch import torch.nn.functional as F from navsim.agents.vadv2.vadv2_config import Vadv2Config from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes def hydra_loss ( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): loss_val, loss = hydra_kd_imi_agent_loss(targets, predictions, config, vocab_pdm_score) loss_one2many_val, loss_one2many = hydra_kd_imi_agent_loss_one2many(targets, predictions, config, vocab_pdm_score) loss.update(loss_one2many) return loss_val + loss_one2many_val, loss def hydra_kd_imi_agent_loss( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], predictions['ttc'], predictions['comfort'], predictions['progress']) ddc, lk, tl = predictions['ddc'], predictions['lk'], predictions['tl'] imi = predictions['imi'] # 2 cls da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype)) #expansion ddc_loss = F.binary_cross_entropy(ddc, three_to_two_classes(vocab_pdm_score['ddc'].to(da.dtype))) lk_loss = F.binary_cross_entropy(lk, vocab_pdm_score['lk'].to(progress.dtype)) tl_loss = F.binary_cross_entropy(tl, vocab_pdm_score['tl'].to(da.dtype)) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss da_loss_final = config.trajectory_pdm_weight['da'] * da_loss ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss #expansion ddc_loss_final = config.trajectory_pdm_weight['ddc'] * ddc_loss lk_loss_final = config.trajectory_pdm_weight['lk'] * lk_loss tl_loss_final = config.trajectory_pdm_weight['tl'] * tl_loss agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss loss = ( imi_loss_final + noc_loss_final + da_loss_final + ttc_loss_final + progress_loss_final + comfort_loss_final + agent_class_loss_final + agent_box_loss_final + ddc_loss_final + lk_loss_final + tl_loss_final ) return loss, { 'imi_loss': imi_loss_final, 'pdm_noc_loss': noc_loss_final, 'pdm_da_loss': da_loss_final, 'pdm_ttc_loss': ttc_loss_final, 'pdm_progress_loss': progress_loss_final, 'pdm_ddc_loss': ddc_loss_final, 'pdm_lk_loss': lk_loss_final, 'pdm_tl_loss': tl_loss_final, 'pdm_comfort_loss': comfort_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, } def hydra_kd_imi_agent_loss_one2many( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ # noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], # predictions['ttc'], # predictions['comfort'], predictions['progress']) imi = predictions['imi'] # 2 cls # da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) # ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) # comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) # noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) # progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype)) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss * 0.5 # noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss # da_loss_final = config.trajectory_pdm_weight['da'] * da_loss # ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss # progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss # comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss # agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) # agent_class_loss_final = config.agent_class_weight * agent_class_loss # agent_box_loss_final = config.agent_box_weight * agent_box_loss loss = ( imi_loss_final ) return loss, { 'imi_loss': imi_loss_final, }