|
|
|
|
|
from copy import deepcopy |
|
from pathlib import Path |
|
from typing import Any, Dict, List |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from omegaconf import DictConfig, OmegaConf |
|
import pytorch_lightning as pl |
|
from dataset.UAV.dataset import UavMapPair |
|
|
|
|
|
from torch.utils.data import Dataset, ConcatDataset |
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
import torchvision.transforms as tvf |
|
|
|
|
|
class UavMapDatasetModule(pl.LightningDataModule): |
|
|
|
|
|
def __init__(self, cfg: Dict[str, Any]): |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
self.cfg=cfg |
|
|
|
|
|
|
|
|
|
|
|
|
|
tfs = [] |
|
tfs.append(tvf.ToTensor()) |
|
tfs.append(tvf.Resize(self.cfg.image_size)) |
|
self.val_tfs = tvf.Compose(tfs) |
|
|
|
|
|
if cfg.augmentation.image.apply: |
|
args = OmegaConf.masked_copy( |
|
cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"] |
|
) |
|
tfs.append(tvf.ColorJitter(**args)) |
|
self.train_tfs = tvf.Compose(tfs) |
|
|
|
|
|
|
|
self.init() |
|
def init(self): |
|
self.train_dataset = ConcatDataset([ |
|
UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs) |
|
for city in self.cfg.train_citys |
|
]) |
|
|
|
self.val_dataset = ConcatDataset([ |
|
UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) |
|
for city in self.cfg.val_citys |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_dataloader(self): |
|
train_loader = DataLoader(self.train_dataset, |
|
batch_size=self.cfg.train.batch_size, |
|
num_workers=self.cfg.train.num_workers, |
|
shuffle=True,pin_memory = True) |
|
return train_loader |
|
|
|
def val_dataloader(self): |
|
val_loader = DataLoader(self.val_dataset, |
|
batch_size=self.cfg.val.batch_size, |
|
num_workers=self.cfg.val.num_workers, |
|
shuffle=True,pin_memory = True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return val_loader |
|
|