longsim-base / backups /scripts /traj_clustering.py
gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
""" Adapted from NVLabs/CatK """
# Not a contribution
# Changes made by NVIDIA CORPORATION & AFFILIATES enabling <CAT-K> or otherwise documented as
# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
# SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig
from torch import Tensor
from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal
@torch.no_grad()
def cal_polygon_contour(
pos: Tensor, # [n_agent, n_step, n_target, 2]
head: Tensor, # [n_agent, n_step, n_target]
width_length: Tensor, # [n_agent, 1, 1, 2]
) -> Tensor: # [n_agent, n_step, n_target, 4, 2]
x, y = pos[..., 0], pos[..., 1] # [n_agent, n_step, n_target]
width, length = width_length[..., 0], width_length[..., 1] # [n_agent, 1 ,1]
half_cos = 0.5 * head.cos() # [n_agent, n_step, n_target]
half_sin = 0.5 * head.sin() # [n_agent, n_step, n_target]
length_cos = length * half_cos # [n_agent, n_step, n_target]
length_sin = length * half_sin # [n_agent, n_step, n_target]
width_cos = width * half_cos # [n_agent, n_step, n_target]
width_sin = width * half_sin # [n_agent, n_step, n_target]
left_front_x = x + length_cos - width_sin
left_front_y = y + length_sin + width_cos
left_front = torch.stack((left_front_x, left_front_y), dim=-1)
right_front_x = x + length_cos + width_sin
right_front_y = y + length_sin - width_cos
right_front = torch.stack((right_front_x, right_front_y), dim=-1)
right_back_x = x - length_cos + width_sin
right_back_y = y - length_sin - width_cos
right_back = torch.stack((right_back_x, right_back_y), dim=-1)
left_back_x = x - length_cos - width_sin
left_back_y = y - length_sin + width_cos
left_back = torch.stack((left_back_x, left_back_y), dim=-1)
polygon_contour = torch.stack(
(left_front, right_front, right_back, left_back), dim=-2
)
return polygon_contour
def transform_to_global(
pos_local: Tensor, # [n_agent, n_step, 2]
head_local: Optional[Tensor], # [n_agent, n_step]
pos_now: Tensor, # [n_agent, 2]
head_now: Tensor, # [n_agent]
) -> Tuple[Tensor, Optional[Tensor]]:
cos, sin = head_now.cos(), head_now.sin()
rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
pos_global = torch.bmm(pos_local, rot_mat) # [n_agent, n_step, 2]*[n_agent, 2, 2]
pos_global = pos_global + pos_now.unsqueeze(1)
if head_local is None:
head_global = None
else:
head_global = head_local + head_now.unsqueeze(1)
return pos_global, head_global
def transform_to_local(
pos_global: Tensor, # [n_agent, n_step, 2]
head_global: Optional[Tensor], # [n_agent, n_step]
pos_now: Tensor, # [n_agent, 2]
head_now: Tensor, # [n_agent]
) -> Tuple[Tensor, Optional[Tensor]]:
cos, sin = head_now.cos(), head_now.sin()
rot_mat = torch.zeros((head_now.shape[0], 2, 2), device=head_now.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = -sin
rot_mat[:, 1, 0] = sin
rot_mat[:, 1, 1] = cos
pos_local = pos_global - pos_now.unsqueeze(1)
pos_local = torch.bmm(pos_local, rot_mat) # [n_agent, n_step, 2]*[n_agent, 2, 2]
if head_global is None:
head_local = None
else:
head_local = head_global - head_now.unsqueeze(1)
return pos_local, head_local
def sample_next_token_traj(
token_traj: Tensor, # [n_agent, n_token, 4, 2]
token_traj_all: Tensor, # [n_agent, n_token, 6, 4, 2]
sampling_scheme: DictConfig,
# ! for most-likely sampling
next_token_logits: Tensor, # [n_agent, n_token], with grad
# ! for nearest-pos sampling, sampling near to GT
pos_now: Tensor, # [n_agent, 2]
head_now: Tensor, # [n_agent]
pos_next_gt: Tensor, # [n_agent, 2]
head_next_gt: Tensor, # [n_agent]
valid_next_gt: Tensor, # [n_agent]
token_agent_shape: Tensor, # [n_agent, 2]
) -> Tuple[Tensor, Tensor]:
"""
Returns:
next_token_traj_all: [n_agent, 6, 4, 2], local coord
next_token_idx: [n_agent], without grad
"""
range_a = torch.arange(next_token_logits.shape[0])
next_token_logits = next_token_logits.detach()
if (
sampling_scheme.criterium == "topk_prob"
or sampling_scheme.criterium == "topk_prob_sampled_with_dist"
):
topk_logits, topk_indices = torch.topk(
next_token_logits, sampling_scheme.num_k, dim=-1, sorted=False
)
if sampling_scheme.criterium == "topk_prob_sampled_with_dist":
#! gt_contour: [n_agent, 4, 2] in global coord
gt_contour = cal_polygon_contour(
pos_next_gt, head_next_gt, token_agent_shape
)
gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2]
token_world_sample = token_traj[range_a.unsqueeze(1), topk_indices]
token_world_sample = transform_to_global(
pos_local=token_world_sample.flatten(1, 2),
head_local=None,
pos_now=pos_now, # [n_agent, 2]
head_now=head_now, # [n_agent]
)[0].view(*token_world_sample.shape)
# dist: [n_agent, n_token]
dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1)
topk_logits = topk_logits.masked_fill(
valid_next_gt.unsqueeze(1), 0.0
) - 1.0 * dist.masked_fill(~valid_next_gt.unsqueeze(1), 0.0)
elif sampling_scheme.criterium == "topk_dist_sampled_with_prob":
#! gt_contour: [n_agent, 4, 2] in global coord
gt_contour = cal_polygon_contour(pos_next_gt, head_next_gt, token_agent_shape)
gt_contour = gt_contour.unsqueeze(1) # [n_agent, 1, 4, 2]
token_world_sample = transform_to_global(
pos_local=token_traj.flatten(1, 2), # [n_agent, n_token*4, 2]
head_local=None,
pos_now=pos_now, # [n_agent, 2]
head_now=head_now, # [n_agent]
)[0].view(*token_traj.shape)
_invalid = ~valid_next_gt
# dist: [n_agent, n_token]
dist = torch.norm(token_world_sample - gt_contour, dim=-1).mean(-1)
_logits = -1.0 * dist.masked_fill(_invalid.unsqueeze(1), 0.0)
if _invalid.any():
_logits[_invalid] = next_token_logits[_invalid]
_, topk_indices = torch.topk(
_logits, sampling_scheme.num_k, dim=-1, sorted=False
) # [n_agent, K]
topk_logits = next_token_logits[range_a.unsqueeze(1), topk_indices]
else:
raise ValueError(f"Invalid criterium: {sampling_scheme.criterium}")
# topk_logits, topk_indices: [n_agent, K]
topk_logits = topk_logits / sampling_scheme.temp
samples = Categorical(logits=topk_logits).sample() # [n_agent] in K
next_token_idx = topk_indices[range_a, samples]
next_token_traj_all = token_traj_all[range_a, next_token_idx]
return next_token_idx, next_token_traj_all
def sample_next_gmm_traj(
token_traj: Tensor, # [n_agent, n_token, 4, 2]
token_traj_all: Tensor, # [n_agent, n_token, 6, 4, 2]
sampling_scheme: DictConfig,
# ! for most-likely sampling
ego_mask: Tensor, # [n_agent], bool, ego_mask.sum()==n_batch
ego_next_logits: Tensor, # [n_batch, n_k_ego_gmm]
ego_next_poses: Tensor, # [n_batch, n_k_ego_gmm, 3]
ego_next_cov: Tensor, # [2], one for pos, one for heading.
# ! for nearest-pos sampling, sampling near to GT
pos_now: Tensor, # [n_agent, 2]
head_now: Tensor, # [n_agent]
pos_next_gt: Tensor, # [n_agent, 2]
head_next_gt: Tensor, # [n_agent]
valid_next_gt: Tensor, # [n_agent]
token_agent_shape: Tensor, # [n_agent, 2]
next_token_idx: Tensor, # [n_agent]
) -> Tuple[Tensor, Tensor]:
"""
Returns:
next_token_traj_all: [n_agent, 6, 4, 2], local coord
next_token_idx: [n_agent], without grad
"""
n_agent = token_traj.shape[0]
n_batch = ego_next_logits.shape[0]
next_token_traj_all = token_traj_all[torch.arange(n_agent), next_token_idx]
# ! sample only the ego-vehicle
assert (
sampling_scheme.criterium == "topk_prob"
or sampling_scheme.criterium == "topk_prob_sampled_with_dist"
)
topk_logits, topk_indices = torch.topk(
ego_next_logits, sampling_scheme.num_k, dim=-1, sorted=False
) # [n_agent, k], [n_agent, k]
ego_pose_topk = ego_next_poses[
torch.arange(n_batch).unsqueeze(1), topk_indices
] # [n_batch, k, 3]
if sampling_scheme.criterium == "topk_prob_sampled_with_dist":
# udpate topk_logits
gt_contour = cal_polygon_contour(
pos_next_gt[ego_mask],
head_next_gt[ego_mask],
token_agent_shape[ego_mask],
) # [n_batch, 4, 2] in global coord
gt_contour = gt_contour.unsqueeze(1) # [n_batch, 1, 4, 2]
ego_pos_global, ego_head_global = transform_to_global(
pos_local=ego_pose_topk[:, :, :2], # [n_batch, k, 2]
head_local=ego_pose_topk[:, :, -1], # [n_batch, k]
pos_now=pos_now[ego_mask], # [n_batch, 2]
head_now=head_now[ego_mask], # [n_batch]
)
ego_contour = cal_polygon_contour(
ego_pos_global, # [n_batch, k, 2]
ego_head_global, # [n_batch, k]
token_agent_shape[ego_mask].unsqueeze(1),
) # [n_batch, k, 4, 2] in global coord
dist = torch.norm(ego_contour - gt_contour, dim=-1).mean(-1) # [n_batch, k]
topk_logits = topk_logits.masked_fill(
valid_next_gt[ego_mask].unsqueeze(1), 0.0
) - 1.0 * dist.masked_fill(~valid_next_gt[ego_mask].unsqueeze(1), 0.0)
topk_logits = topk_logits / sampling_scheme.temp_mode # [n_batch, k]
ego_pose_topk = torch.cat(
[
ego_pose_topk[..., :2],
ego_pose_topk[..., [-1]].cos(),
ego_pose_topk[..., [-1]].sin(),
],
dim=-1,
)
cov = (
(ego_next_cov * sampling_scheme.temp_cov)
.repeat_interleave(2)[None, None, :]
.expand(*ego_pose_topk.shape)
) # [n_batch, k, 4]
gmm = MixtureSameFamily(
Categorical(logits=topk_logits), Independent(Normal(ego_pose_topk, cov), 1)
)
ego_sample = gmm.sample() # [n_batch, 4]
ego_contour_local = cal_polygon_contour(
ego_sample[:, :2], # [n_batch, 2]
torch.arctan2(ego_sample[:, -1], ego_sample[:, -2]), # [n_batch]
token_agent_shape[ego_mask], # [n_batch, 2]
) # [n_batch, 4, 2] in local coord
ego_token_local = token_traj[ego_mask] # [n_batch, n_token, 4, 2]
dist = torch.norm(ego_contour_local.unsqueeze(1) - ego_token_local, dim=-1).mean(
-1
) # [n_batch, n_token]
next_token_idx[ego_mask] = dist.argmin(-1)
ego_contour_local # [n_batch, 4, 2] in local coord
ego_countour_start = next_token_traj_all[ego_mask][:, 0] # [n_batch, 4, 2]
n_step = next_token_traj_all.shape[1]
diff = (ego_contour_local - ego_countour_start) / (n_step - 1)
ego_token_interp = [ego_countour_start + diff * i for i in range(n_step)]
# [n_batch, 6, 4, 2]
next_token_traj_all[ego_mask] = torch.stack(ego_token_interp, dim=1)
return next_token_idx, next_token_traj_all