import functools import torch import torch.utils.data from frame_field_learning import data_transforms from lydorn_utils import print_utils def inria_aerial_train_tile_filter(tile, train_val_split_point): return tile["number"] <= train_val_split_point def inria_aerial_val_tile_filter(tile, train_val_split_point): return train_val_split_point < tile["number"] def get_inria_aerial_folds(config, root_dir, folds): from torch_lydorn.torchvision.datasets import InriaAerial # --- Online transform done on the host (CPU): online_cpu_transform = data_transforms.get_online_cpu_transform(config, augmentations=config["data_aug_params"]["enable"]) mask_only = config["dataset_params"]["mask_only"] kwargs = { "pre_process": config["dataset_params"]["pre_process"], "transform": online_cpu_transform, "patch_size": config["dataset_params"]["data_patch_size"], "patch_stride": config["dataset_params"]["input_patch_size"], "pre_transform": data_transforms.get_offline_transform_patch(distances=not mask_only, sizes=not mask_only), "small": config["dataset_params"]["small"], "pool_size": config["num_workers"], "gt_source": config["dataset_params"]["gt_source"], "gt_type": config["dataset_params"]["gt_type"], "gt_dirname": config["dataset_params"]["gt_dirname"], "mask_only": mask_only, } train_val_split_point = config["dataset_params"]["train_fraction"] * 36 partial_train_tile_filter = functools.partial(inria_aerial_train_tile_filter, train_val_split_point=train_val_split_point) partial_val_tile_filter = functools.partial(inria_aerial_val_tile_filter, train_val_split_point=train_val_split_point) ds_list = [] for fold in folds: if fold == "train": ds = InriaAerial(root_dir, fold="train", tile_filter=partial_train_tile_filter, **kwargs) ds_list.append(ds) elif fold == "val": ds = InriaAerial(root_dir, fold="train", tile_filter=partial_val_tile_filter, **kwargs) ds_list.append(ds) elif fold == "train_val": ds = InriaAerial(root_dir, fold="train", **kwargs) ds_list.append(ds) elif fold == "test": ds = InriaAerial(root_dir, fold="test", **kwargs) ds_list.append(ds) else: print_utils.print_error("ERROR: fold \"{}\" not recognized, implement it in dataset_folds.py.".format(fold)) return ds_list def get_luxcarta_buildings(config, root_dir, folds): from torch_lydorn.torchvision.datasets import LuxcartaBuildings # --- Online transform done on the host (CPU): online_cpu_transform = data_transforms.get_online_cpu_transform(config, augmentations=config["data_aug_params"]["enable"]) data_patch_size = config["dataset_params"]["data_patch_size"] if config["data_aug_params"]["enable"] else config[ "input_patch_size"] ds = LuxcartaBuildings(root_dir, transform=online_cpu_transform, patch_size=data_patch_size, patch_stride=config["dataset_params"]["input_patch_size"], pre_transform=data_transforms.get_offline_transform_patch(), fold="train", pool_size=config["num_workers"]) torch.manual_seed(config["dataset_params"]["seed"]) # Ensure a seed is set train_split_length = int(round(config["dataset_params"]["train_fraction"] * len(ds))) val_split_length = len(ds) - train_split_length train_ds, val_ds = torch.utils.data.random_split(ds, [train_split_length, val_split_length]) ds_list = [] for fold in folds: if fold == "train": ds_list.append(train_ds) elif fold == "val": ds_list.append(val_ds) elif fold == "test": # TODO: handle patching with multi-GPU processing print_utils.print_error("WARNING: handle patching with multi-GPU processing") ds = LuxcartaBuildings(root_dir, transform=online_cpu_transform, pre_transform=data_transforms.get_offline_transform_patch(), fold="test", pool_size=config["num_workers"]) ds_list.append(ds) else: print_utils.print_error("ERROR: fold \"{}\" not recognized, implement it in dataset_folds.py.".format(fold)) return ds_list def get_mapping_challenge(config, root_dir, folds): from torch_lydorn.torchvision.datasets import MappingChallenge if "train" in folds or "val" in folds or "train_val" in folds: train_online_cpu_transform = data_transforms.get_online_cpu_transform(config, augmentations=config["data_aug_params"][ "enable"]) ds = MappingChallenge(root_dir, transform=train_online_cpu_transform, pre_transform=data_transforms.get_offline_transform_patch(), small=config["dataset_params"]["small"], fold="train", pool_size=config["num_workers"]) torch.manual_seed(config["dataset_params"]["seed"]) # Ensure a seed is set train_split_length = int(round(config["dataset_params"]["train_fraction"] * len(ds))) val_split_length = len(ds) - train_split_length train_ds, val_ds = torch.utils.data.random_split(ds, [train_split_length, val_split_length]) ds_list = [] for fold in folds: if fold == "train": ds_list.append(train_ds) elif fold == "val": ds_list.append(val_ds) elif fold == "train_val": ds_list.append(ds) elif fold == "test": # The val fold from the original challenge is used as test here # because we don't have the ground truth for the test_images fold of the challenge: test_online_cpu_transform = data_transforms.get_eval_online_cpu_transform() test_ds = MappingChallenge(root_dir, transform=test_online_cpu_transform, pre_transform=data_transforms.get_offline_transform_patch(), small=config["dataset_params"]["small"], fold="val", pool_size=config["num_workers"]) ds_list.append(test_ds) else: print_utils.print_error("ERROR: fold \"{}\" not recognized, implement it in dataset_folds.py.".format(fold)) exit() return ds_list def get_opencities_competition(config, root_dir, folds): from torch_lydorn.torchvision.datasets import RasterizedOpenCities, OpenCitiesTestDataset data_patch_size = config["dataset_params"]["data_patch_size"] if config["data_aug_params"]["enable"] else config[ "input_patch_size"] ds_list = [] for fold in folds: if fold == "train": train_ds = RasterizedOpenCities(tier=1, augment=False, small_subset=False, resize_size=data_patch_size, data_dir=root_dir, baseline_mode=False, val=False, val_split=config["dataset_params"]["val_fraction"]) ds_list.append(train_ds) elif fold == "val": val_ds = RasterizedOpenCities(tier=1, augment=False, small_subset=False, resize_size=data_patch_size, data_dir=root_dir, baseline_mode=False, val=True, val_split=config["dataset_params"]["val_fraction"]) ds_list.append(val_ds) elif fold == "test": test_ds = OpenCitiesTestDataset(root_dir + "/test/", 1024) ds_list.append(test_ds) else: print_utils.print_error("ERROR: fold \"{}\" not recognized, implement it in dataset_folds.py.".format(fold)) return ds_list def get_xview2_dataset(config, root_dir, folds): from torch_lydorn.torchvision.datasets import xView2Dataset if "train" in folds or "val" in folds or "train_val" in folds: train_online_cpu_transform = data_transforms.get_online_cpu_transform(config, augmentations=config["data_aug_params"][ "enable"]) ds = xView2Dataset(root_dir, fold="train", pre_process=True, patch_size=config["dataset_params"]["data_patch_size"], pre_transform=data_transforms.get_offline_transform_patch(), transform=train_online_cpu_transform, small=config["dataset_params"]["small"], pool_size=config["num_workers"]) torch.manual_seed(config["dataset_params"]["seed"]) # Ensure a seed is set train_split_length = int(round(config["dataset_params"]["train_fraction"] * len(ds))) val_split_length = len(ds) - train_split_length train_ds, val_ds = torch.utils.data.random_split(ds, [train_split_length, val_split_length]) ds_list = [] for fold in folds: if fold == "train": ds_list.append(train_ds) elif fold == "val": ds_list.append(val_ds) elif fold == "train_val": ds_list.append(ds) elif fold == "test": raise NotImplementedError("Test fold not yet implemented (skip pre-processing?)") elif fold == "hold": raise NotImplementedError("Hold fold not yet implemented (skip pre-processing?)") else: print_utils.print_error("ERROR: fold \"{}\" not recognized, implement it in dataset_folds.py.".format(fold)) exit() return ds_list def get_folds(config, root_dir, folds): assert set(folds).issubset({"train", "val", "train_val", "test"}), \ 'fold in folds should be in ["train", "val", "train_val", "test"]' if config["dataset_params"]["root_dirname"] == "AerialImageDataset": return get_inria_aerial_folds(config, root_dir, folds) elif config["dataset_params"]["root_dirname"] == "luxcarta_precise_buildings": return get_luxcarta_buildings(config, root_dir, folds) elif config["dataset_params"]["root_dirname"] == "mapping_challenge_dataset": return get_mapping_challenge(config, root_dir, folds) elif config["dataset_params"]["root_dirname"] == "segbuildings": return get_opencities_competition(config, root_dir, folds) elif config["dataset_params"]["root_dirname"] == "xview2_xbd_dataset": return get_xview2_dataset(config, root_dir, folds) else: print_utils.print_error("ERROR: config[\"data_root_partial_dirpath\"] = \"{}\" is an unknown dataset! " "If it is a new dataset, add it in dataset_folds.py's get_folds() function.".format( config["dataset_params"]["root_dirname"])) exit()