import copy from os.path import join as pjoin from typing import Any, Callable from torch.utils.data import DataLoader class BaseDataModule: def __init__(self, collate_fn: Callable) -> None: super(BaseDataModule, self).__init__() self.collate_fn = collate_fn self.is_mm = False def get_sample_set(self, overrides: dict) -> Any: sample_params = copy.deepcopy(self.hparams) sample_params.update(overrides) split_file = pjoin( eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"), self.cfg.TEST.SPLIT + ".txt" ) return self.Dataset(split_file=split_file, **sample_params) def __getattr__(self, item: str) -> Any: if item.endswith("_dataset") and not item.startswith("_"): subset = item[:-len("_dataset")].upper() item_c = "_" + item if item_c not in self.__dict__: split_file = pjoin( eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"), eval(f"self.cfg.{subset}.SPLIT") + ".txt" ) self.__dict__[item_c] = self.Dataset(split_file=split_file, **self.hparams) return getattr(self, item_c) classname = self.__class__.__name__ raise AttributeError(f"'{classname}' object has no attribute '{item}'") def get_dataloader_options(self, stage: str) -> dict: stage_args = eval(f"self.cfg.{stage.upper()}") dataloader_options = { "batch_size": stage_args.BATCH_SIZE, "num_workers": stage_args.NUM_WORKERS, "collate_fn": self.collate_fn, "persistent_workers": stage_args.PERSISTENT_WORKERS, } return dataloader_options def train_dataloader(self) -> DataLoader: dataloader_options = self.get_dataloader_options('TRAIN') return DataLoader(self.train_dataset, shuffle=True, **dataloader_options) def val_dataloader(self) -> DataLoader: dataloader_options = self.get_dataloader_options('VAL') return DataLoader(self.val_dataset, shuffle=False, **dataloader_options) def test_dataloader(self) -> DataLoader: dataloader_options = self.get_dataloader_options('TEST') dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE return DataLoader(self.test_dataset, shuffle=False, **dataloader_options)