|
from typing import List |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
|
|
class BBox(object): |
|
|
|
def __init__(self, left: float, top: float, right: float, bottom: float): |
|
super().__init__() |
|
self.left = left |
|
self.top = top |
|
self.right = right |
|
self.bottom = bottom |
|
|
|
def __repr__(self) -> str: |
|
return 'BBox[l={:.1f}, t={:.1f}, r={:.1f}, b={:.1f}]'.format( |
|
self.left, self.top, self.right, self.bottom) |
|
|
|
def tolist(self) -> List[float]: |
|
return [self.left, self.top, self.right, self.bottom] |
|
|
|
@staticmethod |
|
def to_center_base(bboxes: Tensor) -> Tensor: |
|
return torch.stack([ |
|
(bboxes[..., 0] + bboxes[..., 2]) / 2, |
|
(bboxes[..., 1] + bboxes[..., 3]) / 2, |
|
bboxes[..., 2] - bboxes[..., 0], |
|
bboxes[..., 3] - bboxes[..., 1] |
|
], dim=-1) |
|
|
|
@staticmethod |
|
def from_center_base(center_based_bboxes: Tensor) -> Tensor: |
|
return torch.stack([ |
|
center_based_bboxes[..., 0] - center_based_bboxes[..., 2] / 2, |
|
center_based_bboxes[..., 1] - center_based_bboxes[..., 3] / 2, |
|
center_based_bboxes[..., 0] + center_based_bboxes[..., 2] / 2, |
|
center_based_bboxes[..., 1] + center_based_bboxes[..., 3] / 2 |
|
], dim=-1) |
|
|
|
@staticmethod |
|
def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor: |
|
center_based_src_bboxes = BBox.to_center_base(src_bboxes) |
|
center_based_dst_bboxes = BBox.to_center_base(dst_bboxes) |
|
transformers = torch.stack([ |
|
(center_based_dst_bboxes[..., 0] - center_based_src_bboxes[..., 0]) / center_based_src_bboxes[..., 2], |
|
(center_based_dst_bboxes[..., 1] - center_based_src_bboxes[..., 1]) / center_based_src_bboxes[..., 3], |
|
torch.log(center_based_dst_bboxes[..., 2] / center_based_src_bboxes[..., 2]), |
|
torch.log(center_based_dst_bboxes[..., 3] / center_based_src_bboxes[..., 3]) |
|
], dim=-1) |
|
return transformers |
|
|
|
@staticmethod |
|
def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor: |
|
center_based_src_bboxes = BBox.to_center_base(src_bboxes) |
|
center_based_dst_bboxes = torch.stack([ |
|
transformers[..., 0] * center_based_src_bboxes[..., 2] + center_based_src_bboxes[..., 0], |
|
transformers[..., 1] * center_based_src_bboxes[..., 3] + center_based_src_bboxes[..., 1], |
|
torch.exp(transformers[..., 2]) * center_based_src_bboxes[..., 2], |
|
torch.exp(transformers[..., 3]) * center_based_src_bboxes[..., 3] |
|
], dim=-1) |
|
dst_bboxes = BBox.from_center_base(center_based_dst_bboxes) |
|
return dst_bboxes |
|
|
|
@staticmethod |
|
def iou(source: Tensor, other: Tensor) -> Tensor: |
|
source, other = source.unsqueeze(dim=-2).repeat(1, 1, other.shape[-2], 1), \ |
|
other.unsqueeze(dim=-3).repeat(1, source.shape[-2], 1, 1) |
|
|
|
source_area = (source[..., 2] - source[..., 0]) * (source[..., 3] - source[..., 1]) |
|
other_area = (other[..., 2] - other[..., 0]) * (other[..., 3] - other[..., 1]) |
|
|
|
intersection_left = torch.max(source[..., 0], other[..., 0]) |
|
intersection_top = torch.max(source[..., 1], other[..., 1]) |
|
intersection_right = torch.min(source[..., 2], other[..., 2]) |
|
intersection_bottom = torch.min(source[..., 3], other[..., 3]) |
|
intersection_width = torch.clamp(intersection_right - intersection_left, min=0) |
|
intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0) |
|
intersection_area = intersection_width * intersection_height |
|
|
|
return intersection_area / (source_area + other_area - intersection_area) |
|
|
|
@staticmethod |
|
def inside(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor: |
|
return ((bboxes[..., 0] >= left) * (bboxes[..., 1] >= top) * |
|
(bboxes[..., 2] <= right) * (bboxes[..., 3] <= bottom)) |
|
|
|
@staticmethod |
|
def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor: |
|
bboxes[..., [0, 2]] = bboxes[..., [0, 2]].clamp(min=left, max=right) |
|
bboxes[..., [1, 3]] = bboxes[..., [1, 3]].clamp(min=top, max=bottom) |
|
return bboxes |
|
|