Spaces:
Building
Building
import os | |
import numpy as np | |
import torch.utils.data as data | |
import umsgpack | |
from PIL import Image | |
import json | |
import torchvision.transforms as tvf | |
from .transform import BEVTransform | |
from ..schema import KITTIDataConfiguration | |
class BEVKitti360Dataset(data.Dataset): | |
_IMG_DIR = "img" | |
_BEV_MSK_DIR = "bev_msk" | |
_BEV_PLABEL_DIR = "bev_plabel_dynamic" | |
_FV_MSK_DIR = "front_msk_seam" | |
_BEV_DIR = "bev_ortho" | |
_LST_DIR = "split" | |
_PERCENTAGES_DIR = "percentages" | |
_BEV_METADATA_FILE = "metadata_ortho.bin" | |
_FV_METADATA_FILE = "metadata_front.bin" | |
def __init__(self, cfg: KITTIDataConfiguration, split_name="train"): | |
super(BEVKitti360Dataset, self).__init__() | |
self.cfg = cfg | |
self.seam_root_dir = cfg.seam_root_dir # Directory of seamless data | |
self.kitti_root_dir = cfg.dataset_root_dir # Directory of the KITTI360 data | |
self.split_name = split_name | |
self.rgb_cameras = ['front'] | |
if cfg.bev_percentage < 1: | |
self.bev_percentage = cfg.bev_percentage | |
else: | |
self.bev_percentage = int(cfg.bev_percentage) | |
# Folders | |
self._img_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._IMG_DIR) | |
self._bev_msk_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_MSK_DIR, BEVKitti360Dataset._BEV_DIR) | |
self._bev_plabel_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_PLABEL_DIR, BEVKitti360Dataset._BEV_DIR) | |
self._fv_msk_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._FV_MSK_DIR, "front") | |
self._lst_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._LST_DIR) | |
self._percentages_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._LST_DIR, BEVKitti360Dataset._PERCENTAGES_DIR) | |
# Load meta-data and split | |
self._bev_meta, self._bev_images, self._bev_images_all, self._fv_meta, self._fv_images, self._fv_images_all,\ | |
self._img_map, self.bev_percent_split = self._load_split() | |
self.tfs = self.get_augmentations() if split_name == "train" else tvf.Compose([]) | |
self.transform = BEVTransform(cfg, self.tfs) | |
def get_augmentations(self): | |
print(f"Augmentation!", "\n" * 10) | |
augmentations = [ | |
tvf.ColorJitter( | |
brightness=self.cfg.augmentations.brightness, | |
contrast=self.cfg.augmentations.contrast, | |
saturation=self.cfg.augmentations.saturation, | |
hue=self.cfg.augmentations.hue, | |
) | |
] | |
if self.cfg.augmentations.random_resized_crop: | |
augmentations.append( | |
tvf.RandomResizedCrop(scale=(0.8, 1.0)) | |
) # RandomResizedCrop | |
if self.cfg.augmentations.gaussian_noise.enabled: | |
augmentations.append( | |
tvf.GaussianNoise( | |
mean=self.cfg.augmentations.gaussian_noise.mean, | |
std=self.cfg.augmentations.gaussian_noise.std, | |
) | |
) # Gaussian noise | |
if self.cfg.augmentations.brightness_contrast.enabled: | |
augmentations.append( | |
tvf.ColorJitter( | |
brightness=self.cfg.augmentations.brightness_contrast.brightness_factor, | |
contrast=self.cfg.augmentations.brightness_contrast.contrast_factor, | |
saturation=0, # Keep saturation at 0 for brightness and contrast adjustment | |
hue=0, | |
) | |
) # Brightness and contrast adjustment | |
return tvf.Compose(augmentations) | |
# Load the train or the validation split | |
def _load_split(self): | |
with open(os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_METADATA_FILE), "rb") as fid: | |
bev_metadata = umsgpack.unpack(fid, encoding="utf-8") | |
with open(os.path.join(self.seam_root_dir, BEVKitti360Dataset._FV_METADATA_FILE), 'rb') as fid: | |
fv_metadata = umsgpack.unpack(fid, encoding="utf-8") | |
# Read the files for this split | |
with open(os.path.join(self._lst_dir, self.split_name + ".txt"), "r") as fid: | |
lst = fid.readlines() | |
lst = [line.strip() for line in lst] | |
if self.split_name == "train": | |
# Get all the frames in the train dataset. This will be used for generating samples for temporal consistency. | |
with open(os.path.join(self._lst_dir, "{}_all.txt".format(self.split_name)), 'r') as fid: | |
lst_all = fid.readlines() | |
lst_all = [line.strip() for line in lst_all] | |
# Get all the samples for which the BEV plabels have to be loaded. | |
percentage_file = os.path.join(self._percentages_dir, "{}_{}.txt".format(self.split_name, self.bev_percentage)) | |
print("Loading {}% file".format(self.bev_percentage)) | |
with open(percentage_file, 'r') as fid: | |
lst_percent = fid.readlines() | |
lst_percent = [line.strip() for line in lst_percent] | |
else: | |
lst_all = lst | |
lst_percent = lst | |
# Remove elements from lst if they are not in _FRONT_MSK_DIR | |
fv_msk_frames = os.listdir(self._fv_msk_dir) | |
fv_msk_frames = [frame.split(".")[0] for frame in fv_msk_frames] | |
fv_msk_frames_exist_map = {entry: True for entry in fv_msk_frames} # This is to speed-up the dataloader | |
lst = [entry for entry in lst if entry in fv_msk_frames_exist_map] | |
lst_all = [entry for entry in lst_all if entry in fv_msk_frames_exist_map] | |
# Filter based on the samples plabels | |
if self.bev_percentage < 100: | |
lst_filt = [entry for entry in lst if entry in lst_percent] | |
lst = lst_filt | |
# Remove any potential duplicates | |
lst = set(lst) | |
lst_percent = set(lst_percent) | |
img_map = {} | |
for camera in self.rgb_cameras: | |
with open(os.path.join(self._img_dir, "{}.json".format(camera))) as fp: | |
map_list = json.load(fp) | |
map_dict = {k: v for d in map_list for k, v in d.items()} | |
img_map[camera] = map_dict | |
bev_meta = bev_metadata["meta"] | |
bev_images = [img_desc for img_desc in bev_metadata["images"] if img_desc["id"] in lst] | |
fv_meta = fv_metadata["meta"] | |
fv_images = [img_desc for img_desc in fv_metadata['images'] if img_desc['id'] in lst] | |
# Check for inconsistency due to inconsistencies in the input files or dataset | |
bev_images_ids = [bev_img["id"] for bev_img in bev_images] | |
fv_images_ids = [fv_img["id"] for fv_img in fv_images] | |
assert set(bev_images_ids) == set(fv_images_ids) and len(bev_images_ids) == len(fv_images_ids), 'Inconsistency between fv_images and bev_images detected' | |
if lst_all is not None: | |
bev_images_all = [img_desc for img_desc in bev_metadata['images'] if img_desc['id'] in lst_all] | |
fv_images_all = [img_desc for img_desc in fv_metadata['images'] if img_desc['id'] in lst_all] | |
else: | |
bev_images_all, fv_images_all = None, None | |
return bev_meta, bev_images, bev_images_all, fv_meta, fv_images, fv_images_all, img_map, lst_percent | |
def _find_index(self, list, key, value): | |
for i, dic in enumerate(list): | |
if dic[key] == value: | |
return i | |
return None | |
def _load_item(self, item_idx): | |
# Find the index of the element in the list containing all elements | |
all_idx = self._find_index(self._fv_images_all, "id", self._fv_images[item_idx]['id']) | |
if all_idx is None: | |
raise IOError("Required index not found!") | |
bev_img_desc = self._bev_images[item_idx] | |
fv_img_desc = self._fv_images[item_idx] | |
scene, frame_id = self._bev_images[item_idx]["id"].split(";") | |
# Get the RGB file names | |
img_file = os.path.join( | |
self.kitti_root_dir, | |
self._img_map["front"]["{}.png" | |
.format(bev_img_desc['id'])] | |
) | |
if not os.path.exists(img_file): | |
raise IOError( | |
"RGB image not found! Scene: {}, Frame: {}".format(scene, frame_id) | |
) | |
# Load the images | |
img = Image.open(img_file).convert(mode="RGB") | |
# Load the BEV mask | |
bev_msk_file = os.path.join( | |
self._bev_msk_dir, | |
"{}.png".format(bev_img_desc['id']) | |
) | |
bev_msk = Image.open(bev_msk_file) | |
bev_plabel = None | |
# Load the front mask | |
fv_msk_file = os.path.join( | |
self._fv_msk_dir, | |
"{}.png".format(fv_img_desc['id']) | |
) | |
fv_msk = Image.open(fv_msk_file) | |
bev_weights_msk_combined = None | |
# Get the other information | |
bev_cat = bev_img_desc["cat"] | |
bev_iscrowd = bev_img_desc["iscrowd"] | |
fv_cat = fv_img_desc['cat'] | |
fv_iscrowd = fv_img_desc['iscrowd'] | |
fv_intrinsics = fv_img_desc["cam_intrinsic"] | |
ego_pose = fv_img_desc['ego_pose'] # This loads the cam0 pose | |
# Get the ids of all the frames | |
frame_ids = bev_img_desc["id"] | |
return img, bev_msk, bev_plabel, fv_msk, bev_weights_msk_combined, bev_cat, \ | |
bev_iscrowd, fv_cat, fv_iscrowd, fv_intrinsics, ego_pose, frame_ids | |
def fv_categories(self): | |
"""Category names""" | |
return self._fv_meta["categories"] | |
def fv_num_categories(self): | |
"""Number of categories""" | |
return len(self.fv_categories) | |
def fv_num_stuff(self): | |
"""Number of "stuff" categories""" | |
return self._fv_meta["num_stuff"] | |
def fv_num_thing(self): | |
"""Number of "thing" categories""" | |
return self.fv_num_categories - self.fv_num_stuff | |
def bev_categories(self): | |
"""Category names""" | |
return self._bev_meta["categories"] | |
def bev_num_categories(self): | |
"""Number of categories""" | |
return len(self.bev_categories) | |
def bev_num_stuff(self): | |
"""Number of "stuff" categories""" | |
return self._bev_meta["num_stuff"] | |
def bev_num_thing(self): | |
"""Number of "thing" categories""" | |
return self.bev_num_categories - self.bev_num_stuff | |
def original_ids(self): | |
"""Original class id of each category""" | |
return self._fv_meta["original_ids"] | |
def palette(self): | |
"""Default palette to be used when color-coding semantic labels""" | |
return np.array(self._fv_meta["palette"], dtype=np.uint8) | |
def img_sizes(self): | |
"""Size of each image of the dataset""" | |
return [img_desc["size"] for img_desc in self._fv_images] | |
def img_categories(self): | |
"""Categories present in each image of the dataset""" | |
return [img_desc["cat"] for img_desc in self._fv_images] | |
def dataset_name(self): | |
return "Kitti360" | |
def __len__(self): | |
if self.cfg.percentage < 1: | |
return int(len(self._fv_images) * self.cfg.percentage) | |
return len(self._fv_images) | |
def __getitem__(self, item): | |
img, bev_msk, bev_plabel, fv_msk, bev_weights_msk, bev_cat, bev_iscrowd, fv_cat, fv_iscrowd, fv_intrinsics, ego_pose, idx = self._load_item(item) | |
rec = self.transform(img=img, bev_msk=bev_msk, bev_plabel=bev_plabel, fv_msk=fv_msk, bev_weights_msk=bev_weights_msk, bev_cat=bev_cat, | |
bev_iscrowd=bev_iscrowd, fv_cat=fv_cat, fv_iscrowd=fv_iscrowd, fv_intrinsics=fv_intrinsics, | |
ego_pose=ego_pose) | |
size = (img.size[1], img.size[0]) | |
# Close the file | |
img.close() | |
bev_msk.close() | |
fv_msk.close() | |
rec["index"] = idx | |
rec["size"] = size | |
rec['name'] = idx | |
return rec | |
def get_image_desc(self, idx): | |
"""Look up an image descriptor given the id""" | |
matching = [img_desc for img_desc in self._images if img_desc["id"] == idx] | |
if len(matching) == 1: | |
return matching[0] | |
else: | |
raise ValueError("No image found with id %s" % idx) |