File size: 6,725 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
from torch.utils.data import Dataset, ConcatDataset
import os
from concurrent.futures import ProcessPoolExecutor
import pandas as pd
def add_data_args(parent_args):
parser = parent_args.add_argument_group('taiyi stable diffusion data args')
# 支持传入多个路径,分别加载
parser.add_argument(
"--datasets_path", type=str, default=None, required=True, nargs='+',
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--datasets_type", type=str, default=None, required=True, choices=['txt', 'csv', 'fs_datasets'], nargs='+',
help="dataset type, txt or csv, same len as datasets_path",
)
parser.add_argument(
"--resolution", type=int, default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", default=False,
help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--thres", type=float, default=0.2)
return parent_args
class TXTDataset(Dataset):
# 添加Txt数据集读取,主要是针对Zero23m数据集。
def __init__(self,
foloder_name,
thres=0.2):
super().__init__()
# print(f'Loading folder data from {foloder_name}.')
self.image_paths = []
'''
暂时没有开源这部分文件
score_data = pd.read_csv(os.path.join(foloder_name, 'score.csv'))
img_path2score = {score_data['image_path'][i]: score_data['score'][i]
for i in range(len(score_data))}
'''
# print(img_path2score)
# 这里都存的是地址,避免初始化时间过多。
for each_file in os.listdir(foloder_name):
if each_file.endswith('.jpg'):
self.image_paths.append(os.path.join(foloder_name, each_file))
# print('Done loading data. Len of images:', len(self.image_paths))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = str(self.image_paths[idx])
caption_path = img_path.replace('.jpg', '.txt') # 图片名称和文本名称一致。
with open(caption_path, 'r') as f:
caption = f.read()
return {'img_path': img_path, 'caption': caption}
# NOTE 加速读取数据,直接用原版的,在外部使用并行读取策略。30min->3min
class CSVDataset(Dataset):
def __init__(self,
input_filename,
image_root,
img_key,
caption_key,
thres=0.2):
super().__init__()
# logging.debug(f'Loading csv data from {input_filename}.')
print(f'Loading csv data from {input_filename}.')
self.images = []
self.captions = []
if input_filename.endswith('.csv'):
# print(f"Load Data from{input_filename}")
df = pd.read_csv(input_filename, index_col=0, on_bad_lines='skip')
print(f'file {input_filename} datalen {len(df)}')
# 这个图片的路径也需要根据数据集的结构稍微做点修改
self.images.extend(df[img_key].tolist())
self.captions.extend(df[caption_key].tolist())
self.image_root = image_root
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_root, str(self.images[idx]))
return {'img_path': img_path, 'caption': self.captions[idx]}
def if_final_dir(path: str) -> bool:
# 如果当前目录有一个文件,那就算是终极目录
for f in os.scandir(path):
if f.is_file():
return True
return False
def process_pool_read_txt_dataset(args,
input_root=None,
thres=0.2):
p = ProcessPoolExecutor(max_workers=20)
all_datasets = []
res = []
# 遍历该目录下所有的子目录
def traversal_files(path: str):
list_subfolders_with_paths = [f.path for f in os.scandir(path) if f.is_dir()]
for dir_path in list_subfolders_with_paths:
if if_final_dir(dir_path):
res.append(p.submit(TXTDataset,
dir_path,
thres))
else:
traversal_files(dir_path)
traversal_files(input_root)
p.shutdown()
for future in res:
all_datasets.append(future.result())
dataset = ConcatDataset(all_datasets)
return dataset
def process_pool_read_csv_dataset(args,
input_root,
thres=0.20):
# here input_filename is a directory containing a CSV file
all_csvs = os.listdir(os.path.join(input_root, 'release'))
image_root = os.path.join(input_root, 'images')
# csv_with_score = [each for each in all_csvs if 'score' in each]
all_datasets = []
res = []
p = ProcessPoolExecutor(max_workers=150)
for path in all_csvs:
each_csv_path = os.path.join(input_root, 'release', path)
res.append(p.submit(CSVDataset,
each_csv_path,
image_root,
img_key="name",
caption_key="caption",
thres=thres))
p.shutdown()
for future in res:
all_datasets.append(future.result())
dataset = ConcatDataset(all_datasets)
return dataset
def load_data(args, global_rank=0):
assert len(args.datasets_path) == len(args.datasets_type), \
"datasets_path num not equal to datasets_type"
all_datasets = []
for path, type in zip(args.datasets_path, args.datasets_type):
if type == 'txt':
all_datasets.append(process_pool_read_txt_dataset(
args, input_root=path, thres=args.thres))
elif type == 'csv':
all_datasets.append(process_pool_read_csv_dataset(
args, input_root=path, thres=args.thres))
elif type == 'fs_datasets':
from fengshen.data.fs_datasets import load_dataset
all_datasets.append(load_dataset(path, num_proc=args.num_workers,
thres=args.thres, global_rank=global_rank)['train'])
else:
raise ValueError('unsupport dataset type: %s' % type)
print(f'load datasset {type} {path} len {len(all_datasets[-1])}')
return {'train': ConcatDataset(all_datasets)}
|