# Copyright (c) Meta Platforms, Inc. and affiliates. from copy import deepcopy from pathlib import Path from typing import Any, Dict, List # from logger import logger import numpy as np # import torch # import torch.utils.data as torchdata # import torchvision.transforms as tvf from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl from dataset.UAV.dataset import UavMapPair # from torch.utils.data import Dataset, DataLoader # from torchvision import transforms from torch.utils.data import Dataset, ConcatDataset from torch.utils.data import Dataset, DataLoader, random_split import torchvision.transforms as tvf # 自定义数据模块类,继承自pl.LightningDataModule class UavMapDatasetModule(pl.LightningDataModule): def __init__(self, cfg: Dict[str, Any]): super().__init__() # default_cfg = OmegaConf.create(self.default_cfg) # OmegaConf.set_struct(default_cfg, True) # cannot add new keys # self.cfg = OmegaConf.merge(default_cfg, cfg) self.cfg=cfg # self.transform = tvf.Compose([ # tvf.ToTensor(), # tvf.Resize(self.cfg.image_size), # tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # ]) tfs = [] tfs.append(tvf.ToTensor()) tfs.append(tvf.Resize(self.cfg.image_size)) self.val_tfs = tvf.Compose(tfs) # transforms.Resize(self.cfg.image_size), if cfg.augmentation.image.apply: args = OmegaConf.masked_copy( cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"] ) tfs.append(tvf.ColorJitter(**args)) self.train_tfs = tvf.Compose(tfs) # self.train_tfs=self.transform # self.val_tfs = self.transform self.init() def init(self): self.train_dataset = ConcatDataset([ UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs) for city in self.cfg.train_citys ]) self.val_dataset = ConcatDataset([ UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs) for city in self.cfg.val_citys ]) # self.val_datasets = { # city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs) # for city in self.cfg.val_citys # } # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset))) # # 定义分割比例 # train_ratio = 0.8 # 训练集比例 # # 计算分割的样本数量 # train_size = int(len(self.dataset) * train_ratio) # val_size = len(self.dataset) - train_size # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size]) def train_dataloader(self): train_loader = DataLoader(self.train_dataset, batch_size=self.cfg.train.batch_size, num_workers=self.cfg.train.num_workers, shuffle=True,pin_memory = True) return train_loader def val_dataloader(self): val_loader = DataLoader(self.val_dataset, batch_size=self.cfg.val.batch_size, num_workers=self.cfg.val.num_workers, shuffle=True,pin_memory = True) # # my_dict = {k: v for k, v in self.val_datasets} # val_loaders={city: DataLoader(dataset, # batch_size=self.cfg.val.batch_size, # num_workers=self.cfg.val.num_workers, # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()} return val_loader