File size: 1,175 Bytes
58da73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from .base_dataset import BaseDataset, get_transform
from PIL import Image
from pathlib import Path


class OneDataset(BaseDataset):
    """
    加载数据

    加载文件夹中所有图片或直接加载指定文件
    """

    def __init__(self, img, opt):
        BaseDataset.__init__(self, opt)
        # self.opt = opt
        dataroot = img
        if type(dataroot) == str:
            dataroot = Path(dataroot)
            if dataroot.is_file():
                self.A_path = [str(dataroot)]
            if dataroot.is_dir():
                self.A_path = [str(i) for i in list(dataroot.iterdir())]
            self.A_img = [Image.open(path).convert("RGB") for path in self.A_path]
        else:  # dataroot 传入的直接是PIL格式图片
            self.A_path = [None]
            self.A_img = [dataroot]

    def __getitem__(self, idx:int):
        A_path = self.A_path[idx]
        A_img = self.A_img[idx]
        A = transform(A_img, self.opt)
        return {"A": A, "A_paths": A_path}

    def __len__(self):
        return 1


def transform(img, opt):
    fn_transform = get_transform(opt, grayscale=(opt.input_nc == 1))
    return fn_transform(img)