|
""" Adapted from NVLabs/CatK """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from omegaconf import DictConfig |
|
from torch import Tensor |
|
from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal |
|
|
|
|
|
@torch.no_grad() |
|
def cal_polygon_contour( |
|
pos: Tensor, |
|
head: Tensor, |
|
width_length: Tensor, |
|
) -> Tensor: |
|
x, y = pos[..., 0], pos[..., 1] |
|
width, length = width_length[..., 0], width_length[..., 1] |
|
|
|
half_cos = 0.5 * head.cos() |
|
half_sin = 0.5 * head.sin() |
|
length_cos = length * half_cos |
|
length_sin = length * half_sin |
|
width_cos = width * half_cos |
|
width_sin = width * half_sin |
|
|
|
left_front_x = x + length_cos - width_sin |
|
left_front_y = y + length_sin + width_cos |
|
left_front = torch.stack((left_front_x, left_front_y), dim=-1) |
|
|
|
right_front_x = x + length_cos + width_sin |
|
right_front_y = y + length_sin - width_cos |
|
right_front = torch.stack((right_front_x, right_front_y), dim=-1) |
|
|
|
right_back_x = x - length_cos + width_sin |
|
right_back_y = y - length_sin - width_cos |
|
right_back = torch.stack((right_back_x, right_back_y), dim=-1) |
|
|
|
left_back_x = x - length_cos - width_sin |
|
left_back_y = y - length_sin + width_cos |
|
left_back = torch.stack((left_back_x, left_back_y), dim=-1) |
|
|
|
polygon_contour = torch.stack( |
|
(left_front, right_front, right_back, left_back), dim=-2 |
|
) |
|
|
|
return polygon_contour |
|
|
|
|
|
def transform_to_global( |
|
pos_local: Tensor, |
|
head_local: Optional[Tensor], |
|
pos_now: Tensor, |
|
head_now: Tensor, |
|
) -> Tuple[Tensor, Optional[Tensor]]: |
|
cos, sin = head_now.cos(), head_now.sin() |
|
rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device) |
|
rot_mat[:, 0, 0] = cos |
|
rot_mat[:, 0, 1] = sin |
|
rot_mat[:, 1, 0] = -sin |
|
rot_mat[:, 1, 1] = cos |
|
|
|
pos_global = torch.bmm(pos_local, rot_mat) |
|
pos_global = pos_global + pos_now.unsqueeze(1) |
|
if head_local is None: |
|
head_global = None |
|
else: |
|
head_global = head_local + head_now.unsqueeze(1) |
|
return pos_global, head_global |
|
|
|
|
|
def transform_to_local( |
|
pos_global: Tensor, |
|
head_global: Optional[Tensor], |
|
pos_now: Tensor, |
|
head_now: Tensor, |
|
) -> Tuple[Tensor, Optional[Tensor]]: |
|
cos, sin = head_now.cos(), head_now.sin() |
|
rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device) |
|
rot_mat[:, 0, 0] = cos |
|
rot_mat[:, 0, 1] = -sin |
|
rot_mat[:, 1, 0] = sin |
|
rot_mat[:, 1, 1] = cos |
|
|
|
pos_local = pos_global - pos_now.unsqueeze(1) |
|
pos_local = torch.bmm(pos_local, rot_mat) |
|
if head_global is None: |
|
head_local = None |
|
else: |
|
head_local = head_global - head_now.unsqueeze(1) |
|
return pos_local, head_local |
|
|
|
|
|
def sample_next_token_traj( |
|
token_traj: Tensor, |
|
token_traj_all: Tensor, |
|
sampling_scheme: DictConfig, |
|
|
|
next_token_logits: Tensor, |
|
|
|
pos_now: Tensor, |
|
head_now: Tensor, |
|
pos_next_gt: Tensor, |
|
head_next_gt: Tensor, |
|
valid_next_gt: Tensor, |
|
token_agent_shape: Tensor, |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Returns: |
|
next_token_traj_all: [n_agent, 6, 4, 2], local coord |
|
next_token_idx: [n_agent], without grad |
|
""" |
|
range_a = torch.arange(next_token_logits.shape[0]) |
|
next_token_logits = next_token_logits.detach() |
|
|
|
if ( |
|
sampling_scheme.criterium == "topk_prob" |
|
or sampling_scheme.criterium == "topk_prob_sampled_with_dist" |
|
): |
|
topk_logits, topk_indices = torch.topk( |
|
next_token_logits, sampling_scheme.num_k, dim=-1, sorted=False |
|
) |
|
if sampling_scheme.criterium == "topk_prob_sampled_with_dist": |
|
|
|
gt_contour = cal_polygon_contour( |
|
pos_next_gt, head_next_gt, token_agent_shape |
|
) |
|
gt_contour = gt_contour.unsqueeze(1) |
|
token_world_sample = token_traj[range_a.unsqueeze(1), topk_indices] |
|
token_world_sample = transform_to_global( |
|
pos_local=token_world_sample.flatten(1, 2), |
|
head_local=None, |
|
pos_now=pos_now, |
|
head_now=head_now, |
|
)[0].view(*token_world_sample.shape) |
|
|
|
|
|
dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) |
|
topk_logits = topk_logits.masked_fill( |
|
valid_next_gt.unsqueeze(1), 0.0 |
|
) - 1.0 * dist.masked_fill(~valid_next_gt.unsqueeze(1), 0.0) |
|
elif sampling_scheme.criterium == "topk_dist_sampled_with_prob": |
|
|
|
gt_contour = cal_polygon_contour(pos_next_gt, head_next_gt, token_agent_shape) |
|
gt_contour = gt_contour.unsqueeze(1) |
|
token_world_sample = transform_to_global( |
|
pos_local=token_traj.flatten(1, 2), |
|
head_local=None, |
|
pos_now=pos_now, |
|
head_now=head_now, |
|
)[0].view(*token_traj.shape) |
|
|
|
_invalid = ~valid_next_gt |
|
|
|
dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1) |
|
_logits = -1.0 * dist.masked_fill(_invalid.unsqueeze(1), 0.0) |
|
|
|
if _invalid.any(): |
|
_logits[_invalid] = next_token_logits[_invalid] |
|
_, topk_indices = torch.topk( |
|
_logits, sampling_scheme.num_k, dim=-1, sorted=False |
|
) |
|
topk_logits = next_token_logits[range_a.unsqueeze(1), topk_indices] |
|
|
|
else: |
|
raise ValueError(f"Invalid criterium: {sampling_scheme.criterium}") |
|
|
|
|
|
topk_logits = topk_logits / sampling_scheme.temp |
|
samples = Categorical(logits=topk_logits).sample() |
|
next_token_idx = topk_indices[range_a, samples] |
|
next_token_traj_all = token_traj_all[range_a, next_token_idx] |
|
|
|
return next_token_idx, next_token_traj_all |
|
|
|
|
|
def sample_next_gmm_traj( |
|
token_traj: Tensor, |
|
token_traj_all: Tensor, |
|
sampling_scheme: DictConfig, |
|
|
|
ego_mask: Tensor, |
|
ego_next_logits: Tensor, |
|
ego_next_poses: Tensor, |
|
ego_next_cov: Tensor, |
|
|
|
pos_now: Tensor, |
|
head_now: Tensor, |
|
pos_next_gt: Tensor, |
|
head_next_gt: Tensor, |
|
valid_next_gt: Tensor, |
|
token_agent_shape: Tensor, |
|
next_token_idx: Tensor, |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Returns: |
|
next_token_traj_all: [n_agent, 6, 4, 2], local coord |
|
next_token_idx: [n_agent], without grad |
|
""" |
|
n_agent = token_traj.shape[0] |
|
n_batch = ego_next_logits.shape[0] |
|
next_token_traj_all = token_traj_all[torch.arange(n_agent), next_token_idx] |
|
|
|
|
|
assert ( |
|
sampling_scheme.criterium == "topk_prob" |
|
or sampling_scheme.criterium == "topk_prob_sampled_with_dist" |
|
) |
|
topk_logits, topk_indices = torch.topk( |
|
ego_next_logits, sampling_scheme.num_k, dim=-1, sorted=False |
|
) |
|
ego_pose_topk = ego_next_poses[ |
|
torch.arange(n_batch).unsqueeze(1), topk_indices |
|
] |
|
|
|
if sampling_scheme.criterium == "topk_prob_sampled_with_dist": |
|
|
|
gt_contour = cal_polygon_contour( |
|
pos_next_gt[ego_mask], |
|
head_next_gt[ego_mask], |
|
token_agent_shape[ego_mask], |
|
) |
|
gt_contour = gt_contour.unsqueeze(1) |
|
|
|
ego_pos_global, ego_head_global = transform_to_global( |
|
pos_local=ego_pose_topk[:, :, :2], |
|
head_local=ego_pose_topk[:, :, -1], |
|
pos_now=pos_now[ego_mask], |
|
head_now=head_now[ego_mask], |
|
) |
|
ego_contour = cal_polygon_contour( |
|
ego_pos_global, |
|
ego_head_global, |
|
token_agent_shape[ego_mask].unsqueeze(1), |
|
) |
|
|
|
dist = torch.norm(ego_contour - gt_contour, dim=-1).mean(-1) |
|
topk_logits = topk_logits.masked_fill( |
|
valid_next_gt[ego_mask].unsqueeze(1), 0.0 |
|
) - 1.0 * dist.masked_fill(~valid_next_gt[ego_mask].unsqueeze(1), 0.0) |
|
|
|
topk_logits = topk_logits / sampling_scheme.temp_mode |
|
ego_pose_topk = torch.cat( |
|
[ |
|
ego_pose_topk[..., :2], |
|
ego_pose_topk[..., [-1]].cos(), |
|
ego_pose_topk[..., [-1]].sin(), |
|
], |
|
dim=-1, |
|
) |
|
cov = ( |
|
(ego_next_cov * sampling_scheme.temp_cov) |
|
.repeat_interleave(2)[None, None, :] |
|
.expand(*ego_pose_topk.shape) |
|
) |
|
gmm = MixtureSameFamily( |
|
Categorical(logits=topk_logits), Independent(Normal(ego_pose_topk, cov), 1) |
|
) |
|
ego_sample = gmm.sample() |
|
|
|
ego_contour_local = cal_polygon_contour( |
|
ego_sample[:, :2], |
|
torch.arctan2(ego_sample[:, -1], ego_sample[:, -2]), |
|
token_agent_shape[ego_mask], |
|
) |
|
|
|
ego_token_local = token_traj[ego_mask] |
|
|
|
dist = torch.norm(ego_contour_local.unsqueeze(1) - ego_token_local, dim=-1).mean( |
|
-1 |
|
) |
|
next_token_idx[ego_mask] = dist.argmin(-1) |
|
|
|
ego_contour_local |
|
ego_countour_start = next_token_traj_all[ego_mask][:, 0] |
|
n_step = next_token_traj_all.shape[1] |
|
diff = (ego_contour_local - ego_countour_start) / (n_step - 1) |
|
ego_token_interp = [ego_countour_start + diff * i for i in range(n_step)] |
|
|
|
next_token_traj_all[ego_mask] = torch.stack(ego_token_interp, dim=1) |
|
|
|
return next_token_idx, next_token_traj_all |
|
|