HairSwapModel / models /CtrlHair /common_dataset.py
miguelmuzo's picture
Upload 426 files
3de0e37 verified
# -*- coding: utf-8 -*-
"""
# File name: common_dataset.py
# Time : 2021/12/31 19:19
# Author: [email protected]
# Description:
"""
import os
import random
import numpy as np
import pandas as pd
from global_value_utils import GLOBAL_DATA_ROOT, HAIR_IDX, HAT_IDX, DATASET_NAME
from util.util import path_join_abs
import cv2
class DataFilter:
def __init__(self, cfg=None, test_part=0.096):
# cfg
if cfg is None:
from color_texture_branch.config import cfg
self.cfg = cfg
base_dataset_dir = GLOBAL_DATA_ROOT
dataset_dir = [os.path.join(base_dataset_dir, dn) for dn in DATASET_NAME]
self.data_dirs = dataset_dir
self.random_seed = 7
random.seed(self.random_seed)
####################################################################
# Please modify these if you don't want to use these filters
angle_filter = True
gender_filter = True
gender = ['female']
# gender = ['male', 'female']
gender = set([{'male': 1, 'female': -1}[g] for g in gender])
####################################################################
self.total_list = []
for data_dir in self.data_dirs:
img_dir = os.path.join(data_dir, 'images_256')
if angle_filter:
angle_csv = pd.read_csv(os.path.join(data_dir, 'angle.csv'), index_col=0)
angle_filter_imgs = list(angle_csv.index[angle_csv['angle'] < 5])
cur_list = ['%05d.png' % dd for dd in angle_filter_imgs]
else:
cur_list = os.listdir(img_dir)
if gender_filter:
attr_filter = pd.read_csv(os.path.join(data_dir, 'attr_gender.csv'))
cur_list = [p for p in cur_list if attr_filter.Male[int(p[:-4])] in gender]
self.total_list += [os.path.join(img_dir, p) for p in cur_list]
random.shuffle(self.total_list)
self.test_start = int(len(self.total_list) * (1 - test_part))
self.test_list = self.total_list[self.test_start:]
self.train_list = [st for st in self.total_list if st not in self.test_list]
idx = 0
# make sure the area of hair is big enough
self.hair_region_threshold = 0.07
self.test_face_list = []
while len(self.test_face_list) < cfg.sample_batch_size:
test_file = self.test_list[idx]
if self.valid_face(path_join_abs(test_file, '../..'), test_file[-9:-4]):
self.test_face_list.append(test_file)
idx += 1
self.test_hair_list = []
while len(self.test_hair_list) < cfg.sample_batch_size:
test_file = self.test_list[idx]
if self.valid_hair(path_join_abs(test_file, '../..'), test_file[-9:-4]):
self.test_hair_list.append(test_file)
idx += 1
def valid_face(self, data_dir, img_idx_str):
label_path = os.path.join(data_dir, 'label', img_idx_str + '.png')
label_img = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
hat_region = label_img == HAT_IDX
if hat_region.mean() > 0.03:
return False
return True
def valid_hair(self, data_dir, img_idx_str):
label_path = os.path.join(data_dir, 'label', img_idx_str + '.png')
label_img = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
hair_region = label_img == HAIR_IDX
hat_region = label_img == HAT_IDX
if hat_region.mean() > 0.03:
return False
if hair_region.mean() < self.hair_region_threshold:
return False
return True