File size: 3,860 Bytes
124ba77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Meta Platforms, Inc. and affiliates.

from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List
# from logger import logger
import numpy as np
# import torch
# import torch.utils.data as torchdata
# import torchvision.transforms as tvf
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from dataset.UAV.dataset import UavMapPair
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as tvf

# 自定义数据模块类,继承自pl.LightningDataModule
class UavMapDatasetModule(pl.LightningDataModule):


    def __init__(self, cfg: Dict[str, Any]):
        super().__init__()

        # default_cfg = OmegaConf.create(self.default_cfg)
        # OmegaConf.set_struct(default_cfg, True)  # cannot add new keys
        # self.cfg = OmegaConf.merge(default_cfg, cfg)
        self.cfg=cfg
        # self.transform = tvf.Compose([
        #     tvf.ToTensor(),
        #     tvf.Resize(self.cfg.image_size),
        #     tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        # ])

        tfs = []
        tfs.append(tvf.ToTensor())
        tfs.append(tvf.Resize(self.cfg.image_size))
        self.val_tfs = tvf.Compose(tfs)

        #     transforms.Resize(self.cfg.image_size),
        if cfg.augmentation.image.apply:
            args = OmegaConf.masked_copy(
                cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
            )
            tfs.append(tvf.ColorJitter(**args))
        self.train_tfs = tvf.Compose(tfs)

        # self.train_tfs=self.transform
        # self.val_tfs = self.transform
        self.init()
    def init(self):
        self.train_dataset = ConcatDataset([
            UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs)
            for city in self.cfg.train_citys
        ])

        self.val_dataset = ConcatDataset([
            UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
            for city in self.cfg.val_citys
        ])

        # self.val_datasets = {
        #     city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs)
        #     for city in self.cfg.val_citys
        #     }
        # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset)))
        # # 定义分割比例
        # train_ratio = 0.8  # 训练集比例
        # # 计算分割的样本数量
        # train_size = int(len(self.dataset) * train_ratio)
        # val_size = len(self.dataset) - train_size
        # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])
    def train_dataloader(self):
        train_loader = DataLoader(self.train_dataset,
                                  batch_size=self.cfg.train.batch_size,
                                  num_workers=self.cfg.train.num_workers,
                                  shuffle=True,pin_memory = True)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_dataset,
                                batch_size=self.cfg.val.batch_size,
                                num_workers=self.cfg.val.num_workers,
                                shuffle=True,pin_memory = True)
        #
        # my_dict = {k: v for k, v in self.val_datasets}
        # val_loaders={city: DataLoader(dataset,
        #                         batch_size=self.cfg.val.batch_size,
        #                         num_workers=self.cfg.val.num_workers,
        #                         shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
        return val_loader