Hancy's picture
init
851751e
import pdb
from abc import abstractmethod
from functools import partial
import PIL
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, IterableDataset
from ..utils.aug_utils import get_lidar_transform, get_camera_transform, get_anno_transform
class DatasetBase(Dataset):
def __init__(self, data_root, split, dataset_config, aug_config, return_pcd=False, condition_key=None,
scale_factors=None, degradation=None, **kwargs):
self.data_root = data_root
self.split = split
self.data = []
self.aug_config = aug_config
self.img_size = dataset_config.size
self.fov = dataset_config.fov
self.depth_range = dataset_config.depth_range
self.filtered_map_cats = dataset_config.filtered_map_cats
self.depth_scale = dataset_config.depth_scale
self.log_scale = dataset_config.log_scale
if self.log_scale:
self.depth_thresh = (np.log2(1./255. + 1) / self.depth_scale) * 2. - 1 + 1e-6
else:
self.depth_thresh = (1./255. / self.depth_scale) * 2. - 1 + 1e-6
self.return_pcd = return_pcd
if degradation is not None and scale_factors is not None:
scaled_img_size = (int(self.img_size[0] / scale_factors[0]), int(self.img_size[1] / scale_factors[1]))
degradation_fn = {
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.degradation_transform = partial(TF.resize, size=scaled_img_size, interpolation=degradation_fn)
else:
self.degradation_transform = None
self.condition_key = condition_key
self.lidar_transform = get_lidar_transform(aug_config, split)
self.anno_transform = get_anno_transform(aug_config, split) if condition_key in ['bbox', 'center'] else None
self.view_transform = get_camera_transform(aug_config, split) if condition_key in ['camera'] else None
self.prepare_data()
def prepare_data(self):
raise NotImplementedError
def process_scan(self, range_img):
range_img = np.where(range_img < 0, 0, range_img)
if self.log_scale:
# log scale
range_img = np.log2(range_img + 0.0001 + 1)
range_img = range_img / self.depth_scale
range_img = range_img * 2. - 1.
range_img = np.clip(range_img, -1, 1)
range_img = np.expand_dims(range_img, axis=0)
# mask
range_mask = np.ones_like(range_img)
range_mask[range_img < self.depth_thresh] = -1
return range_img, range_mask
@staticmethod
def load_lidar_sweep(*args, **kwargs):
raise NotImplementedError
@staticmethod
def load_semantic_map(*args, **kwargs):
raise NotImplementedError
@staticmethod
def load_camera(*args, **kwargs):
raise NotImplementedError
@staticmethod
def load_annotation(*args, **kwargs):
raise NotImplementedError
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
example = dict()
return example
class Txt2ImgIterableBaseDataset(IterableDataset):
"""
Define an interface to make the IterableDatasets for text2img data chainable
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass