MapLocNet / dataset /dataset.py
wangerniu
Commit message.
124ba77
# Copyright (c) Meta Platforms, Inc. and affiliates.
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List
# from logger import logger
import numpy as np
# import torch
# import torch.utils.data as torchdata
# import torchvision.transforms as tvf
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from dataset.UAV.dataset import UavMapPair
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as tvf
# 自定义数据模块类,继承自pl.LightningDataModule
class UavMapDatasetModule(pl.LightningDataModule):
def __init__(self, cfg: Dict[str, Any]):
super().__init__()
# default_cfg = OmegaConf.create(self.default_cfg)
# OmegaConf.set_struct(default_cfg, True) # cannot add new keys
# self.cfg = OmegaConf.merge(default_cfg, cfg)
self.cfg=cfg
# self.transform = tvf.Compose([
# tvf.ToTensor(),
# tvf.Resize(self.cfg.image_size),
# tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ])
tfs = []
tfs.append(tvf.ToTensor())
tfs.append(tvf.Resize(self.cfg.image_size))
self.val_tfs = tvf.Compose(tfs)
# transforms.Resize(self.cfg.image_size),
if cfg.augmentation.image.apply:
args = OmegaConf.masked_copy(
cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
)
tfs.append(tvf.ColorJitter(**args))
self.train_tfs = tvf.Compose(tfs)
# self.train_tfs=self.transform
# self.val_tfs = self.transform
self.init()
def init(self):
self.train_dataset = ConcatDataset([
UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs)
for city in self.cfg.train_citys
])
self.val_dataset = ConcatDataset([
UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
for city in self.cfg.val_citys
])
# self.val_datasets = {
# city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs)
# for city in self.cfg.val_citys
# }
# logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset)))
# # 定义分割比例
# train_ratio = 0.8 # 训练集比例
# # 计算分割的样本数量
# train_size = int(len(self.dataset) * train_ratio)
# val_size = len(self.dataset) - train_size
# self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])
def train_dataloader(self):
train_loader = DataLoader(self.train_dataset,
batch_size=self.cfg.train.batch_size,
num_workers=self.cfg.train.num_workers,
shuffle=True,pin_memory = True)
return train_loader
def val_dataloader(self):
val_loader = DataLoader(self.val_dataset,
batch_size=self.cfg.val.batch_size,
num_workers=self.cfg.val.num_workers,
shuffle=True,pin_memory = True)
#
# my_dict = {k: v for k, v in self.val_datasets}
# val_loaders={city: DataLoader(dataset,
# batch_size=self.cfg.val.batch_size,
# num_workers=self.cfg.val.num_workers,
# shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
return val_loader