from typing import Dict import torch import torch.nn.functional as F from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig from navsim.agents.vadv2.vadv2_config import Vadv2Config from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes def latent_wm_loss(targets, predictions, config: HydraDreamerConfig, vit_model): pred = predictions['pred'] B, L, C = pred.shape wm_loss = F.mse_loss( predictions['pred'], vit_model(targets['img_gt']).view(B, C, -1).permute(0, 2, 1) ) wm_loss_final = wm_loss * config.wm_loss_weight return wm_loss_final, { 'wm_loss': wm_loss_final } def hydra_kd_imi_agent_loss( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config, vocab_pdm_score ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'], predictions['ttc'], predictions['comfort'], predictions['progress']) imi = predictions['imi'] # 2 cls da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype)) ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype)) comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype)) noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype))) progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype)) vocab = predictions["trajectory_vocab"] # B, 8 (4 secs, 0.5Hz), 3 target_traj = targets["trajectory"] # 4, 9, ..., 39 sampled_timepoints = [5 * k - 1 for k in range(1, 9)] B = target_traj.shape[0] l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1)) imi_loss_final = config.trajectory_imi_weight * imi_loss noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss da_loss_final = config.trajectory_pdm_weight['da'] * da_loss ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss loss = ( imi_loss_final + noc_loss_final + da_loss_final + ttc_loss_final + progress_loss_final + comfort_loss_final + agent_class_loss_final + agent_box_loss_final ) return loss, { 'imi_loss': imi_loss_final, 'pdm_noc_loss': noc_loss_final, 'pdm_da_loss': da_loss_final, 'pdm_ttc_loss': ttc_loss_final, 'pdm_progress_loss': progress_loss_final, 'pdm_comfort_loss': comfort_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, }