File size: 3,520 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Dict

import torch
import torch.nn.functional as F

from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig
from navsim.agents.vadv2.vadv2_config import Vadv2Config
from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes


def latent_wm_loss(targets, predictions, config: HydraDreamerConfig, vit_model):
    pred = predictions['pred']
    B, L, C = pred.shape
    wm_loss = F.mse_loss(
        predictions['pred'], vit_model(targets['img_gt']).view(B, C, -1).permute(0, 2, 1)
    )
    wm_loss_final = wm_loss * config.wm_loss_weight
    return wm_loss_final, {
        'wm_loss': wm_loss_final
    }


def hydra_kd_imi_agent_loss(
        targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
        vocab_pdm_score
):
    """
    Helper function calculating complete loss of Transfuser
    :param targets: dictionary of name tensor pairings
    :param predictions: dictionary of name tensor pairings
    :param config: global Transfuser config
    :return: combined loss value
    """

    noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
                                       predictions['ttc'],
                                       predictions['comfort'], predictions['progress'])
    imi = predictions['imi']
    # 2 cls
    da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
    ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
    comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
    noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
    progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))

    vocab = predictions["trajectory_vocab"]
    # B, 8 (4 secs, 0.5Hz), 3
    target_traj = targets["trajectory"]
    # 4, 9, ..., 39
    sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
    B = target_traj.shape[0]
    l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
    imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))

    imi_loss_final = config.trajectory_imi_weight * imi_loss

    noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
    da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
    ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
    progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
    comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss

    agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)

    agent_class_loss_final = config.agent_class_weight * agent_class_loss
    agent_box_loss_final = config.agent_box_weight * agent_box_loss
    loss = (
            imi_loss_final
            + noc_loss_final
            + da_loss_final
            + ttc_loss_final
            + progress_loss_final
            + comfort_loss_final
            + agent_class_loss_final
            + agent_box_loss_final
    )
    return loss, {
        'imi_loss': imi_loss_final,
        'pdm_noc_loss': noc_loss_final,
        'pdm_da_loss': da_loss_final,
        'pdm_ttc_loss': ttc_loss_final,
        'pdm_progress_loss': progress_loss_final,
        'pdm_comfort_loss': comfort_loss_final,
        'agent_class_loss': agent_class_loss_final,
        'agent_box_loss': agent_box_loss_final,
    }