gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import torch
from torch import Tensor
def get_yaw_rotation_2d(yaw):
"""
Gets a 2D rotation matrix given a yaw angle.
Args:
yaw: torch.Tensor, rotation angle in radians. Can be any shape except empty.
Returns:
rotation: torch.Tensor with shape [..., 2, 2], where `...` matches input shape.
"""
cos_yaw = torch.cos(yaw)
sin_yaw = torch.sin(yaw)
rotation = torch.stack([
torch.stack([cos_yaw, -sin_yaw], dim=-1),
torch.stack([sin_yaw, cos_yaw], dim=-1),
], dim=-2) # Shape: [..., 2, 2]
return rotation
def get_yaw_rotation(yaw):
"""
Computes a 3D rotation matrix given a yaw angle (rotation around the Z-axis).
Args:
yaw: torch.Tensor of any shape, representing yaw angles in radians.
Returns:
rotation: torch.Tensor of shape [input_shape, 3, 3], representing the rotation matrices.
"""
cos_yaw = torch.cos(yaw)
sin_yaw = torch.sin(yaw)
ones = torch.ones_like(yaw)
zeros = torch.zeros_like(yaw)
return torch.stack([
torch.stack([cos_yaw, -sin_yaw, zeros], dim=-1),
torch.stack([sin_yaw, cos_yaw, zeros], dim=-1),
torch.stack([zeros, zeros, ones], dim=-1),
], dim=-2)
def get_transform(rotation, translation):
"""
Combines an NxN rotation matrix and an Nx1 translation vector into an (N+1)x(N+1) transformation matrix.
Args:
rotation: torch.Tensor of shape [..., N, N], representing rotation matrices.
translation: torch.Tensor of shape [..., N], representing translation vectors.
This must have the same dtype as rotation.
Returns:
transform: torch.Tensor of shape [..., (N+1), (N+1)], representing the transformation matrices.
This has the same dtype as rotation.
"""
# [..., N, 1]
translation_n_1 = translation.unsqueeze(-1)
# [..., N, N+1] - Combine rotation and translation
transform = torch.cat([rotation, translation_n_1], dim=-1)
# [..., N] - Create the last row, which is [0, 0, ..., 0, 1]
last_row = torch.zeros_like(translation)
last_row = torch.cat([last_row, torch.ones_like(last_row[..., :1])], dim=-1)
# [..., N+1, N+1] - Append the last row to form the final transformation matrix
transform = torch.cat([transform, last_row.unsqueeze(-2)], dim=-2)
return transform
def get_upright_3d_box_corners(boxes: Tensor):
"""
Given a set of upright 3D bounding boxes, return its 8 corner points.
Args:
boxes: torch.Tensor [N, 7]. The inner dims are [center{x,y,z}, length, width,
height, heading].
Returns:
corners: torch.Tensor [N, 8, 3].
"""
center_x, center_y, center_z, length, width, height, heading = boxes.unbind(dim=-1)
# Compute rotation matrix [N, 3, 3]
rotation = get_yaw_rotation(heading)
# Translation [N, 3]
translation = torch.stack([center_x, center_y, center_z], dim=-1)
l2, w2, h2 = length * 0.5, width * 0.5, height * 0.5
# Define the 8 corners in local coordinates [N, 8, 3]
corners_local = torch.stack([
torch.stack([ l2, w2, -h2], dim=-1),
torch.stack([-l2, w2, -h2], dim=-1),
torch.stack([-l2, -w2, -h2], dim=-1),
torch.stack([ l2, -w2, -h2], dim=-1),
torch.stack([ l2, w2, h2], dim=-1),
torch.stack([-l2, w2, h2], dim=-1),
torch.stack([-l2, -w2, h2], dim=-1),
torch.stack([ l2, -w2, h2], dim=-1),
], dim=1) # Shape: [N, 8, 3]
# Rotate and translate the corners
corners = torch.einsum('n i j, n k j -> n k i', rotation, corners_local) + translation.unsqueeze(1)
return corners