sapiens-pose / external /cv /mmcv /ops /chamfer_distance.py
rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence, Tuple
import torch
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['chamfer_distance_forward', 'chamfer_distance_backward'])
class ChamferDistanceFunction(Function):
"""This is an implementation of the 2D Chamfer Distance.
It has been used in the paper `Oriented RepPoints for Aerial Object
Detection (CVPR 2022) <https://arxiv.org/abs/2105.11111>_`.
"""
@staticmethod
def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:
"""
Args:
xyz1 (Tensor): Point set with shape (B, N, 2).
xyz2 (Tensor): Point set with shape (B, N, 2).
Returns:
Sequence[Tensor]:
- dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with
shape (B, N).
- dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with
shape (B, N).
- idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
with shape (B, N), which be used in compute gradient.
- idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
with shape (B, N), which be used in compute gradient.
"""
batch_size, n, _ = xyz1.size()
_, m, _ = xyz2.size()
device = xyz1.device
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
dist1 = torch.zeros(batch_size, n).to(device)
dist2 = torch.zeros(batch_size, m).to(device)
idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device)
idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device)
ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1,
idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
@once_differentiable
def backward(ctx,
grad_dist1: Tensor,
grad_dist2: Tensor,
grad_idx1=None,
grad_idx2=None) -> Tuple[Tensor, Tensor]:
"""
Args:
grad_dist1 (Tensor): Gradient of chamfer distance
(xyz1 to xyz2) with shape (B, N).
grad_dist2 (Tensor): Gradient of chamfer distance
(xyz2 to xyz1) with shape (B, N).
Returns:
Tuple[Tensor, Tensor]:
- grad_xyz1 (Tensor): Gradient of the point set with shape \
(B, N, 2).
- grad_xyz2 (Tensor):Gradient of the point set with shape \
(B, N, 2).
"""
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
device = grad_dist1.device
grad_dist1 = grad_dist1.contiguous()
grad_dist2 = grad_dist2.contiguous()
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).to(device)
ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1,
grad_xyz2)
return grad_xyz1, grad_xyz2
chamfer_distance = ChamferDistanceFunction.apply