import multiprocessing |
import pickle |
import time |
import traceback |
from enum import IntEnum |
import cv2 |
import numpy as np |
from core import imagelib, mplib, pathex |
from core.cv2ex import * |
from core.interact import interact as io |
from core.joblib import SubprocessGenerator, ThisThreadGenerator |
from facelib import LandmarksProcessor |
from samplelib import SampleGeneratorBase |
class MaskType(IntEnum): |
none = 0, |
cloth = 1, |
ear_r = 2, |
eye_g = 3, |
hair = 4, |
hat = 5, |
l_brow = 6, |
l_ear = 7, |
l_eye = 8, |
l_lip = 9, |
mouth = 10, |
neck = 11, |
neck_l = 12, |
nose = 13, |
r_brow = 14, |
r_ear = 15, |
r_eye = 16, |
skin = 17, |
u_lip = 18 |
MaskType_to_name = { |
int(MaskType.none ) : 'none', |
int(MaskType.cloth ) : 'cloth', |
int(MaskType.ear_r ) : 'ear_r', |
int(MaskType.eye_g ) : 'eye_g', |
int(MaskType.hair ) : 'hair', |
int(MaskType.hat ) : 'hat', |
int(MaskType.l_brow) : 'l_brow', |
int(MaskType.l_ear ) : 'l_ear', |
int(MaskType.l_eye ) : 'l_eye', |
int(MaskType.l_lip ) : 'l_lip', |
int(MaskType.mouth ) : 'mouth', |
int(MaskType.neck ) : 'neck', |
int(MaskType.neck_l) : 'neck_l', |
int(MaskType.nose ) : 'nose', |
int(MaskType.r_brow) : 'r_brow', |
int(MaskType.r_ear ) : 'r_ear', |
int(MaskType.r_eye ) : 'r_eye', |
int(MaskType.skin ) : 'skin', |
int(MaskType.u_lip ) : 'u_lip', |
} |
MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() } |
class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase): |
def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, |
generators_count=4, data_format="NHWC", |
**kwargs): |
super().__init__(debug, batch_size) |
self.initialized = False |
dataset_path = root_path / 'CelebAMask-HQ' |
if not dataset_path.exists(): |
raise ValueError(f'Unable to find {dataset_path}') |
images_path = dataset_path /'CelebA-HQ-img' |
if not images_path.exists(): |
raise ValueError(f'Unable to find {images_path}') |
masks_path = dataset_path / 'CelebAMask-HQ-mask-anno' |
if not masks_path.exists(): |
raise ValueError(f'Unable to find {masks_path}') |
if self.debug: |
self.generators_count = 1 |
else: |
self.generators_count = max(1, generators_count) |
source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True) |
source_images_paths_len = len(source_images_paths) |
mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True) |
if source_images_paths_len == 0 or len(mask_images_paths) == 0: |
raise ValueError('No training data provided.') |
mask_file_id_hash = {} |
for filepath in io.progress_bar_generator(mask_images_paths, "Loading"): |
stem = filepath.stem |
file_id, mask_type = stem.split('_', 1) |
file_id = int(file_id) |
if file_id not in mask_file_id_hash: |
mask_file_id_hash[file_id] = {} |
mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path)) |
source_file_id_set = set() |
for filepath in source_images_paths: |
stem = filepath.stem |
file_id = int(stem) |
source_file_id_set.update ( {file_id} ) |
for k in mask_file_id_hash.keys(): |
if k not in source_file_id_set: |
io.log_err (f"Corrupted dataset: {k} not in {images_path}") |
if self.debug: |
self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )] |
else: |
self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \ |
for i in range(self.generators_count) ] |
SubprocessGenerator.start_in_parallel( self.generators ) |
self.generator_counter = -1 |
self.initialized = True |
def is_initialized(self): |
return self.initialized |
def __iter__(self): |
return self |
def __next__(self): |
self.generator_counter += 1 |
generator = self.generators[self.generator_counter % len(self.generators) ] |
return next(generator) |
def batch_func(self, param ): |
images_path, masks_path, mask_file_id_hash, data_format = param |
file_ids = list(mask_file_id_hash.keys()) |
shuffle_file_ids = [] |
resolution = 256 |
random_flip = True |
rotation_range=[-15,15] |
scale_range=[-0.10, 0.95] |
tx_range=[-0.3, 0.3] |
ty_range=[-0.3, 0.3] |
random_bilinear_resize = (25,75) |
motion_blur = (25, 5) |
gaussian_blur = (25, 5) |
bs = self.batch_size |
while True: |
batches = None |
n_batch = 0 |
while n_batch < bs: |
try: |
if len(shuffle_file_ids) == 0: |
shuffle_file_ids = file_ids.copy() |
np.random.shuffle(shuffle_file_ids) |
file_id = shuffle_file_ids.pop() |
masks = mask_file_id_hash[file_id] |
image_path = images_path / f'{file_id}.jpg' |
skin_path = masks.get(MaskType.skin, None) |
hair_path = masks.get(MaskType.hair, None) |
hat_path = masks.get(MaskType.hat, None) |
img = cv2_imread(image_path).astype(np.float32) / 255.0 |
mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0 |
if hair_path is not None: |
hair_path = masks_path / hair_path |
if hair_path.exists(): |
hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0 |
mask *= (1-hair) |
if hat_path is not None: |
hat_path = masks_path / hat_path |
if hat_path.exists(): |
hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0 |
mask *= (1-hat) |
warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) |
img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 ) |
h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) |
h = ( h + np.random.randint(360) ) % 360 |
s = np.clip ( s + np.random.random()-0.5, 0, 1 ) |
v = np.clip ( v + np.random.random()/2-0.25, 0, 1 ) |
img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) |
if motion_blur is not None: |
chance, mb_max_size = motion_blur |
chance = np.clip(chance, 0, 100) |
mblur_rnd_chance = np.random.randint(100) |
mblur_rnd_kernel = np.random.randint(mb_max_size)+1 |
mblur_rnd_deg = np.random.randint(360) |
if mblur_rnd_chance < chance: |
img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg ) |
img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) |
if gaussian_blur is not None: |
chance, kernel_max_size = gaussian_blur |
chance = np.clip(chance, 0, 100) |
gblur_rnd_chance = np.random.randint(100) |
gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1 |
if gblur_rnd_chance < chance: |
img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) |
if random_bilinear_resize is not None: |
chance, max_size_per = random_bilinear_resize |
chance = np.clip(chance, 0, 100) |
pick_chance = np.random.randint(100) |
resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) ) |
img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR ) |
img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR ) |
mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None] |
mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) |
mask[mask < 0.5] = 0.0 |
mask[mask >= 0.5] = 1.0 |
mask = np.clip(mask, 0, 1) |
if data_format == "NCHW": |
img = np.transpose(img, (2,0,1) ) |
mask = np.transpose(mask, (2,0,1) ) |
if batches is None: |
batches = [ [], [] ] |
batches[0].append ( img ) |
batches[1].append ( mask ) |
n_batch += 1 |
except: |
io.log_err ( traceback.format_exc() ) |
yield [ np.array(batch) for batch in batches] |