็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
2.81 kB
from typing import Tuple, List, Union, cast
import torch
from kornia.geometry.transform import vflip, rotate
UnionType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
def random_rotate(input: torch.Tensor) -> UnionType:
r"""Rotate a tensor image or a batch of tensor images randomly.
Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
Args:
input tensor.
Returns:
torch.Tensor: The rotated input
"""
if not torch.is_tensor(input):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
device: torch.device = input.device
input = input.unsqueeze(0)
input = input.view((-1, (*input.shape[-3:])))
angle: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(-180, -180)
rotated = rotate(input, angle)
return rotated
def random_vflip(input: torch.Tensor, p: float = 0.5, return_transform: bool = False) -> UnionType:
r"""Vertically flip a tensor image or a batch of tensor images randomly with a given probability.
Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`.
Args:
p (float): probability of the image being flipped. Default value is 0.5
return_transform (bool): if ``True`` return the matrix describing the transformation applied to each
input tensor.
Returns:
torch.Tensor: The vertically flipped input
torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag
is set to ``True``
"""
if not torch.is_tensor(input):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(p, float):
raise TypeError(f"The probability should be a float number. Got {type(p)}")
if not isinstance(return_transform, bool):
raise TypeError(f"The return_transform flag must be a bool. Got {type(return_transform)}")
device: torch.device = input.device
dtype: torch.dtype = input.dtype
input = input.unsqueeze(0)
input = input.view((-1, (*input.shape[-3:])))
probs: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(0, 1)
to_flip: torch.Tensor = probs < p
flipped: torch.Tensor = input.clone()
flipped[to_flip] = vflip(input[to_flip])
if return_transform:
trans_mat: torch.Tensor = torch.eye(3, device=device, dtype=dtype).expand(input.shape[0], -1, -1)
w: int = input.shape[-2]
flip_mat: torch.Tensor = torch.tensor([[-1, 0, w],
[0, 1, 0],
[0, 0, 1]])
trans_mat[to_flip] = flip_mat.to(device).to(dtype)
return flipped, trans_mat
return flipped