็™ฝ้นญๅ…ˆ็”Ÿ
init
abd2a81
raw
history blame
8.88 kB
import numbers
import numpy as np
import torch
from torchvision.transforms.functional import _is_pil_image, _is_numpy_image
from torch_lydorn.torch.utils.complex import complex_mul, complex_abs_squared
try:
import accimage
except ImportError:
accimage = None
__all__ = ["to_tensor", "batch_normalize"]
def _is_numpy(img):
return isinstance(img, np.ndarray)
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor without typecasting and rescaling
See ``ToTensor`` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1)))
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
return img
def batch_normalize(tensor, mean, std, inplace=False):
"""Normalize a batched tensor image with batched mean and batched standard deviation.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
Args:
tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized.
mean (sequence): Tensor means of size (B, C).
std (sequence): Tensor standard deviations of size (B, C).
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
assert len(tensor.shape) == 4, \
"tensor should have 4 dims (B, H, W, C) , not {}".format(len(tensor.shape))
assert len(mean.shape) == len(std.shape) == 2, \
"mean and std should have 2 dims (B, C) , not {} and {}".format(len(mean.shape), len(std.shape))
assert tensor.shape[1] == mean.shape[1] == std.shape[1], \
"tensor, mean and std should have the same number of channels, not {}, {} and {}".format(tensor.shape[1], mean.shape[1], std.shape[1])
if not inplace:
tensor = tensor.clone()
mean = mean.to(tensor.dtype)
std = std.to(tensor.dtype)
tensor.sub_(mean[..., None, None]).div_(std[..., None, None])
return tensor
def batch_denormalize(tensor, mean, std, inplace=False):
"""Denormalize a batched tensor image with batched mean and batched standard deviation.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
Args:
tensor (Tensor): Tensor image of size (B, C, H, W) to be normalized.
mean (sequence): Tensor means of size (B, C).
std (sequence): Tensor standard deviations of size (B, C).
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
assert len(tensor.shape) == 4, \
"tensor should have 4 dims (B, H, W, C) , not {}".format(len(tensor.shape))
assert len(mean.shape) == len(std.shape) == 2, \
"mean and std should have 2 dims (B, C) , not {} and {}".format(len(mean.shape), len(std.shape))
assert tensor.shape[1] == mean.shape[1] == std.shape[1], \
"tensor, mean and std should have the same number of channels, not {}, {} and {}".format( tensor.shape[-1], mean.shape[-1], std.shape[-1])
if not inplace:
tensor = tensor.clone()
mean = mean.to(tensor.dtype)
std = std.to(tensor.dtype)
tensor.mul_(std[..., None, None]).add_(mean[..., None, None])
return tensor
def crop(tensor, top, left, height, width):
"""Crop the given Tensor batch of images.
Args:
tensor (B, C, H, W): Tensor to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
Tensor: Cropped image.
"""
return tensor[..., top:top+height, left:left+width]
def center_crop(tensor, output_size):
"""Crop the given tensor batch of images and resize it to desired size.
Args:
tensor (B, C, H, W): Tensor to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions
Returns:
Tensor: Cropped tensor.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
tensor_height, tensor_width = tensor.shape[-2:]
crop_height, crop_width = output_size
crop_top = int(round((tensor_height - crop_height) / 2.))
crop_left = int(round((tensor_width - crop_width) / 2.))
return crop(tensor, crop_top, crop_left, crop_height, crop_width)
def rotate_anglefield(angle_field, angle):
"""
:param angle_field: (B, 1, H, W), in radians
:param angle_deg: (B) in degrees
:return:
"""
assert len(angle_field.shape) == 4, "angle_field should have shape (B, 1, H, W)"
assert len(angle.shape) == 1, "angle should have shape (B)"
assert angle_field.shape[0] == angle.shape[0], "angle_field and angle should have the same batch size"
angle_field += np.pi * angle[:, None, None, None] / 180
return angle_field
def vflip_anglefield(angle_field):
"""
:param angle_field: (B, 1, H, W), in radians
:return:
"""
assert len(angle_field.shape) == 4, "angle_field should have shape (B, 1, H, W)"
angle_field = np.pi - angle_field # Angle is expressed in ij coordinate (it's a horizontal flip in xy)
return angle_field
def rotate_framefield(framefield, angle):
"""
ONly rotates values of the framefield, does not rotate the spatial domain (use already-made torch functions to do that).
@param framefield: shape (B, 4, H, W). The 4 channels represent the c_0 and c_2 complex coefficients.
@param angle: in degrees
@return:
"""
assert framefield.shape[1] == 4, f"framefield should have shape (B, 4, H, W), not {framefield.shape}"
rad = np.pi * angle / 180
z_4angle = torch.tensor([np.cos(4*rad), np.sin(4*rad)], dtype=framefield.dtype, device=framefield.device)
z_2angle = torch.tensor([np.cos(2*rad), np.sin(2*rad)], dtype=framefield.dtype, device=framefield.device)
framefield[:, :2, :, :] = complex_mul(framefield[:, :2, :, :], z_4angle[None, :, None, None], complex_dim=1)
framefield[:, 2:, :, :] = complex_mul(framefield[:, 2:, :, :], z_2angle[None, :, None, None], complex_dim=1)
return framefield
def vflip_framefield(framefield):
"""
Flips the framefield vertically. This means switching the signs of the real part of u and v
(this is because the framefield is in ij coordinates: it's a horizontal flip in xy), which translates to
switching the signs of the imaginary parts of c_0 and c_2.
@param framefield: shape (B, 4, H, W). The 4 channels represent the c_0 and c_2 complex coefficients.
@return:
"""
assert framefield.shape[1] == 4, f"framefield should have shape (B, 4, H, W), not {framefield.shape}"
framefield[:, 1, :, :] = - framefield[:, 1, :, :]
framefield[:, 3, :, :] = - framefield[:, 3, :, :]
return framefield
def draw_circle(draw, center, radius, fill):
draw.ellipse([center[0] - radius,
center[1] - radius,
center[0] + radius,
center[1] + radius], fill=fill, outline=None)