Spaces:
Runtime error
Runtime error
| from tqdm import tqdm | |
| from os import path as osp | |
| from torch.utils.data import Dataset, DataLoader, ConcatDataset | |
| from src.datasets.megadepth import MegaDepthDataset | |
| from src.datasets.scannet import ScanNetDataset | |
| from src.datasets.aachen import AachenDataset | |
| from src.datasets.inloc import InLocDataset | |
| class TestDataLoader(DataLoader): | |
| """ | |
| For distributed training, each training process is assgined | |
| only a part of the training scenes to reduce memory overhead. | |
| """ | |
| def __init__(self, config): | |
| # 1. data config | |
| self.test_data_source = config.DATASET.TEST_DATA_SOURCE | |
| dataset_name = str(self.test_data_source).lower() | |
| # testing | |
| self.test_data_root = config.DATASET.TEST_DATA_ROOT | |
| self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) | |
| self.test_npz_root = config.DATASET.TEST_NPZ_ROOT | |
| self.test_list_path = config.DATASET.TEST_LIST_PATH | |
| self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH | |
| # 2. dataset config | |
| # general options | |
| self.min_overlap_score_test = ( | |
| config.DATASET.MIN_OVERLAP_SCORE_TEST | |
| ) # 0.4, omit data with overlap_score < min_overlap_score | |
| # MegaDepth options | |
| if dataset_name == "megadepth": | |
| self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800 | |
| self.mgdpt_img_pad = True | |
| self.mgdpt_depth_pad = True | |
| self.mgdpt_df = 8 | |
| self.coarse_scale = 0.125 | |
| if dataset_name == "scannet": | |
| self.img_resize = config.DATASET.TEST_IMGSIZE | |
| if (dataset_name == "megadepth") or (dataset_name == "scannet"): | |
| test_dataset = self._setup_dataset( | |
| self.test_data_root, | |
| self.test_npz_root, | |
| self.test_list_path, | |
| self.test_intrinsic_path, | |
| mode="test", | |
| min_overlap_score=self.min_overlap_score_test, | |
| pose_dir=self.test_pose_root, | |
| ) | |
| elif dataset_name == "aachen_v1.1": | |
| test_dataset = AachenDataset( | |
| self.test_data_root, | |
| self.test_list_path, | |
| img_resize=config.DATASET.TEST_IMGSIZE, | |
| ) | |
| elif dataset_name == "inloc": | |
| test_dataset = InLocDataset( | |
| self.test_data_root, | |
| self.test_list_path, | |
| img_resize=config.DATASET.TEST_IMGSIZE, | |
| ) | |
| else: | |
| raise "unknown dataset" | |
| self.test_loader_params = { | |
| "batch_size": 1, | |
| "shuffle": False, | |
| "num_workers": 4, | |
| "pin_memory": True, | |
| } | |
| # sampler = Seq(self.test_dataset, shuffle=False) | |
| super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params) | |
| def _setup_dataset( | |
| self, | |
| data_root, | |
| split_npz_root, | |
| scene_list_path, | |
| intri_path, | |
| mode="train", | |
| min_overlap_score=0.0, | |
| pose_dir=None, | |
| ): | |
| """Setup train / val / test set""" | |
| with open(scene_list_path, "r") as f: | |
| npz_names = [name.split()[0] for name in f.readlines()] | |
| local_npz_names = npz_names | |
| return self._build_concat_dataset( | |
| data_root, | |
| local_npz_names, | |
| split_npz_root, | |
| intri_path, | |
| mode=mode, | |
| min_overlap_score=min_overlap_score, | |
| pose_dir=pose_dir, | |
| ) | |
| def _build_concat_dataset( | |
| self, | |
| data_root, | |
| npz_names, | |
| npz_dir, | |
| intrinsic_path, | |
| mode, | |
| min_overlap_score=0.0, | |
| pose_dir=None, | |
| ): | |
| datasets = [] | |
| # augment_fn = self.augment_fn if mode == 'train' else None | |
| data_source = self.test_data_source | |
| if str(data_source).lower() == "megadepth": | |
| npz_names = [f"{n}.npz" for n in npz_names] | |
| for npz_name in tqdm(npz_names): | |
| # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. | |
| npz_path = osp.join(npz_dir, npz_name) | |
| if data_source == "ScanNet": | |
| datasets.append( | |
| ScanNetDataset( | |
| data_root, | |
| npz_path, | |
| intrinsic_path, | |
| mode=mode, | |
| img_resize=self.img_resize, | |
| min_overlap_score=min_overlap_score, | |
| pose_dir=pose_dir, | |
| ) | |
| ) | |
| elif data_source == "MegaDepth": | |
| datasets.append( | |
| MegaDepthDataset( | |
| data_root, | |
| npz_path, | |
| mode=mode, | |
| min_overlap_score=min_overlap_score, | |
| img_resize=self.mgdpt_img_resize, | |
| df=self.mgdpt_df, | |
| img_padding=self.mgdpt_img_pad, | |
| depth_padding=self.mgdpt_depth_pad, | |
| coarse_scale=self.coarse_scale, | |
| ) | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| return ConcatDataset(datasets) | |