File size: 6,631 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Dict

import torch
import torch.nn.functional as F

from navsim.agents.vadv2.vadv2_config import Vadv2Config
from navsim.agents.vadv2.vadv2_loss import _agent_loss, three_to_two_classes

def hydra_loss (
        targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
        vocab_pdm_score
):
    loss_val, loss = hydra_kd_imi_agent_loss(targets, predictions, config, vocab_pdm_score)
    loss_one2many_val, loss_one2many = hydra_kd_imi_agent_loss_one2many(targets, predictions, config, vocab_pdm_score)
    loss.update(loss_one2many)
    return loss_val + loss_one2many_val, loss

def hydra_kd_imi_agent_loss(
        targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
        vocab_pdm_score
):
    """
    Helper function calculating complete loss of Transfuser
    :param targets: dictionary of name tensor pairings
    :param predictions: dictionary of name tensor pairings
    :param config: global Transfuser config
    :return: combined loss value
    """

    noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
                                       predictions['ttc'],
                                       predictions['comfort'], predictions['progress'])
    ddc, lk, tl = predictions['ddc'], predictions['lk'], predictions['tl']
    imi = predictions['imi']
    # 2 cls
    da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
    ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
    comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
    noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
    progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
    #expansion
    ddc_loss = F.binary_cross_entropy(ddc, three_to_two_classes(vocab_pdm_score['ddc'].to(da.dtype)))
    lk_loss = F.binary_cross_entropy(lk, vocab_pdm_score['lk'].to(progress.dtype))
    tl_loss = F.binary_cross_entropy(tl, vocab_pdm_score['tl'].to(da.dtype))

    vocab = predictions["trajectory_vocab"]
    # B, 8 (4 secs, 0.5Hz), 3
    target_traj = targets["trajectory"]
    # 4, 9, ..., 39
    sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
    B = target_traj.shape[0]
    l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
    imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))

    imi_loss_final = config.trajectory_imi_weight * imi_loss

    noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
    da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
    ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
    progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
    comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
    #expansion
    ddc_loss_final = config.trajectory_pdm_weight['ddc'] * ddc_loss
    lk_loss_final = config.trajectory_pdm_weight['lk'] * lk_loss
    tl_loss_final = config.trajectory_pdm_weight['tl'] * tl_loss

    agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)

    agent_class_loss_final = config.agent_class_weight * agent_class_loss
    agent_box_loss_final = config.agent_box_weight * agent_box_loss
    loss = (
            imi_loss_final
            + noc_loss_final
            + da_loss_final
            + ttc_loss_final
            + progress_loss_final
            + comfort_loss_final
            + agent_class_loss_final
            + agent_box_loss_final
            + ddc_loss_final
            + lk_loss_final
            + tl_loss_final
    )
    return loss, {
        'imi_loss': imi_loss_final,
        'pdm_noc_loss': noc_loss_final,
        'pdm_da_loss': da_loss_final,
        'pdm_ttc_loss': ttc_loss_final,
        'pdm_progress_loss': progress_loss_final,
        'pdm_ddc_loss': ddc_loss_final,
        'pdm_lk_loss': lk_loss_final,
        'pdm_tl_loss': tl_loss_final,
        'pdm_comfort_loss': comfort_loss_final,
        'agent_class_loss': agent_class_loss_final,
        'agent_box_loss': agent_box_loss_final,
    }


def hydra_kd_imi_agent_loss_one2many(
        targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
        vocab_pdm_score
):
    """
    Helper function calculating complete loss of Transfuser
    :param targets: dictionary of name tensor pairings
    :param predictions: dictionary of name tensor pairings
    :param config: global Transfuser config
    :return: combined loss value
    """

    # noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
    #                                    predictions['ttc'],
    #                                    predictions['comfort'], predictions['progress'])
    imi = predictions['imi']
    # 2 cls
    # da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
    # ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
    # comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
    # noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
    # progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))

    vocab = predictions["trajectory_vocab"]
    # B, 8 (4 secs, 0.5Hz), 3
    target_traj = targets["trajectory"]
    # 4, 9, ..., 39
    sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
    B = target_traj.shape[0]
    l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
    imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))

    imi_loss_final = config.trajectory_imi_weight * imi_loss * 0.5

    # noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
    # da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
    # ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
    # progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
    # comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss

    # agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)

    # agent_class_loss_final = config.agent_class_weight * agent_class_loss
    # agent_box_loss_final = config.agent_box_weight * agent_box_loss
    loss = (
            imi_loss_final
    )
    return loss, {
        'imi_loss': imi_loss_final,
    }