File size: 4,273 Bytes
dfebd8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import List
import torch
from torch import Tensor
class BBox(object):
def __init__(self, left: float, top: float, right: float, bottom: float):
super().__init__()
self.left = left
self.top = top
self.right = right
self.bottom = bottom
def __repr__(self) -> str:
return 'BBox[l={:.1f}, t={:.1f}, r={:.1f}, b={:.1f}]'.format(
self.left, self.top, self.right, self.bottom)
def tolist(self) -> List[float]:
return [self.left, self.top, self.right, self.bottom]
@staticmethod
def to_center_base(bboxes: Tensor) -> Tensor:
return torch.stack([
(bboxes[..., 0] + bboxes[..., 2]) / 2,
(bboxes[..., 1] + bboxes[..., 3]) / 2,
bboxes[..., 2] - bboxes[..., 0],
bboxes[..., 3] - bboxes[..., 1]
], dim=-1)
@staticmethod
def from_center_base(center_based_bboxes: Tensor) -> Tensor:
return torch.stack([
center_based_bboxes[..., 0] - center_based_bboxes[..., 2] / 2,
center_based_bboxes[..., 1] - center_based_bboxes[..., 3] / 2,
center_based_bboxes[..., 0] + center_based_bboxes[..., 2] / 2,
center_based_bboxes[..., 1] + center_based_bboxes[..., 3] / 2
], dim=-1)
@staticmethod
def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor:
center_based_src_bboxes = BBox.to_center_base(src_bboxes)
center_based_dst_bboxes = BBox.to_center_base(dst_bboxes)
transformers = torch.stack([
(center_based_dst_bboxes[..., 0] - center_based_src_bboxes[..., 0]) / center_based_src_bboxes[..., 2],
(center_based_dst_bboxes[..., 1] - center_based_src_bboxes[..., 1]) / center_based_src_bboxes[..., 3],
torch.log(center_based_dst_bboxes[..., 2] / center_based_src_bboxes[..., 2]),
torch.log(center_based_dst_bboxes[..., 3] / center_based_src_bboxes[..., 3])
], dim=-1)
return transformers
@staticmethod
def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor:
center_based_src_bboxes = BBox.to_center_base(src_bboxes)
center_based_dst_bboxes = torch.stack([
transformers[..., 0] * center_based_src_bboxes[..., 2] + center_based_src_bboxes[..., 0],
transformers[..., 1] * center_based_src_bboxes[..., 3] + center_based_src_bboxes[..., 1],
torch.exp(transformers[..., 2]) * center_based_src_bboxes[..., 2],
torch.exp(transformers[..., 3]) * center_based_src_bboxes[..., 3]
], dim=-1)
dst_bboxes = BBox.from_center_base(center_based_dst_bboxes)
return dst_bboxes
@staticmethod
def iou(source: Tensor, other: Tensor) -> Tensor:
source, other = source.unsqueeze(dim=-2).repeat(1, 1, other.shape[-2], 1), \
other.unsqueeze(dim=-3).repeat(1, source.shape[-2], 1, 1)
source_area = (source[..., 2] - source[..., 0]) * (source[..., 3] - source[..., 1])
other_area = (other[..., 2] - other[..., 0]) * (other[..., 3] - other[..., 1])
intersection_left = torch.max(source[..., 0], other[..., 0])
intersection_top = torch.max(source[..., 1], other[..., 1])
intersection_right = torch.min(source[..., 2], other[..., 2])
intersection_bottom = torch.min(source[..., 3], other[..., 3])
intersection_width = torch.clamp(intersection_right - intersection_left, min=0)
intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0)
intersection_area = intersection_width * intersection_height
return intersection_area / (source_area + other_area - intersection_area)
@staticmethod
def inside(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor:
return ((bboxes[..., 0] >= left) * (bboxes[..., 1] >= top) *
(bboxes[..., 2] <= right) * (bboxes[..., 3] <= bottom))
@staticmethod
def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor:
bboxes[..., [0, 2]] = bboxes[..., [0, 2]].clamp(min=left, max=right)
bboxes[..., [1, 3]] = bboxes[..., [1, 3]].clamp(min=top, max=bottom)
return bboxes
|