import cv2 import numpy as np import torch from PIL import Image from det_map.data.datasets.dataclasses import Camera from det_map.data.pipelines.color_utils import bgr2hsv, hsv2bgr, mmlabNormalize class PrepareImageInputs(object): """Load multi channel images from a list of separate channel files. Expects results['img_filename'] to be a list of filenames. Args: to_float32 (bool): Whether to convert the img to float32. Defaults to False. color_type (str): Color type of the file. Defaults to 'unchanged'. """ def __init__( self, data_config, is_train=False, opencv_pp=False, ): self.is_train = is_train self.data_config = data_config self.normalize_img = mmlabNormalize self.opencv_pp = opencv_pp def get_rot(self, h): return torch.Tensor([ [np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)], ]) def img_transform(self, img, post_rot, post_tran, resize, resize_dims, crop, flip, rotate): # adjust image if not self.opencv_pp: img = self.img_transform_core(img, resize_dims, crop, flip, rotate) # post-homography transformation post_rot *= resize post_tran -= torch.Tensor(crop[:2]) if flip: A = torch.Tensor([[-1, 0], [0, 1]]) b = torch.Tensor([crop[2] - crop[0], 0]) post_rot = A.matmul(post_rot) post_tran = A.matmul(post_tran) + b A = self.get_rot(rotate / 180 * np.pi) b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 b = A.matmul(-b) + b post_rot = A.matmul(post_rot) post_tran = A.matmul(post_tran) + b if self.opencv_pp: img = self.img_transform_core_opencv(img, post_rot, post_tran, crop) return img, post_rot, post_tran def img_transform_core_opencv(self, img, post_rot, post_tran, crop): img = np.array(img).astype(np.float32) img = cv2.warpAffine(img, np.concatenate([post_rot, post_tran.reshape(2, 1)], axis=1), (crop[2] - crop[0], crop[3] - crop[1]), flags=cv2.INTER_LINEAR) return img def img_transform_core(self, img, resize_dims, crop, flip, rotate): # adjust image img = img.resize(resize_dims) img = img.crop(crop) if flip: img = img.transpose(method=Image.FLIP_LEFT_RIGHT) img = img.rotate(rotate) return img def sample_augmentation(self, H, W, flip=None, scale=None): fH, fW = eval(self.data_config['input_size']) if self.is_train: resize = float(fW) / float(W) resize += np.random.uniform(*eval(self.data_config['resize'])) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims random_crop_height = \ self.data_config.get('random_crop_height', False) if random_crop_height: crop_h = int(np.random.uniform(max(0.3 * newH, newH - fH), newH - fH)) else: crop_h = \ int((1 - np.random.uniform(*eval(self.data_config['crop_h']))) * newH) - fH crop_w = int(np.random.uniform(0, max(0, newW - fW))) crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) flip = self.data_config['flip'] and np.random.choice([0, 1]) rotate = np.random.uniform(*eval(self.data_config['rot'])) if self.data_config.get('vflip', False) and np.random.choice([0, 1]): rotate += 180 else: resize = float(fW) / float(W) if scale is not None: resize += scale else: resize += self.data_config.get('resize_test', 0.0) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.mean(eval(self.data_config['crop_h']))) * newH) - fH crop_w = int(max(0, newW - fW) / 2) crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) flip = False if flip is None else flip rotate = 0 return resize, resize_dims, crop, flip, rotate def photo_metric_distortion(self, img, pmd): """Call function to perform photometric distortion on images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with images distorted. """ if np.random.rand() > pmd.get('rate', 1.0): return img img = np.array(img).astype(np.float32) assert img.dtype == np.float32, \ 'PhotoMetricDistortion needs the input image of dtype np.float32,' \ ' please set "to_float32=True" in "LoadImageFromFile" pipeline' # random brightness if np.random.randint(2): delta = np.random.uniform(-pmd['brightness_delta'], pmd['brightness_delta']) img += delta # mode == 0 --> do random contrast first # mode == 1 --> do random contrast last mode = np.random.randint(2) if mode == 1: if np.random.randint(2): alpha = np.random.uniform(pmd['contrast_lower'], pmd['contrast_upper']) img *= alpha # convert color from BGR to HSV img = bgr2hsv(img) # random saturation if np.random.randint(2): img[..., 1] *= np.random.uniform(pmd['saturation_lower'], pmd['saturation_upper']) # random hue if np.random.randint(2): img[..., 0] += np.random.uniform(-pmd['hue_delta'], pmd['hue_delta']) img[..., 0][img[..., 0] > 360] -= 360 img[..., 0][img[..., 0] < 0] += 360 # convert color from HSV to BGR img = hsv2bgr(img) # random contrast if mode == 0: if np.random.randint(2): alpha = np.random.uniform(pmd['contrast_lower'], pmd['contrast_upper']) img *= alpha # randomly swap channels if np.random.randint(2): img = img[..., np.random.permutation(3)] return Image.fromarray(img.astype(np.uint8)) def get_inputs(self, cam: Camera, flip=None, scale=None): img = Image.fromarray(cam.image) # original copy of image cam.canvas = torch.tensor(np.array(img)) post_rot = torch.eye(2) post_tran = torch.zeros(2) # image view augmentation (resize, crop, horizontal flip, rotate) img_augs = self.sample_augmentation( H=img.height, W=img.width, flip=flip, scale=scale) resize, resize_dims, crop, flip, rotate = img_augs img, post_rot2, post_tran2 = \ self.img_transform(img, post_rot, post_tran, resize=resize, resize_dims=resize_dims, crop=crop, flip=flip, rotate=rotate) # for convenience, make augmentation matrices 3x3 post_tran = torch.zeros(3) post_rot = torch.eye(3) post_tran[:2] = post_tran2 post_rot[:2, :2] = post_rot2 if self.is_train and self.data_config.get('pmd', None) is not None: img = self.photo_metric_distortion(img, self.data_config['pmd']) # original image cam.image = self.normalize_img(img) cam.post_rot = post_rot cam.post_tran = post_tran cam.sensor2lidar_rotation = torch.tensor(cam.sensor2lidar_rotation) cam.sensor2lidar_translation = torch.tensor(cam.sensor2lidar_translation) cam.intrinsics = torch.tensor(cam.intrinsics) cam.distortion = torch.tensor(cam.distortion) return cam def __call__(self, results): return self.get_inputs(results)