Spaces:
Runtime error
Runtime error
| # From https://github.com/TRI-ML/KP2D. | |
| # Copyright 2020 Toyota Research Institute. All rights reserved. | |
| import random | |
| from math import pi | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from lanet_utils import image_grid | |
| def filter_dict(dict, keywords): | |
| """ | |
| Returns only the keywords that are part of a dictionary | |
| Parameters | |
| ---------- | |
| dictionary : dict | |
| Dictionary for filtering | |
| keywords : list of str | |
| Keywords that will be filtered | |
| Returns | |
| ------- | |
| keywords : list of str | |
| List containing the keywords that are keys in dictionary | |
| """ | |
| return [key for key in keywords if key in dict] | |
| def resize_sample(sample, image_shape, image_interpolation=Image.ANTIALIAS): | |
| """ | |
| Resizes a sample, which contains an input image. | |
| Parameters | |
| ---------- | |
| sample : dict | |
| Dictionary with sample values (output from a dataset's __getitem__ method) | |
| shape : tuple (H,W) | |
| Output shape | |
| image_interpolation : int | |
| Interpolation mode | |
| Returns | |
| ------- | |
| sample : dict | |
| Resized sample | |
| """ | |
| # image | |
| image_transform = transforms.Resize(image_shape, interpolation=image_interpolation) | |
| sample["image"] = image_transform(sample["image"]) | |
| return sample | |
| def spatial_augment_sample(sample): | |
| """Apply spatial augmentation to an image (flipping and random affine transformation).""" | |
| augment_image = transforms.Compose( | |
| [ | |
| transforms.RandomVerticalFlip(p=0.5), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)), | |
| ] | |
| ) | |
| sample["image"] = augment_image(sample["image"]) | |
| return sample | |
| def unnormalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): | |
| """Counterpart method of torchvision.transforms.Normalize.""" | |
| for t, m, s in zip(tensor, mean, std): | |
| t.div_(1 / s).sub_(-m) | |
| return tensor | |
| def sample_homography( | |
| shape, | |
| perspective=True, | |
| scaling=True, | |
| rotation=True, | |
| translation=True, | |
| n_scales=100, | |
| n_angles=100, | |
| scaling_amplitude=0.1, | |
| perspective_amplitude=0.4, | |
| patch_ratio=0.8, | |
| max_angle=pi / 4, | |
| ): | |
| """Sample a random homography that includes perspective, scale, translation and rotation operations.""" | |
| width = float(shape[1]) | |
| hw_ratio = float(shape[0]) / float(shape[1]) | |
| pts1 = np.stack([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]], axis=0) | |
| pts2 = pts1.copy() * patch_ratio | |
| pts2[:, 1] *= hw_ratio | |
| if perspective: | |
| perspective_amplitude_x = np.random.normal(0.0, perspective_amplitude / 2, (2)) | |
| perspective_amplitude_y = np.random.normal( | |
| 0.0, hw_ratio * perspective_amplitude / 2, (2) | |
| ) | |
| perspective_amplitude_x = np.clip( | |
| perspective_amplitude_x, | |
| -perspective_amplitude / 2, | |
| perspective_amplitude / 2, | |
| ) | |
| perspective_amplitude_y = np.clip( | |
| perspective_amplitude_y, | |
| hw_ratio * -perspective_amplitude / 2, | |
| hw_ratio * perspective_amplitude / 2, | |
| ) | |
| pts2[0, 0] -= perspective_amplitude_x[1] | |
| pts2[0, 1] -= perspective_amplitude_y[1] | |
| pts2[1, 0] -= perspective_amplitude_x[0] | |
| pts2[1, 1] += perspective_amplitude_y[1] | |
| pts2[2, 0] += perspective_amplitude_x[1] | |
| pts2[2, 1] -= perspective_amplitude_y[0] | |
| pts2[3, 0] += perspective_amplitude_x[0] | |
| pts2[3, 1] += perspective_amplitude_y[0] | |
| if scaling: | |
| random_scales = np.random.normal(1, scaling_amplitude / 2, (n_scales)) | |
| random_scales = np.clip( | |
| random_scales, 1 - scaling_amplitude / 2, 1 + scaling_amplitude / 2 | |
| ) | |
| scales = np.concatenate([[1.0], random_scales], 0) | |
| center = np.mean(pts2, axis=0, keepdims=True) | |
| scaled = ( | |
| np.expand_dims(pts2 - center, axis=0) | |
| * np.expand_dims(np.expand_dims(scales, 1), 1) | |
| + center | |
| ) | |
| valid = np.arange(n_scales) # all scales are valid except scale=1 | |
| idx = valid[np.random.randint(valid.shape[0])] | |
| pts2 = scaled[idx] | |
| if translation: | |
| t_min, t_max = np.min(pts2 - [-1.0, -hw_ratio], axis=0), np.min( | |
| [1.0, hw_ratio] - pts2, axis=0 | |
| ) | |
| pts2 += np.expand_dims( | |
| np.stack( | |
| [ | |
| np.random.uniform(-t_min[0], t_max[0]), | |
| np.random.uniform(-t_min[1], t_max[1]), | |
| ] | |
| ), | |
| axis=0, | |
| ) | |
| if rotation: | |
| angles = np.linspace(-max_angle, max_angle, n_angles) | |
| angles = np.concatenate([[0.0], angles], axis=0) | |
| center = np.mean(pts2, axis=0, keepdims=True) | |
| rot_mat = np.reshape( | |
| np.stack( | |
| [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)], | |
| axis=1, | |
| ), | |
| [-1, 2, 2], | |
| ) | |
| rotated = ( | |
| np.matmul( | |
| np.tile(np.expand_dims(pts2 - center, axis=0), [n_angles + 1, 1, 1]), | |
| rot_mat, | |
| ) | |
| + center | |
| ) | |
| valid = np.where( | |
| np.all( | |
| (rotated >= [-1.0, -hw_ratio]) & (rotated < [1.0, hw_ratio]), | |
| axis=(1, 2), | |
| ) | |
| )[0] | |
| idx = valid[np.random.randint(valid.shape[0])] | |
| pts2 = rotated[idx] | |
| pts2[:, 1] /= hw_ratio | |
| def ax(p, q): | |
| return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] | |
| def ay(p, q): | |
| return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] | |
| a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) | |
| p_mat = np.transpose( | |
| np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0) | |
| ) | |
| homography = np.matmul(np.linalg.pinv(a_mat), p_mat).squeeze() | |
| homography = np.concatenate([homography, [1.0]]).reshape(3, 3) | |
| return homography | |
| def warp_homography(sources, homography): | |
| """Warp features given a homography | |
| Parameters | |
| ---------- | |
| sources: torch.tensor (1,H,W,2) | |
| Keypoint vector. | |
| homography: torch.Tensor (3,3) | |
| Homography. | |
| Returns | |
| ------- | |
| warped_sources: torch.tensor (1,H,W,2) | |
| Warped feature vector. | |
| """ | |
| _, H, W, _ = sources.shape | |
| warped_sources = sources.clone().squeeze() | |
| warped_sources = warped_sources.view(-1, 2) | |
| warped_sources = torch.addmm( | |
| homography[:, 2], warped_sources, homography[:, :2].t() | |
| ) | |
| warped_sources.mul_(1 / warped_sources[:, 2].unsqueeze(1)) | |
| warped_sources = warped_sources[:, :2].contiguous().view(1, H, W, 2) | |
| return warped_sources | |
| def add_noise(img, mode="gaussian", percent=0.02): | |
| """Add image noise | |
| Parameters | |
| ---------- | |
| image : np.array | |
| Input image | |
| mode: str | |
| Type of noise, from ['gaussian','salt','pepper','s&p'] | |
| percent: float | |
| Percentage image points to add noise to. | |
| Returns | |
| ------- | |
| image : np.array | |
| Image plus noise. | |
| """ | |
| original_dtype = img.dtype | |
| if mode == "gaussian": | |
| mean = 0 | |
| var = 0.1 | |
| sigma = var * 0.5 | |
| if img.ndim == 2: | |
| h, w = img.shape | |
| gauss = np.random.normal(mean, sigma, (h, w)) | |
| else: | |
| h, w, c = img.shape | |
| gauss = np.random.normal(mean, sigma, (h, w, c)) | |
| if img.dtype not in [np.float32, np.float64]: | |
| gauss = gauss * np.iinfo(img.dtype).max | |
| img = np.clip(img.astype(np.float) + gauss, 0, np.iinfo(img.dtype).max) | |
| else: | |
| img = np.clip(img.astype(np.float) + gauss, 0, 1) | |
| elif mode == "salt": | |
| print(img.dtype) | |
| s_vs_p = 1 | |
| num_salt = np.ceil(percent * img.size * s_vs_p) | |
| coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape]) | |
| if img.dtype in [np.float32, np.float64]: | |
| img[coords] = 1 | |
| else: | |
| img[coords] = np.iinfo(img.dtype).max | |
| print(img.dtype) | |
| elif mode == "pepper": | |
| s_vs_p = 0 | |
| num_pepper = np.ceil(percent * img.size * (1.0 - s_vs_p)) | |
| coords = tuple( | |
| [np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape] | |
| ) | |
| img[coords] = 0 | |
| elif mode == "s&p": | |
| s_vs_p = 0.5 | |
| # Salt mode | |
| num_salt = np.ceil(percent * img.size * s_vs_p) | |
| coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape]) | |
| if img.dtype in [np.float32, np.float64]: | |
| img[coords] = 1 | |
| else: | |
| img[coords] = np.iinfo(img.dtype).max | |
| # Pepper mode | |
| num_pepper = np.ceil(percent * img.size * (1.0 - s_vs_p)) | |
| coords = tuple( | |
| [np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape] | |
| ) | |
| img[coords] = 0 | |
| else: | |
| raise ValueError("not support mode for {}".format(mode)) | |
| noisy = img.astype(original_dtype) | |
| return noisy | |
| def non_spatial_augmentation( | |
| img_warp_ori, jitter_paramters, color_order=[0, 1, 2], to_gray=False | |
| ): | |
| """Apply non-spatial augmentation to an image (jittering, color swap, convert to gray scale, Gaussian blur).""" | |
| brightness, contrast, saturation, hue = jitter_paramters | |
| color_augmentation = transforms.ColorJitter(brightness, contrast, saturation, hue) | |
| """ | |
| augment_image = color_augmentation.get_params(brightness=[max(0, 1 - brightness), 1 + brightness], | |
| contrast=[max(0, 1 - contrast), 1 + contrast], | |
| saturation=[max(0, 1 - saturation), 1 + saturation], | |
| hue=[-hue, hue]) | |
| """ | |
| B = img_warp_ori.shape[0] | |
| img_warp = [] | |
| kernel_sizes = [0, 1, 3, 5] | |
| for b in range(B): | |
| img_warp_sub = img_warp_ori[b].cpu() | |
| img_warp_sub = torchvision.transforms.functional.to_pil_image(img_warp_sub) | |
| img_warp_sub_np = np.array(img_warp_sub) | |
| img_warp_sub_np = img_warp_sub_np[:, :, color_order] | |
| if np.random.rand() > 0.5: | |
| img_warp_sub_np = add_noise(img_warp_sub_np) | |
| rand_index = np.random.randint(4) | |
| kernel_size = kernel_sizes[rand_index] | |
| if kernel_size > 0: | |
| img_warp_sub_np = cv2.GaussianBlur( | |
| img_warp_sub_np, (kernel_size, kernel_size), sigmaX=0 | |
| ) | |
| if to_gray: | |
| img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_RGB2GRAY) | |
| img_warp_sub_np = cv2.cvtColor(img_warp_sub_np, cv2.COLOR_GRAY2RGB) | |
| img_warp_sub = Image.fromarray(img_warp_sub_np) | |
| img_warp_sub = color_augmentation(img_warp_sub) | |
| img_warp_sub = torchvision.transforms.functional.to_tensor(img_warp_sub).to( | |
| img_warp_ori.device | |
| ) | |
| img_warp.append(img_warp_sub) | |
| img_warp = torch.stack(img_warp, dim=0) | |
| return img_warp | |
| def ha_augment_sample( | |
| data, | |
| jitter_paramters=[0.5, 0.5, 0.2, 0.05], | |
| patch_ratio=0.7, | |
| scaling_amplitude=0.2, | |
| max_angle=pi / 4, | |
| ): | |
| """Apply Homography Adaptation image augmentation.""" | |
| input_img = data["image"].unsqueeze(0) | |
| _, _, H, W = input_img.shape | |
| device = input_img.device | |
| homography = ( | |
| torch.from_numpy( | |
| sample_homography( | |
| [H, W], | |
| patch_ratio=patch_ratio, | |
| scaling_amplitude=scaling_amplitude, | |
| max_angle=max_angle, | |
| ) | |
| ) | |
| .float() | |
| .to(device) | |
| ) | |
| homography_inv = torch.inverse(homography) | |
| source = ( | |
| image_grid( | |
| 1, H, W, dtype=input_img.dtype, device=device, ones=False, normalized=True | |
| ) | |
| .clone() | |
| .permute(0, 2, 3, 1) | |
| ) | |
| target_warped = warp_homography(source, homography) | |
| img_warp = torch.nn.functional.grid_sample(input_img, target_warped) | |
| color_order = [0, 1, 2] | |
| if np.random.rand() > 0.5: | |
| random.shuffle(color_order) | |
| to_gray = False | |
| if np.random.rand() > 0.5: | |
| to_gray = True | |
| input_img = non_spatial_augmentation( | |
| input_img, | |
| jitter_paramters=jitter_paramters, | |
| color_order=color_order, | |
| to_gray=to_gray, | |
| ) | |
| img_warp = non_spatial_augmentation( | |
| img_warp, | |
| jitter_paramters=jitter_paramters, | |
| color_order=color_order, | |
| to_gray=to_gray, | |
| ) | |
| data["image"] = input_img.squeeze() | |
| data["image_aug"] = img_warp.squeeze() | |
| data["homography"] = homography | |
| data["homography_inv"] = homography_inv | |
| return data | |