navsim_ours / navsim /agents /dm /dm_loss_fn.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
2.09 kB
from typing import Dict
import torch
import torch.nn.functional as F
from navsim.agents.dm.dm_config import DMConfig
from navsim.agents.dm.dm_model import DMTrajHead
from navsim.agents.vadv2.vadv2_loss import _agent_loss
def dm_imi_loss(
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: DMConfig,
traj_head: DMTrajHead
):
"""
Helper function calculating complete loss of Transfuser
:param targets: dictionary of name tensor pairings
:param predictions: dictionary of name tensor pairings
:param config: global Transfuser config
:return: combined loss value
"""
history_waypoints = predictions['history_waypoints']
target_trajectory = targets['trajectory']
B = target_trajectory.shape[0]
standard_traj = traj_head.standardizer.transform_features(target_trajectory, history_waypoints)
noise = torch.randn(standard_traj.shape, device=standard_traj.device)
timesteps = torch.randint(0,
traj_head.scheduler.config.num_train_timesteps,
(B,),
device=standard_traj.device).long()
ego_noisy_trajectory = traj_head.scheduler.add_noise(standard_traj, noise, timesteps)
pred_noise = traj_head.denoise(
ego_noisy_trajectory,
predictions['env_features'],
predictions['status_encoding'],
timesteps
)
diffusion_loss = F.mse_loss(pred_noise, noise.reshape(B, -1))
diffusion_loss_final = diffusion_loss * config.diffusion_loss_weight
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
agent_class_loss_final = config.agent_class_weight * agent_class_loss
agent_box_loss_final = config.agent_box_weight * agent_box_loss
loss = (
diffusion_loss_final
+ agent_class_loss_final
+ agent_box_loss_final
)
return loss, {
'diffusion_loss': diffusion_loss_final,
'agent_class_loss': agent_class_loss_final,
'agent_box_loss': agent_box_loss_final,
}