Spaces:
Build error
Build error
| import pytorch_lightning as pl | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from src.tracker.signboard_segment.datasets_signboard_detection.dataset import PoIDataset | |
| import src.tracker.signboard_segment.datasets_signboard_detection.utils as utils | |
| class POIDataModule(pl.LightningDataModule): | |
| def __init__(self, | |
| data, | |
| train_batch_size=16, | |
| test_batch_size=16, | |
| seed=42): | |
| super().__init__() | |
| self.data = data | |
| self.train_batch_size = train_batch_size | |
| self.test_batch_size = test_batch_size | |
| self.seed = seed | |
| def prepare_data(self): | |
| pass | |
| def setup(self, stage="fit"): | |
| transform = [transforms.ToTensor()] | |
| test_transform = transforms.Compose(transform) | |
| if stage == "predict" or stage is None: | |
| self.test_dataset = PoIDataset(self.data, | |
| transforms=test_transform) | |
| def predict_dataloader(self): | |
| if self.test_dataset is not None: | |
| return DataLoader(self.test_dataset, | |
| batch_size=self.test_batch_size, | |
| shuffle=False, | |
| num_workers=16, | |
| collate_fn=utils.collate_fn) | |
| def _get_name(filepath): | |
| images = filepath | |
| return images | |