File size: 6,725 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from torch.utils.data import Dataset, ConcatDataset
import os
from concurrent.futures import ProcessPoolExecutor
import pandas as pd


def add_data_args(parent_args):
    parser = parent_args.add_argument_group('taiyi stable diffusion data args')
    # 支持传入多个路径,分别加载
    parser.add_argument(
        "--datasets_path", type=str, default=None, required=True, nargs='+',
        help="A folder containing the training data of instance images.",
    )
    parser.add_argument(
        "--datasets_type", type=str, default=None, required=True, choices=['txt', 'csv', 'fs_datasets'], nargs='+',
        help="dataset type, txt or csv, same len as datasets_path",
    )
    parser.add_argument(
        "--resolution", type=int, default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop", action="store_true", default=False,
        help="Whether to center crop images before resizing to resolution"
    )
    parser.add_argument("--thres", type=float, default=0.2)
    return parent_args


class TXTDataset(Dataset):
    # 添加Txt数据集读取,主要是针对Zero23m数据集。
    def __init__(self,
                 foloder_name,
                 thres=0.2):
        super().__init__()
        # print(f'Loading folder data from {foloder_name}.')
        self.image_paths = []
        '''
        暂时没有开源这部分文件
        score_data = pd.read_csv(os.path.join(foloder_name, 'score.csv'))
        img_path2score = {score_data['image_path'][i]: score_data['score'][i]
                          for i in range(len(score_data))}
        '''
        # print(img_path2score)
        # 这里都存的是地址,避免初始化时间过多。
        for each_file in os.listdir(foloder_name):
            if each_file.endswith('.jpg'):
                self.image_paths.append(os.path.join(foloder_name, each_file))

        # print('Done loading data. Len of images:', len(self.image_paths))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = str(self.image_paths[idx])
        caption_path = img_path.replace('.jpg', '.txt')  # 图片名称和文本名称一致。
        with open(caption_path, 'r') as f:
            caption = f.read()
        return {'img_path': img_path, 'caption': caption}


# NOTE 加速读取数据,直接用原版的,在外部使用并行读取策略。30min->3min
class CSVDataset(Dataset):
    def __init__(self,
                 input_filename,
                 image_root,
                 img_key,
                 caption_key,
                 thres=0.2):
        super().__init__()
        # logging.debug(f'Loading csv data from {input_filename}.')
        print(f'Loading csv data from {input_filename}.')
        self.images = []
        self.captions = []

        if input_filename.endswith('.csv'):
            # print(f"Load Data from{input_filename}")
            df = pd.read_csv(input_filename, index_col=0, on_bad_lines='skip')
            print(f'file {input_filename} datalen {len(df)}')
            # 这个图片的路径也需要根据数据集的结构稍微做点修改
            self.images.extend(df[img_key].tolist())
            self.captions.extend(df[caption_key].tolist())
        self.image_root = image_root

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_root, str(self.images[idx]))
        return {'img_path': img_path, 'caption': self.captions[idx]}


def if_final_dir(path: str) -> bool:
    # 如果当前目录有一个文件,那就算是终极目录
    for f in os.scandir(path):
        if f.is_file():
            return True
    return False


def process_pool_read_txt_dataset(args,
                                  input_root=None,
                                  thres=0.2):
    p = ProcessPoolExecutor(max_workers=20)
    all_datasets = []
    res = []

    # 遍历该目录下所有的子目录
    def traversal_files(path: str):
        list_subfolders_with_paths = [f.path for f in os.scandir(path) if f.is_dir()]
        for dir_path in list_subfolders_with_paths:
            if if_final_dir(dir_path):
                res.append(p.submit(TXTDataset,
                                    dir_path,
                                    thres))
            else:
                traversal_files(dir_path)
    traversal_files(input_root)
    p.shutdown()
    for future in res:
        all_datasets.append(future.result())
    dataset = ConcatDataset(all_datasets)
    return dataset


def process_pool_read_csv_dataset(args,
                                  input_root,
                                  thres=0.20):
    # here input_filename is a directory containing a CSV file
    all_csvs = os.listdir(os.path.join(input_root, 'release'))
    image_root = os.path.join(input_root, 'images')
    # csv_with_score = [each for each in all_csvs if 'score' in each]
    all_datasets = []
    res = []
    p = ProcessPoolExecutor(max_workers=150)
    for path in all_csvs:
        each_csv_path = os.path.join(input_root, 'release', path)
        res.append(p.submit(CSVDataset,
                            each_csv_path,
                            image_root,
                            img_key="name",
                            caption_key="caption",
                            thres=thres))
    p.shutdown()
    for future in res:
        all_datasets.append(future.result())
    dataset = ConcatDataset(all_datasets)
    return dataset


def load_data(args, global_rank=0):
    assert len(args.datasets_path) == len(args.datasets_type), \
        "datasets_path num not equal to datasets_type"
    all_datasets = []
    for path, type in zip(args.datasets_path, args.datasets_type):
        if type == 'txt':
            all_datasets.append(process_pool_read_txt_dataset(
                args, input_root=path, thres=args.thres))
        elif type == 'csv':
            all_datasets.append(process_pool_read_csv_dataset(
                args, input_root=path, thres=args.thres))
        elif type == 'fs_datasets':
            from fengshen.data.fs_datasets import load_dataset
            all_datasets.append(load_dataset(path, num_proc=args.num_workers,
                                             thres=args.thres, global_rank=global_rank)['train'])
        else:
            raise ValueError('unsupport dataset type: %s' % type)
        print(f'load datasset {type} {path} len {len(all_datasets[-1])}')
    return {'train': ConcatDataset(all_datasets)}