|
import multiprocessing |
|
import pickle |
|
import time |
|
import traceback |
|
from enum import IntEnum |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from core import imagelib, mplib, pathex |
|
from core.cv2ex import * |
|
from core.interact import interact as io |
|
from core.joblib import SubprocessGenerator, ThisThreadGenerator |
|
from facelib import LandmarksProcessor |
|
from samplelib import SampleGeneratorBase |
|
|
|
|
|
class MaskType(IntEnum): |
|
none = 0, |
|
cloth = 1, |
|
ear_r = 2, |
|
eye_g = 3, |
|
hair = 4, |
|
hat = 5, |
|
l_brow = 6, |
|
l_ear = 7, |
|
l_eye = 8, |
|
l_lip = 9, |
|
mouth = 10, |
|
neck = 11, |
|
neck_l = 12, |
|
nose = 13, |
|
r_brow = 14, |
|
r_ear = 15, |
|
r_eye = 16, |
|
skin = 17, |
|
u_lip = 18 |
|
|
|
|
|
|
|
MaskType_to_name = { |
|
int(MaskType.none ) : 'none', |
|
int(MaskType.cloth ) : 'cloth', |
|
int(MaskType.ear_r ) : 'ear_r', |
|
int(MaskType.eye_g ) : 'eye_g', |
|
int(MaskType.hair ) : 'hair', |
|
int(MaskType.hat ) : 'hat', |
|
int(MaskType.l_brow) : 'l_brow', |
|
int(MaskType.l_ear ) : 'l_ear', |
|
int(MaskType.l_eye ) : 'l_eye', |
|
int(MaskType.l_lip ) : 'l_lip', |
|
int(MaskType.mouth ) : 'mouth', |
|
int(MaskType.neck ) : 'neck', |
|
int(MaskType.neck_l) : 'neck_l', |
|
int(MaskType.nose ) : 'nose', |
|
int(MaskType.r_brow) : 'r_brow', |
|
int(MaskType.r_ear ) : 'r_ear', |
|
int(MaskType.r_eye ) : 'r_eye', |
|
int(MaskType.skin ) : 'skin', |
|
int(MaskType.u_lip ) : 'u_lip', |
|
} |
|
|
|
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() } |
|
|
|
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase): |
|
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, |
|
generators_count=4, data_format="NHWC", |
|
**kwargs): |
|
|
|
super().__init__(debug, batch_size) |
|
self.initialized = False |
|
|
|
dataset_path = root_path / 'CelebAMask-HQ' |
|
if not dataset_path.exists(): |
|
raise ValueError(f'Unable to find {dataset_path}') |
|
|
|
images_path = dataset_path /'CelebA-HQ-img' |
|
if not images_path.exists(): |
|
raise ValueError(f'Unable to find {images_path}') |
|
|
|
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno' |
|
if not masks_path.exists(): |
|
raise ValueError(f'Unable to find {masks_path}') |
|
|
|
|
|
if self.debug: |
|
self.generators_count = 1 |
|
else: |
|
self.generators_count = max(1, generators_count) |
|
|
|
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True) |
|
source_images_paths_len = len(source_images_paths) |
|
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True) |
|
|
|
if source_images_paths_len == 0 or len(mask_images_paths) == 0: |
|
raise ValueError('No training data provided.') |
|
|
|
mask_file_id_hash = {} |
|
|
|
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"): |
|
stem = filepath.stem |
|
|
|
file_id, mask_type = stem.split('_', 1) |
|
file_id = int(file_id) |
|
|
|
if file_id not in mask_file_id_hash: |
|
mask_file_id_hash[file_id] = {} |
|
|
|
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path)) |
|
|
|
source_file_id_set = set() |
|
|
|
for filepath in source_images_paths: |
|
stem = filepath.stem |
|
|
|
file_id = int(stem) |
|
source_file_id_set.update ( {file_id} ) |
|
|
|
for k in mask_file_id_hash.keys(): |
|
if k not in source_file_id_set: |
|
io.log_err (f"Corrupted dataset: {k} not in {images_path}") |
|
|
|
|
|
|
|
if self.debug: |
|
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )] |
|
else: |
|
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \ |
|
for i in range(self.generators_count) ] |
|
|
|
SubprocessGenerator.start_in_parallel( self.generators ) |
|
|
|
self.generator_counter = -1 |
|
|
|
self.initialized = True |
|
|
|
|
|
def is_initialized(self): |
|
return self.initialized |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
self.generator_counter += 1 |
|
generator = self.generators[self.generator_counter % len(self.generators) ] |
|
return next(generator) |
|
|
|
def batch_func(self, param ): |
|
images_path, masks_path, mask_file_id_hash, data_format = param |
|
|
|
file_ids = list(mask_file_id_hash.keys()) |
|
|
|
shuffle_file_ids = [] |
|
|
|
resolution = 256 |
|
random_flip = True |
|
rotation_range=[-15,15] |
|
scale_range=[-0.10, 0.95] |
|
tx_range=[-0.3, 0.3] |
|
ty_range=[-0.3, 0.3] |
|
|
|
random_bilinear_resize = (25,75) |
|
motion_blur = (25, 5) |
|
gaussian_blur = (25, 5) |
|
|
|
bs = self.batch_size |
|
while True: |
|
batches = None |
|
|
|
n_batch = 0 |
|
while n_batch < bs: |
|
try: |
|
if len(shuffle_file_ids) == 0: |
|
shuffle_file_ids = file_ids.copy() |
|
np.random.shuffle(shuffle_file_ids) |
|
|
|
file_id = shuffle_file_ids.pop() |
|
masks = mask_file_id_hash[file_id] |
|
image_path = images_path / f'{file_id}.jpg' |
|
|
|
skin_path = masks.get(MaskType.skin, None) |
|
hair_path = masks.get(MaskType.hair, None) |
|
hat_path = masks.get(MaskType.hat, None) |
|
|
|
|
|
img = cv2_imread(image_path).astype(np.float32) / 255.0 |
|
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0 |
|
|
|
if hair_path is not None: |
|
hair_path = masks_path / hair_path |
|
if hair_path.exists(): |
|
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0 |
|
mask *= (1-hair) |
|
|
|
if hat_path is not None: |
|
hat_path = masks_path / hat_path |
|
if hat_path.exists(): |
|
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0 |
|
mask *= (1-hat) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) |
|
|
|
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 ) |
|
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) |
|
h = ( h + np.random.randint(360) ) % 360 |
|
s = np.clip ( s + np.random.random()-0.5, 0, 1 ) |
|
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 ) |
|
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) |
|
|
|
if motion_blur is not None: |
|
chance, mb_max_size = motion_blur |
|
chance = np.clip(chance, 0, 100) |
|
|
|
mblur_rnd_chance = np.random.randint(100) |
|
mblur_rnd_kernel = np.random.randint(mb_max_size)+1 |
|
mblur_rnd_deg = np.random.randint(360) |
|
|
|
if mblur_rnd_chance < chance: |
|
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg ) |
|
|
|
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) |
|
|
|
if gaussian_blur is not None: |
|
chance, kernel_max_size = gaussian_blur |
|
chance = np.clip(chance, 0, 100) |
|
|
|
gblur_rnd_chance = np.random.randint(100) |
|
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1 |
|
|
|
if gblur_rnd_chance < chance: |
|
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) |
|
|
|
if random_bilinear_resize is not None: |
|
chance, max_size_per = random_bilinear_resize |
|
chance = np.clip(chance, 0, 100) |
|
pick_chance = np.random.randint(100) |
|
resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) ) |
|
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR ) |
|
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR ) |
|
|
|
|
|
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None] |
|
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) |
|
mask[mask < 0.5] = 0.0 |
|
mask[mask >= 0.5] = 1.0 |
|
mask = np.clip(mask, 0, 1) |
|
|
|
if data_format == "NCHW": |
|
img = np.transpose(img, (2,0,1) ) |
|
mask = np.transpose(mask, (2,0,1) ) |
|
|
|
if batches is None: |
|
batches = [ [], [] ] |
|
|
|
batches[0].append ( img ) |
|
batches[1].append ( mask ) |
|
|
|
n_batch += 1 |
|
except: |
|
io.log_err ( traceback.format_exc() ) |
|
|
|
yield [ np.array(batch) for batch in batches] |
|
|