from typing import List import torch from torch import Tensor class BBox(object): def __init__(self, left: float, top: float, right: float, bottom: float): super().__init__() self.left = left self.top = top self.right = right self.bottom = bottom def __repr__(self) -> str: return 'BBox[l={:.1f}, t={:.1f}, r={:.1f}, b={:.1f}]'.format( self.left, self.top, self.right, self.bottom) def tolist(self) -> List[float]: return [self.left, self.top, self.right, self.bottom] @staticmethod def to_center_base(bboxes: Tensor) -> Tensor: return torch.stack([ (bboxes[..., 0] + bboxes[..., 2]) / 2, (bboxes[..., 1] + bboxes[..., 3]) / 2, bboxes[..., 2] - bboxes[..., 0], bboxes[..., 3] - bboxes[..., 1] ], dim=-1) @staticmethod def from_center_base(center_based_bboxes: Tensor) -> Tensor: return torch.stack([ center_based_bboxes[..., 0] - center_based_bboxes[..., 2] / 2, center_based_bboxes[..., 1] - center_based_bboxes[..., 3] / 2, center_based_bboxes[..., 0] + center_based_bboxes[..., 2] / 2, center_based_bboxes[..., 1] + center_based_bboxes[..., 3] / 2 ], dim=-1) @staticmethod def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor: center_based_src_bboxes = BBox.to_center_base(src_bboxes) center_based_dst_bboxes = BBox.to_center_base(dst_bboxes) transformers = torch.stack([ (center_based_dst_bboxes[..., 0] - center_based_src_bboxes[..., 0]) / center_based_src_bboxes[..., 2], (center_based_dst_bboxes[..., 1] - center_based_src_bboxes[..., 1]) / center_based_src_bboxes[..., 3], torch.log(center_based_dst_bboxes[..., 2] / center_based_src_bboxes[..., 2]), torch.log(center_based_dst_bboxes[..., 3] / center_based_src_bboxes[..., 3]) ], dim=-1) return transformers @staticmethod def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor: center_based_src_bboxes = BBox.to_center_base(src_bboxes) center_based_dst_bboxes = torch.stack([ transformers[..., 0] * center_based_src_bboxes[..., 2] + center_based_src_bboxes[..., 0], transformers[..., 1] * center_based_src_bboxes[..., 3] + center_based_src_bboxes[..., 1], torch.exp(transformers[..., 2]) * center_based_src_bboxes[..., 2], torch.exp(transformers[..., 3]) * center_based_src_bboxes[..., 3] ], dim=-1) dst_bboxes = BBox.from_center_base(center_based_dst_bboxes) return dst_bboxes @staticmethod def iou(source: Tensor, other: Tensor) -> Tensor: source, other = source.unsqueeze(dim=-2).repeat(1, 1, other.shape[-2], 1), \ other.unsqueeze(dim=-3).repeat(1, source.shape[-2], 1, 1) source_area = (source[..., 2] - source[..., 0]) * (source[..., 3] - source[..., 1]) other_area = (other[..., 2] - other[..., 0]) * (other[..., 3] - other[..., 1]) intersection_left = torch.max(source[..., 0], other[..., 0]) intersection_top = torch.max(source[..., 1], other[..., 1]) intersection_right = torch.min(source[..., 2], other[..., 2]) intersection_bottom = torch.min(source[..., 3], other[..., 3]) intersection_width = torch.clamp(intersection_right - intersection_left, min=0) intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0) intersection_area = intersection_width * intersection_height return intersection_area / (source_area + other_area - intersection_area) @staticmethod def inside(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor: return ((bboxes[..., 0] >= left) * (bboxes[..., 1] >= top) * (bboxes[..., 2] <= right) * (bboxes[..., 3] <= bottom)) @staticmethod def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor: bboxes[..., [0, 2]] = bboxes[..., [0, 2]].clamp(min=left, max=right) bboxes[..., [1, 3]] = bboxes[..., [1, 3]].clamp(min=top, max=bottom) return bboxes