|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from scipy.optimize import linear_sum_assignment
|
|
|
|
from navsim.agents.transfuser.transfuser_config import TransfuserConfig
|
|
from navsim.agents.vadv2.vadv2_config import Vadv2Config
|
|
|
|
def vadv2_loss_pdm_ablate(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
|
|
total = predictions['total']
|
|
imi = predictions['imi']
|
|
|
|
pdmtotal_loss = F.binary_cross_entropy(total, vocab_pdm_score['total'].to(total.dtype))
|
|
|
|
vocab = predictions["trajectory_vocab"]
|
|
|
|
target_traj = targets["trajectory"]
|
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
|
|
B = target_traj.shape[0]
|
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
|
|
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
|
|
|
|
imi_loss_final = config.trajectory_imi_weight * imi_loss
|
|
|
|
pdmtotal_loss_final = config.trajectory_pdm_weight['total'] * pdmtotal_loss
|
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
|
|
bev_semantic_loss = F.cross_entropy(
|
|
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
|
|
)
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss
|
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss
|
|
bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss
|
|
loss = (
|
|
imi_loss_final
|
|
+ pdmtotal_loss_final
|
|
+ agent_class_loss_final
|
|
+ agent_box_loss_final
|
|
+ bev_semantic_loss_final
|
|
)
|
|
return loss, {
|
|
'imi_loss': imi_loss_final,
|
|
'pdmtotal_loss': pdmtotal_loss_final,
|
|
'agent_class_loss': agent_class_loss_final,
|
|
'agent_box_loss': agent_box_loss_final,
|
|
'bev_semantic_loss': bev_semantic_loss_final
|
|
}
|
|
|
|
|
|
def vadv2_loss_center_woper(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
pred_dist = predictions["trajectory_distribution"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, N_VOCAB = pred_dist.shape
|
|
|
|
vocab = predictions["trajectory_vocab"]
|
|
|
|
target_traj = targets["trajectory"]
|
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
|
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
|
|
trajectory_loss = F.cross_entropy(pred_dist, l2_distance.sum((-2, -1)).softmax(1))
|
|
trajectory_imi_loss_final = config.trajectory_imi_weight * trajectory_loss
|
|
loss = (
|
|
trajectory_imi_loss_final
|
|
)
|
|
return loss, {
|
|
'trajectory_imi_loss': trajectory_imi_loss_final,
|
|
}
|
|
|
|
|
|
def vadv2_loss_center(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
pred_dist = predictions["trajectory_distribution"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, N_VOCAB = pred_dist.shape
|
|
|
|
vocab = predictions["trajectory_vocab"]
|
|
|
|
target_traj = targets["trajectory"]
|
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
|
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
|
|
trajectory_loss = F.cross_entropy(pred_dist, l2_distance.sum((-2, -1)).softmax(1))
|
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
|
|
bev_semantic_loss = F.cross_entropy(
|
|
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
|
|
)
|
|
trajectory_imi_loss_final = config.trajectory_imi_weight * trajectory_loss
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss
|
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss
|
|
bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss
|
|
loss = (
|
|
trajectory_imi_loss_final
|
|
+ agent_class_loss_final
|
|
+ agent_box_loss_final
|
|
+ bev_semantic_loss_final
|
|
)
|
|
return loss, {
|
|
'trajectory_imi_loss': trajectory_imi_loss_final,
|
|
'agent_class_loss': agent_class_loss_final,
|
|
'agent_box_loss': agent_box_loss_final,
|
|
'bev_semantic_loss': bev_semantic_loss_final
|
|
}
|
|
|
|
|
|
def vadv2_loss_ori(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
pred_dist = predictions["trajectory_distribution"]
|
|
|
|
|
|
|
|
B, N_SAMPLES = pred_dist.shape
|
|
|
|
|
|
|
|
|
|
|
|
target_dist = torch.zeros((B, config.vocab_size), dtype=pred_dist.dtype, device=pred_dist.device)
|
|
mask = torch.eye(B, dtype=pred_dist.dtype, device=pred_dist.device)
|
|
target_dist = torch.cat([target_dist, mask], dim=-1).contiguous()
|
|
|
|
trajectory_loss = F.cross_entropy(pred_dist, target_dist, reduction='mean')
|
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
|
|
bev_semantic_loss = F.cross_entropy(
|
|
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
|
|
)
|
|
trajectory_pdm_loss_final = config.trajectory_imi_weight * trajectory_loss
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss
|
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss
|
|
bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss
|
|
loss = (
|
|
trajectory_pdm_loss_final
|
|
+ agent_class_loss_final
|
|
+ agent_box_loss_final
|
|
+ bev_semantic_loss_final
|
|
)
|
|
return loss, {
|
|
'trajectory_pdm_loss': trajectory_pdm_loss_final,
|
|
'agent_class_loss': agent_class_loss_final,
|
|
'agent_box_loss': agent_box_loss_final,
|
|
'bev_semantic_loss': bev_semantic_loss_final
|
|
}
|
|
|
|
def three_to_two_classes(x):
|
|
x[x==0.5] = 0.0
|
|
return x
|
|
|
|
def vadv2_loss_pdm_wo_progress(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
|
|
noc, da, dd, ttc, comfort = (predictions['noc'], predictions['da'], predictions['dd'],
|
|
predictions['ttc'], predictions['comfort'])
|
|
imi = predictions['imi']
|
|
|
|
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
|
|
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
|
|
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
|
|
|
|
|
|
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
|
|
dd_loss = F.binary_cross_entropy(dd, three_to_two_classes(vocab_pdm_score['dd'].to(da.dtype)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab = predictions["trajectory_vocab"]
|
|
|
|
target_traj = targets["trajectory"]
|
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
|
|
B = target_traj.shape[0]
|
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
|
|
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
|
|
|
|
imi_loss_final = config.trajectory_imi_weight * imi_loss
|
|
|
|
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
|
|
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
|
|
dd_loss_final = config.trajectory_pdm_weight['dd'] * dd_loss
|
|
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
|
|
|
|
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
|
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
|
|
bev_semantic_loss = F.cross_entropy(
|
|
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
|
|
)
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss
|
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss
|
|
bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss
|
|
loss = (
|
|
imi_loss_final
|
|
+ noc_loss_final
|
|
+ da_loss_final
|
|
+ dd_loss_final
|
|
+ ttc_loss_final
|
|
|
|
+ comfort_loss_final
|
|
+ agent_class_loss_final
|
|
+ agent_box_loss_final
|
|
+ bev_semantic_loss_final
|
|
)
|
|
return loss, {
|
|
'imi_loss': imi_loss_final,
|
|
'pdm_noc_loss': noc_loss_final,
|
|
'pdm_da_loss': da_loss_final,
|
|
'pdm_dd_loss': dd_loss_final,
|
|
'pdm_ttc_loss': ttc_loss_final,
|
|
|
|
'pdm_comfort_loss': comfort_loss_final,
|
|
'agent_class_loss': agent_class_loss_final,
|
|
'agent_box_loss': agent_box_loss_final,
|
|
'bev_semantic_loss': bev_semantic_loss_final
|
|
}
|
|
|
|
|
|
def vadv2_loss_pdm_w_progress(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
|
|
noc, da, ttc, comfort, progress = (predictions['noc'], predictions['da'],
|
|
predictions['ttc'],
|
|
predictions['comfort'], predictions['progress'])
|
|
imi = predictions['imi']
|
|
|
|
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
|
|
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
|
|
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
|
|
|
|
|
|
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(progress.dtype))
|
|
|
|
vocab = predictions["trajectory_vocab"]
|
|
|
|
target_traj = targets["trajectory"]
|
|
|
|
sampled_timepoints = [5 * k - 1 for k in range(1, 9)]
|
|
B = target_traj.shape[0]
|
|
l2_distance = -((vocab[:, sampled_timepoints][None].repeat(B, 1, 1, 1) - target_traj[:, None]) ** 2) / config.sigma
|
|
imi_loss = F.cross_entropy(imi, l2_distance.sum((-2, -1)).softmax(1))
|
|
|
|
imi_loss_final = config.trajectory_imi_weight * imi_loss
|
|
|
|
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
|
|
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
|
|
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
|
|
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
|
|
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
|
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
|
|
bev_semantic_loss = F.cross_entropy(
|
|
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
|
|
)
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss
|
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss
|
|
bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss
|
|
loss = (
|
|
imi_loss_final
|
|
+ noc_loss_final
|
|
+ da_loss_final
|
|
+ ttc_loss_final
|
|
+ progress_loss_final
|
|
+ comfort_loss_final
|
|
+ agent_class_loss_final
|
|
+ agent_box_loss_final
|
|
+ bev_semantic_loss_final
|
|
)
|
|
return loss, {
|
|
'imi_loss': imi_loss_final,
|
|
'pdm_noc_loss': noc_loss_final,
|
|
'pdm_da_loss': da_loss_final,
|
|
'pdm_ttc_loss': ttc_loss_final,
|
|
'pdm_progress_loss': progress_loss_final,
|
|
'pdm_comfort_loss': comfort_loss_final,
|
|
'agent_class_loss': agent_class_loss_final,
|
|
'agent_box_loss': agent_box_loss_final,
|
|
'bev_semantic_loss': bev_semantic_loss_final
|
|
}
|
|
|
|
def vadv2_loss_pdm(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: Vadv2Config,
|
|
vocab_pdm_score
|
|
):
|
|
"""
|
|
Helper function calculating complete loss of Transfuser
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: combined loss value
|
|
"""
|
|
|
|
noc, da, dd, ttc, comfort, progress = (predictions['noc'], predictions['da'], predictions['dd'],
|
|
predictions['ttc'], predictions['comfort'], predictions['progress'])
|
|
|
|
da_loss = F.binary_cross_entropy(da, vocab_pdm_score['da'].to(da.dtype))
|
|
ttc_loss = F.binary_cross_entropy(ttc, vocab_pdm_score['ttc'].to(da.dtype))
|
|
comfort_loss = F.binary_cross_entropy(comfort, vocab_pdm_score['comfort'].to(da.dtype))
|
|
|
|
|
|
noc_loss = F.binary_cross_entropy(noc, three_to_two_classes(vocab_pdm_score['noc'].to(da.dtype)))
|
|
dd_loss = F.binary_cross_entropy(dd, three_to_two_classes(vocab_pdm_score['dd'].to(da.dtype)))
|
|
|
|
|
|
progress_loss = F.binary_cross_entropy(progress, vocab_pdm_score['progress'].to(da.dtype))
|
|
|
|
noc_loss_final = config.trajectory_pdm_weight['noc'] * noc_loss
|
|
da_loss_final = config.trajectory_pdm_weight['da'] * da_loss
|
|
dd_loss_final = config.trajectory_pdm_weight['dd'] * dd_loss
|
|
ttc_loss_final = config.trajectory_pdm_weight['ttc'] * ttc_loss
|
|
progress_loss_final = config.trajectory_pdm_weight['progress'] * progress_loss
|
|
comfort_loss_final = config.trajectory_pdm_weight['comfort'] * comfort_loss
|
|
|
|
agent_class_loss, agent_box_loss = _agent_loss(targets, predictions, config)
|
|
bev_semantic_loss = F.cross_entropy(
|
|
predictions["bev_semantic_map"], targets["bev_semantic_map"].long()
|
|
)
|
|
agent_class_loss_final = config.agent_class_weight * agent_class_loss
|
|
agent_box_loss_final = config.agent_box_weight * agent_box_loss
|
|
bev_semantic_loss_final = config.bev_semantic_weight * bev_semantic_loss
|
|
loss = (
|
|
noc_loss_final
|
|
+ da_loss_final
|
|
+ dd_loss_final
|
|
+ ttc_loss_final
|
|
+ progress_loss_final
|
|
+ comfort_loss_final
|
|
+ agent_class_loss_final
|
|
+ agent_box_loss_final
|
|
+ bev_semantic_loss_final
|
|
)
|
|
return loss, {
|
|
'pdm_noc_loss': noc_loss_final,
|
|
'pdm_da_loss': da_loss_final,
|
|
'pdm_dd_loss': dd_loss_final,
|
|
'pdm_ttc_loss': ttc_loss_final,
|
|
'pdm_progress_loss': progress_loss_final,
|
|
'pdm_comfort_loss': comfort_loss_final,
|
|
'agent_class_loss': agent_class_loss_final,
|
|
'agent_box_loss': agent_box_loss_final,
|
|
'bev_semantic_loss': bev_semantic_loss_final
|
|
}
|
|
|
|
|
|
def _agent_loss(
|
|
targets: Dict[str, torch.Tensor], predictions: Dict[str, torch.Tensor], config: TransfuserConfig
|
|
):
|
|
"""
|
|
Hungarian matching loss for agent detection
|
|
:param targets: dictionary of name tensor pairings
|
|
:param predictions: dictionary of name tensor pairings
|
|
:param config: global Transfuser config
|
|
:return: detection loss
|
|
"""
|
|
|
|
gt_states, gt_valid = targets["agent_states"], targets["agent_labels"]
|
|
pred_states, pred_logits = predictions["agent_states"], predictions["agent_labels"]
|
|
|
|
|
|
batch_dim, num_instances = pred_states.shape[:2]
|
|
num_gt_instances = gt_valid.sum()
|
|
num_gt_instances = num_gt_instances if num_gt_instances > 0 else num_gt_instances + 1
|
|
|
|
ce_cost = _get_ce_cost(gt_valid, pred_logits)
|
|
l1_cost = _get_l1_cost(gt_states, pred_states, gt_valid)
|
|
|
|
cost = config.agent_class_weight * ce_cost + config.agent_box_weight * l1_cost
|
|
cost = cost.cpu()
|
|
|
|
indices = [linear_sum_assignment(c) for i, c in enumerate(cost)]
|
|
matching = [
|
|
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
|
|
for i, j in indices
|
|
]
|
|
idx = _get_src_permutation_idx(matching)
|
|
|
|
pred_states_idx = pred_states[idx]
|
|
gt_states_idx = torch.cat([t[i] for t, (_, i) in zip(gt_states, indices)], dim=0)
|
|
|
|
pred_valid_idx = pred_logits[idx]
|
|
gt_valid_idx = torch.cat([t[i] for t, (_, i) in zip(gt_valid, indices)], dim=0).float()
|
|
|
|
l1_loss = F.l1_loss(pred_states_idx, gt_states_idx, reduction="none")
|
|
l1_loss = l1_loss.sum(-1) * gt_valid_idx
|
|
l1_loss = l1_loss.view(batch_dim, -1).sum() / num_gt_instances
|
|
|
|
ce_loss = F.binary_cross_entropy_with_logits(pred_valid_idx, gt_valid_idx, reduction="none")
|
|
ce_loss = ce_loss.view(batch_dim, -1).mean()
|
|
|
|
return ce_loss, l1_loss
|
|
|
|
|
|
@torch.no_grad()
|
|
def _get_ce_cost(gt_valid: torch.Tensor, pred_logits: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Function to calculate cross-entropy cost for cost matrix.
|
|
:param gt_valid: tensor of binary ground-truth labels
|
|
:param pred_logits: tensor of predicted logits of neural net
|
|
:return: bce cost matrix as tensor
|
|
"""
|
|
|
|
|
|
|
|
gt_valid_expanded = gt_valid[:, :, None].detach().float()
|
|
pred_logits_expanded = pred_logits[:, None, :].detach()
|
|
|
|
max_val = torch.relu(-pred_logits_expanded)
|
|
helper_term = max_val + torch.log(
|
|
torch.exp(-max_val) + torch.exp(-pred_logits_expanded - max_val)
|
|
)
|
|
ce_cost = (1 - gt_valid_expanded) * pred_logits_expanded + helper_term
|
|
ce_cost = ce_cost.permute(0, 2, 1)
|
|
|
|
return ce_cost
|
|
|
|
|
|
@torch.no_grad()
|
|
def _get_l1_cost(
|
|
gt_states: torch.Tensor, pred_states: torch.Tensor, gt_valid: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Function to calculate L1 cost for cost matrix.
|
|
:param gt_states: tensor of ground-truth bounding boxes
|
|
:param pred_states: tensor of predicted bounding boxes
|
|
:param gt_valid: mask of binary ground-truth labels
|
|
:return: l1 cost matrix as tensor
|
|
"""
|
|
|
|
gt_states_expanded = gt_states[:, :, None, :2].detach()
|
|
pred_states_expanded = pred_states[:, None, :, :2].detach()
|
|
l1_cost = gt_valid[..., None].float() * (gt_states_expanded - pred_states_expanded).abs().sum(
|
|
dim=-1
|
|
)
|
|
l1_cost = l1_cost.permute(0, 2, 1)
|
|
return l1_cost
|
|
|
|
|
|
def _get_src_permutation_idx(indices):
|
|
"""
|
|
Helper function to align indices after matching
|
|
:param indices: matched indices
|
|
:return: permuted indices
|
|
"""
|
|
|
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
src_idx = torch.cat([src for (src, _) in indices])
|
|
return batch_idx, src_idx
|
|
|