from typing import Dict import torch import torch.nn.functional as F from navsim.agents.dm.dm_config import DMConfig from navsim.agents.dm.dm_model import DMTrajHead from navsim.agents.vadv2.vadv2_loss import _agent_loss def dm_imi_loss( targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: DMConfig, traj_head: DMTrajHead ): """ Helper function calculating complete loss of Transfuser :param targets: dictionary of name tensor pairings :param predictions: dictionary of name tensor pairings :param config: global Transfuser config :return: combined loss value """ history_waypoints = predictions['history_waypoints'] target_trajectory = targets['trajectory'] B = target_trajectory.shape[0] standard_traj = traj_head.standardizer.transform_features(target_trajectory, history_waypoints) noise = torch.randn(standard_traj.shape, device=standard_traj.device) timesteps = torch.randint(0, traj_head.scheduler.config.num_train_timesteps, (B,), device=standard_traj.device).long() ego_noisy_trajectory = traj_head.scheduler.add_noise(standard_traj, noise, timesteps) pred_noise = traj_head.denoise( ego_noisy_trajectory, predictions['env_features'], predictions['status_encoding'], timesteps ) diffusion_loss = F.mse_loss(pred_noise, noise.reshape(B, -1)) diffusion_loss_final = diffusion_loss * config.diffusion_loss_weight agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) agent_class_loss_final = config.agent_class_weight * agent_class_loss agent_box_loss_final = config.agent_box_weight * agent_box_loss loss = ( diffusion_loss_final + agent_class_loss_final + agent_box_loss_final ) return loss, { 'diffusion_loss': diffusion_loss_final, 'agent_class_loss': agent_class_loss_final, 'agent_box_loss': agent_box_loss_final, }