|
from typing import Dict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from navsim.agents.dm.dm_config import DMConfig |
|
from navsim.agents.dm.dm_model import DMTrajHead |
|
from navsim.agents.vadv2.vadv2_loss import _agent_loss |
|
|
|
|
|
def dm_imi_loss( |
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: DMConfig, |
|
traj_head: DMTrajHead |
|
): |
|
""" |
|
Helper function calculating complete loss of Transfuser |
|
:param targets: dictionary of name tensor pairings |
|
:param predictions: dictionary of name tensor pairings |
|
:param config: global Transfuser config |
|
:return: combined loss value |
|
""" |
|
history_waypoints = predictions['history_waypoints'] |
|
target_trajectory = targets['trajectory'] |
|
B = target_trajectory.shape[0] |
|
standard_traj = traj_head.standardizer.transform_features(target_trajectory, history_waypoints) |
|
noise = torch.randn(standard_traj.shape, device=standard_traj.device) |
|
timesteps = torch.randint(0, |
|
traj_head.scheduler.config.num_train_timesteps, |
|
(B,), |
|
device=standard_traj.device).long() |
|
|
|
ego_noisy_trajectory = traj_head.scheduler.add_noise(standard_traj, noise, timesteps) |
|
|
|
pred_noise = traj_head.denoise( |
|
ego_noisy_trajectory, |
|
predictions['env_features'], |
|
predictions['status_encoding'], |
|
timesteps |
|
) |
|
|
|
diffusion_loss = F.mse_loss(pred_noise, noise.reshape(B, -1)) |
|
diffusion_loss_final = diffusion_loss * config.diffusion_loss_weight |
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config) |
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss |
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss |
|
loss = ( |
|
diffusion_loss_final |
|
+ agent_class_loss_final |
|
+ agent_box_loss_final |
|
) |
|
return loss, { |
|
'diffusion_loss': diffusion_loss_final, |
|
'agent_class_loss': agent_class_loss_final, |
|
'agent_box_loss': agent_box_loss_final, |
|
} |
|
|