|
import numbers |
|
import numpy as np |
|
import torch |
|
from torchvision.transforms.functional import _is_pil_image, _is_numpy_image |
|
|
|
from torch_lydorn.torch.utils.complex import complex_mul, complex_abs_squared |
|
|
|
try: |
|
import accimage |
|
except ImportError: |
|
accimage = None |
|
|
|
|
|
__all__ = ["to_tensor", "batch_normalize"] |
|
|
|
|
|
def _is_numpy(img): |
|
return isinstance(img, np.ndarray) |
|
|
|
|
|
def to_tensor(pic): |
|
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor without typecasting and rescaling |
|
|
|
See ``ToTensor`` for more details. |
|
|
|
Args: |
|
pic (PIL Image or numpy.ndarray): Image to be converted to tensor. |
|
|
|
Returns: |
|
Tensor: Converted image. |
|
""" |
|
if not(_is_pil_image(pic) or _is_numpy(pic)): |
|
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) |
|
|
|
if _is_numpy(pic) and not _is_numpy_image(pic): |
|
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) |
|
|
|
if isinstance(pic, np.ndarray): |
|
|
|
if pic.ndim == 2: |
|
pic = pic[:, :, None] |
|
|
|
img = torch.from_numpy(pic.transpose((2, 0, 1))) |
|
return img |
|
|
|
if accimage is not None and isinstance(pic, accimage.Image): |
|
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8) |
|
pic.copyto(nppic) |
|
return torch.from_numpy(nppic) |
|
|
|
|
|
if pic.mode == 'I': |
|
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) |
|
elif pic.mode == 'I;16': |
|
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) |
|
elif pic.mode == 'F': |
|
img = torch.from_numpy(np.array(pic, np.float32, copy=False)) |
|
elif pic.mode == '1': |
|
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) |
|
else: |
|
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) |
|
|
|
if pic.mode == 'YCbCr': |
|
nchannel = 3 |
|
elif pic.mode == 'I;16': |
|
nchannel = 1 |
|
else: |
|
nchannel = len(pic.mode) |
|
img = img.view(pic.size[1], pic.size[0], nchannel) |
|
|
|
|
|
img = img.transpose(0, 1).transpose(0, 2).contiguous() |
|
return img |
|
|
|
|
|
def batch_normalize(tensor, mean, std, inplace=False): |
|
"""Normalize a batched tensor image with batched mean and batched standard deviation. |
|
.. note:: |
|
This transform acts out of place by default, i.e., it does not mutates the input tensor. |
|
Args: |
|
tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized. |
|
mean (sequence): Tensor means of size (B, C). |
|
std (sequence): Tensor standard deviations of size (B, C). |
|
inplace(bool,optional): Bool to make this operation inplace. |
|
Returns: |
|
Tensor: Normalized Tensor image. |
|
""" |
|
assert len(tensor.shape) == 4, \ |
|
"tensor should have 4 dims (B, H, W, C) , not {}".format(len(tensor.shape)) |
|
assert len(mean.shape) == len(std.shape) == 2, \ |
|
"mean and std should have 2 dims (B, C) , not {} and {}".format(len(mean.shape), len(std.shape)) |
|
assert tensor.shape[1] == mean.shape[1] == std.shape[1], \ |
|
"tensor, mean and std should have the same number of channels, not {}, {} and {}".format(tensor.shape[1], mean.shape[1], std.shape[1]) |
|
|
|
if not inplace: |
|
tensor = tensor.clone() |
|
|
|
mean = mean.to(tensor.dtype) |
|
std = std.to(tensor.dtype) |
|
|
|
tensor.sub_(mean[..., None, None]).div_(std[..., None, None]) |
|
return tensor |
|
|
|
|
|
def batch_denormalize(tensor, mean, std, inplace=False): |
|
"""Denormalize a batched tensor image with batched mean and batched standard deviation. |
|
.. note:: |
|
This transform acts out of place by default, i.e., it does not mutates the input tensor. |
|
Args: |
|
tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized. |
|
mean (sequence): Tensor means of size (B, C). |
|
std (sequence): Tensor standard deviations of size (B, C). |
|
inplace(bool,optional): Bool to make this operation inplace. |
|
Returns: |
|
Tensor: Normalized Tensor image. |
|
""" |
|
assert len(tensor.shape) == 4, \ |
|
"tensor should have 4 dims (B, H, W, C) , not {}".format(len(tensor.shape)) |
|
assert len(mean.shape) == len(std.shape) == 2, \ |
|
"mean and std should have 2 dims (B, C) , not {} and {}".format(len(mean.shape), len(std.shape)) |
|
assert tensor.shape[1] == mean.shape[1] == std.shape[1], \ |
|
"tensor, mean and std should have the same number of channels, not {}, {} and {}".format( tensor.shape[-1], mean.shape[-1], std.shape[-1]) |
|
|
|
if not inplace: |
|
tensor = tensor.clone() |
|
|
|
mean = mean.to(tensor.dtype) |
|
std = std.to(tensor.dtype) |
|
|
|
tensor.mul_(std[..., None, None]).add_(mean[..., None, None]) |
|
return tensor |
|
|
|
|
|
def crop(tensor, top, left, height, width): |
|
"""Crop the given Tensor batch of images. |
|
Args: |
|
tensor (B, C, H, W): Tensor to be cropped. (0,0) denotes the top left corner of the image. |
|
top (int): Vertical component of the top left corner of the crop box. |
|
left (int): Horizontal component of the top left corner of the crop box. |
|
height (int): Height of the crop box. |
|
width (int): Width of the crop box. |
|
Returns: |
|
Tensor: Cropped image. |
|
""" |
|
return tensor[..., top:top+height, left:left+width] |
|
|
|
|
|
def center_crop(tensor, output_size): |
|
"""Crop the given tensor batch of images and resize it to desired size. |
|
|
|
Args: |
|
tensor (B, C, H, W): Tensor to be cropped. |
|
output_size (sequence or int): (height, width) of the crop box. If int, |
|
it is used for both directions |
|
Returns: |
|
Tensor: Cropped tensor. |
|
""" |
|
if isinstance(output_size, numbers.Number): |
|
output_size = (int(output_size), int(output_size)) |
|
tensor_height, tensor_width = tensor.shape[-2:] |
|
crop_height, crop_width = output_size |
|
crop_top = int(round((tensor_height - crop_height) / 2.)) |
|
crop_left = int(round((tensor_width - crop_width) / 2.)) |
|
return crop(tensor, crop_top, crop_left, crop_height, crop_width) |
|
|
|
|
|
def rotate_anglefield(angle_field, angle): |
|
""" |
|
:param angle_field: (B, 1, H, W), in radians |
|
:param angle_deg: (B) in degrees |
|
:return: |
|
""" |
|
assert len(angle_field.shape) == 4, "angle_field should have shape (B, 1, H, W)" |
|
assert len(angle.shape) == 1, "angle should have shape (B)" |
|
assert angle_field.shape[0] == angle.shape[0], "angle_field and angle should have the same batch size" |
|
angle_field += np.pi * angle[:, None, None, None] / 180 |
|
return angle_field |
|
|
|
|
|
def vflip_anglefield(angle_field): |
|
""" |
|
|
|
:param angle_field: (B, 1, H, W), in radians |
|
:return: |
|
""" |
|
assert len(angle_field.shape) == 4, "angle_field should have shape (B, 1, H, W)" |
|
angle_field = np.pi - angle_field |
|
return angle_field |
|
|
|
|
|
def rotate_framefield(framefield, angle): |
|
""" |
|
ONly rotates values of the framefield, does not rotate the spatial domain (use already-made torch functions to do that). |
|
|
|
@param framefield: shape (B, 4, H, W). The 4 channels represent the c_0 and c_2 complex coefficients. |
|
@param angle: in degrees |
|
@return: |
|
""" |
|
assert framefield.shape[1] == 4, f"framefield should have shape (B, 4, H, W), not {framefield.shape}" |
|
rad = np.pi * angle / 180 |
|
z_4angle = torch.tensor([np.cos(4*rad), np.sin(4*rad)], dtype=framefield.dtype, device=framefield.device) |
|
z_2angle = torch.tensor([np.cos(2*rad), np.sin(2*rad)], dtype=framefield.dtype, device=framefield.device) |
|
framefield[:, :2, :, :] = complex_mul(framefield[:, :2, :, :], z_4angle[None, :, None, None], complex_dim=1) |
|
framefield[:, 2:, :, :] = complex_mul(framefield[:, 2:, :, :], z_2angle[None, :, None, None], complex_dim=1) |
|
return framefield |
|
|
|
|
|
def vflip_framefield(framefield): |
|
""" |
|
Flips the framefield vertically. This means switching the signs of the real part of u and v |
|
(this is because the framefield is in ij coordinates: it's a horizontal flip in xy), which translates to |
|
switching the signs of the imaginary parts of c_0 and c_2. |
|
|
|
@param framefield: shape (B, 4, H, W). The 4 channels represent the c_0 and c_2 complex coefficients. |
|
@return: |
|
""" |
|
assert framefield.shape[1] == 4, f"framefield should have shape (B, 4, H, W), not {framefield.shape}" |
|
framefield[:, 1, :, :] = - framefield[:, 1, :, :] |
|
framefield[:, 3, :, :] = - framefield[:, 3, :, :] |
|
return framefield |
|
|
|
|
|
def draw_circle(draw, center, radius, fill): |
|
draw.ellipse([center[0] - radius, |
|
center[1] - radius, |
|
center[0] + radius, |
|
center[1] + radius], fill=fill, outline=None) |
|
|