Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Union | |
import torch | |
from torch import Tensor | |
from mmdet.registry import TASK_UTILS | |
from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor | |
from .base_bbox_coder import BaseBBoxCoder | |
class YOLOBBoxCoder(BaseBBoxCoder): | |
"""YOLO BBox coder. | |
Following `YOLO <https://arxiv.org/abs/1506.02640>`_, this coder divide | |
image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh). | |
cx, cy in [0., 1.], denotes relative center position w.r.t the center of | |
bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`. | |
Args: | |
eps (float): Min value of cx, cy when encoding. | |
""" | |
def __init__(self, eps: float = 1e-6, **kwargs): | |
super().__init__(**kwargs) | |
self.eps = eps | |
def encode(self, bboxes: Union[Tensor, BaseBoxes], | |
gt_bboxes: Union[Tensor, BaseBoxes], | |
stride: Union[Tensor, int]) -> Tensor: | |
"""Get box regression transformation deltas that can be used to | |
transform the ``bboxes`` into the ``gt_bboxes``. | |
Args: | |
bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, | |
e.g., anchors. | |
gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the | |
transformation, e.g., ground-truth boxes. | |
stride (torch.Tensor | int): Stride of bboxes. | |
Returns: | |
torch.Tensor: Box transformation deltas | |
""" | |
bboxes = get_box_tensor(bboxes) | |
gt_bboxes = get_box_tensor(gt_bboxes) | |
assert bboxes.size(0) == gt_bboxes.size(0) | |
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 | |
x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5 | |
y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5 | |
w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0] | |
h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1] | |
x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5 | |
y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5 | |
w = bboxes[..., 2] - bboxes[..., 0] | |
h = bboxes[..., 3] - bboxes[..., 1] | |
w_target = torch.log((w_gt / w).clamp(min=self.eps)) | |
h_target = torch.log((h_gt / h).clamp(min=self.eps)) | |
x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp( | |
self.eps, 1 - self.eps) | |
y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp( | |
self.eps, 1 - self.eps) | |
encoded_bboxes = torch.stack( | |
[x_center_target, y_center_target, w_target, h_target], dim=-1) | |
return encoded_bboxes | |
def decode(self, bboxes: Union[Tensor, BaseBoxes], pred_bboxes: Tensor, | |
stride: Union[Tensor, int]) -> Union[Tensor, BaseBoxes]: | |
"""Apply transformation `pred_bboxes` to `boxes`. | |
Args: | |
boxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes, | |
e.g. anchors. | |
pred_bboxes (torch.Tensor): Encoded boxes with shape | |
stride (torch.Tensor | int): Strides of bboxes. | |
Returns: | |
Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. | |
""" | |
bboxes = get_box_tensor(bboxes) | |
assert pred_bboxes.size(-1) == bboxes.size(-1) == 4 | |
xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + ( | |
pred_bboxes[..., :2] - 0.5) * stride | |
whs = (bboxes[..., 2:] - | |
bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp() | |
decoded_bboxes = torch.stack( | |
(xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] - | |
whs[..., 1], xy_centers[..., 0] + whs[..., 0], | |
xy_centers[..., 1] + whs[..., 1]), | |
dim=-1) | |
if self.use_box_type: | |
decoded_bboxes = HorizontalBoxes(decoded_bboxes) | |
return decoded_bboxes | |