|
from typing import Tuple, List, Union, cast |
|
import torch |
|
|
|
from kornia.geometry.transform import vflip, rotate |
|
|
|
UnionType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] |
|
|
|
|
|
def random_rotate(input: torch.Tensor) -> UnionType: |
|
r"""Rotate a tensor image or a batch of tensor images randomly. |
|
Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. |
|
Args: |
|
input tensor. |
|
Returns: |
|
torch.Tensor: The rotated input |
|
""" |
|
|
|
if not torch.is_tensor(input): |
|
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") |
|
|
|
device: torch.device = input.device |
|
input = input.unsqueeze(0) |
|
input = input.view((-1, (*input.shape[-3:]))) |
|
angle: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(-180, -180) |
|
|
|
rotated = rotate(input, angle) |
|
return rotated |
|
|
|
|
|
def random_vflip(input: torch.Tensor, p: float = 0.5, return_transform: bool = False) -> UnionType: |
|
r"""Vertically flip a tensor image or a batch of tensor images randomly with a given probability. |
|
Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. |
|
Args: |
|
p (float): probability of the image being flipped. Default value is 0.5 |
|
return_transform (bool): if ``True`` return the matrix describing the transformation applied to each |
|
input tensor. |
|
Returns: |
|
torch.Tensor: The vertically flipped input |
|
torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag |
|
is set to ``True`` |
|
""" |
|
|
|
if not torch.is_tensor(input): |
|
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") |
|
|
|
if not isinstance(p, float): |
|
raise TypeError(f"The probability should be a float number. Got {type(p)}") |
|
|
|
if not isinstance(return_transform, bool): |
|
raise TypeError(f"The return_transform flag must be a bool. Got {type(return_transform)}") |
|
|
|
device: torch.device = input.device |
|
dtype: torch.dtype = input.dtype |
|
|
|
input = input.unsqueeze(0) |
|
input = input.view((-1, (*input.shape[-3:]))) |
|
|
|
probs: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(0, 1) |
|
|
|
to_flip: torch.Tensor = probs < p |
|
flipped: torch.Tensor = input.clone() |
|
|
|
flipped[to_flip] = vflip(input[to_flip]) |
|
|
|
if return_transform: |
|
|
|
trans_mat: torch.Tensor = torch.eye(3, device=device, dtype=dtype).expand(input.shape[0], -1, -1) |
|
|
|
w: int = input.shape[-2] |
|
flip_mat: torch.Tensor = torch.tensor([[-1, 0, w], |
|
[0, 1, 0], |
|
[0, 0, 1]]) |
|
|
|
trans_mat[to_flip] = flip_mat.to(device).to(dtype) |
|
|
|
return flipped, trans_mat |
|
|
|
return flipped |
|
|