lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
8.62 kB
import cv2
import numpy as np
import torch
from PIL import Image
from det_map.data.datasets.dataclasses import Camera
from det_map.data.pipelines.color_utils import bgr2hsv, hsv2bgr, mmlabNormalize
class PrepareImageInputs(object):
"""Load multi channel images from a list of separate channel files.
Expects results['img_filename'] to be a list of filenames.
Args:
to_float32 (bool): Whether to convert the img to float32.
Defaults to False.
color_type (str): Color type of the file. Defaults to 'unchanged'.
"""
def __init__(
self,
data_config,
is_train=False,
opencv_pp=False,
):
self.is_train = is_train
self.data_config = data_config
self.normalize_img = mmlabNormalize
self.opencv_pp = opencv_pp
def get_rot(self, h):
return torch.Tensor([
[np.cos(h), np.sin(h)],
[-np.sin(h), np.cos(h)],
])
def img_transform(self, img, post_rot, post_tran, resize, resize_dims,
crop, flip, rotate):
# adjust image
if not self.opencv_pp:
img = self.img_transform_core(img, resize_dims, crop, flip, rotate)
# post-homography transformation
post_rot *= resize
post_tran -= torch.Tensor(crop[:2])
if flip:
A = torch.Tensor([[-1, 0], [0, 1]])
b = torch.Tensor([crop[2] - crop[0], 0])
post_rot = A.matmul(post_rot)
post_tran = A.matmul(post_tran) + b
A = self.get_rot(rotate / 180 * np.pi)
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
b = A.matmul(-b) + b
post_rot = A.matmul(post_rot)
post_tran = A.matmul(post_tran) + b
if self.opencv_pp:
img = self.img_transform_core_opencv(img, post_rot, post_tran, crop)
return img, post_rot, post_tran
def img_transform_core_opencv(self, img, post_rot, post_tran,
crop):
img = np.array(img).astype(np.float32)
img = cv2.warpAffine(img,
np.concatenate([post_rot,
post_tran.reshape(2, 1)],
axis=1),
(crop[2] - crop[0], crop[3] - crop[1]),
flags=cv2.INTER_LINEAR)
return img
def img_transform_core(self, img, resize_dims, crop, flip, rotate):
# adjust image
img = img.resize(resize_dims)
img = img.crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
return img
def sample_augmentation(self, H, W, flip=None, scale=None):
fH, fW = eval(self.data_config['input_size'])
if self.is_train:
resize = float(fW) / float(W)
resize += np.random.uniform(*eval(self.data_config['resize']))
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
random_crop_height = \
self.data_config.get('random_crop_height', False)
if random_crop_height:
crop_h = int(np.random.uniform(max(0.3 * newH, newH - fH),
newH - fH))
else:
crop_h = \
int((1 - np.random.uniform(*eval(self.data_config['crop_h']))) *
newH) - fH
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = self.data_config['flip'] and np.random.choice([0, 1])
rotate = np.random.uniform(*eval(self.data_config['rot']))
if self.data_config.get('vflip', False) and np.random.choice([0, 1]):
rotate += 180
else:
resize = float(fW) / float(W)
if scale is not None:
resize += scale
else:
resize += self.data_config.get('resize_test', 0.0)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.mean(eval(self.data_config['crop_h']))) * newH) - fH
crop_w = int(max(0, newW - fW) / 2)
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False if flip is None else flip
rotate = 0
return resize, resize_dims, crop, flip, rotate
def photo_metric_distortion(self, img, pmd):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
if np.random.rand() > pmd.get('rate', 1.0):
return img
img = np.array(img).astype(np.float32)
assert img.dtype == np.float32, \
'PhotoMetricDistortion needs the input image of dtype np.float32,' \
' please set "to_float32=True" in "LoadImageFromFile" pipeline'
# random brightness
if np.random.randint(2):
delta = np.random.uniform(-pmd['brightness_delta'],
pmd['brightness_delta'])
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = np.random.randint(2)
if mode == 1:
if np.random.randint(2):
alpha = np.random.uniform(pmd['contrast_lower'],
pmd['contrast_upper'])
img *= alpha
# convert color from BGR to HSV
img = bgr2hsv(img)
# random saturation
if np.random.randint(2):
img[..., 1] *= np.random.uniform(pmd['saturation_lower'],
pmd['saturation_upper'])
# random hue
if np.random.randint(2):
img[..., 0] += np.random.uniform(-pmd['hue_delta'], pmd['hue_delta'])
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = hsv2bgr(img)
# random contrast
if mode == 0:
if np.random.randint(2):
alpha = np.random.uniform(pmd['contrast_lower'],
pmd['contrast_upper'])
img *= alpha
# randomly swap channels
if np.random.randint(2):
img = img[..., np.random.permutation(3)]
return Image.fromarray(img.astype(np.uint8))
def get_inputs(self, cam: Camera, flip=None, scale=None):
img = Image.fromarray(cam.image)
# original copy of image
cam.canvas = torch.tensor(np.array(img))
post_rot = torch.eye(2)
post_tran = torch.zeros(2)
# image view augmentation (resize, crop, horizontal flip, rotate)
img_augs = self.sample_augmentation(
H=img.height, W=img.width, flip=flip, scale=scale)
resize, resize_dims, crop, flip, rotate = img_augs
img, post_rot2, post_tran2 = \
self.img_transform(img, post_rot,
post_tran,
resize=resize,
resize_dims=resize_dims,
crop=crop,
flip=flip,
rotate=rotate)
# for convenience, make augmentation matrices 3x3
post_tran = torch.zeros(3)
post_rot = torch.eye(3)
post_tran[:2] = post_tran2
post_rot[:2, :2] = post_rot2
if self.is_train and self.data_config.get('pmd', None) is not None:
img = self.photo_metric_distortion(img, self.data_config['pmd'])
# original image
cam.image = self.normalize_img(img)
cam.post_rot = post_rot
cam.post_tran = post_tran
cam.sensor2lidar_rotation = torch.tensor(cam.sensor2lidar_rotation)
cam.sensor2lidar_translation = torch.tensor(cam.sensor2lidar_translation)
cam.intrinsics = torch.tensor(cam.intrinsics)
cam.distortion = torch.tensor(cam.distortion)
return cam
def __call__(self, results):
return self.get_inputs(results)