|
import multiprocessing |
|
import time |
|
import traceback |
|
|
|
import cv2 |
|
import numpy as np |
|
import numpy.linalg as npla |
|
|
|
from core import mplib |
|
from core import imagelib |
|
from core.interact import interact as io |
|
from core.joblib import SubprocessGenerator, ThisThreadGenerator |
|
from core import mathlib |
|
from facelib import LandmarksProcessor, FaceType |
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, |
|
SampleType) |
|
|
|
class SampleGeneratorSAE(SampleGeneratorBase): |
|
def __init__ (self, src_samples_path, dst_samples_path, |
|
resolution, |
|
face_type, |
|
random_src_flip=False, |
|
random_dst_flip=False, |
|
ct_mode=None, |
|
uniform_yaw_distribution=False, |
|
data_format='NHWC', |
|
debug=False, batch_size=1, |
|
raise_on_no_data=True, |
|
**kwargs): |
|
|
|
super().__init__(debug, batch_size) |
|
self.initialized = False |
|
self.resolution = resolution |
|
self.face_type = face_type |
|
self.random_src_flip = random_src_flip |
|
self.random_dst_flip = random_dst_flip |
|
self.ct_mode = ct_mode |
|
self.data_format = data_format |
|
|
|
if self.debug: |
|
self.generators_count = 1 |
|
else: |
|
self.generators_count = 8 |
|
|
|
src_samples = SampleLoader.load (SampleType.FACE, src_samples_path) |
|
src_samples_len = len(src_samples) |
|
|
|
if src_samples_len == 0: |
|
raise ValueError(f'No samples in {src_samples_path}') |
|
|
|
dst_samples = SampleLoader.load (SampleType.FACE, dst_samples_path) |
|
dst_samples_len = len(dst_samples) |
|
|
|
if dst_samples_len == 0: |
|
raise ValueError(f'No samples in {dst_samples_path}') |
|
|
|
if uniform_yaw_distribution: |
|
src_index_host = self._filter_uniform_yaw(src_samples) |
|
dst_index_host = self._filter_uniform_yaw(dst_samples) |
|
else: |
|
src_index_host = mplib.IndexHost(src_samples_len) |
|
dst_index_host = mplib.IndexHost(dst_samples_len) |
|
|
|
ct_index_host = mplib.IndexHost(dst_samples_len) if ct_mode is not None else None |
|
|
|
self.comm_qs = [ multiprocessing.Queue() for i in range(self.generators_count) ] |
|
|
|
if self.debug: |
|
self.generators = [ThisThreadGenerator ( self.batch_func, (self.comm_qs[0], src_samples, dst_samples, src_index_host.create_cli(), dst_index_host.create_cli(), ct_index_host.create_cli() if ct_index_host is not None else None) )] |
|
else: |
|
self.generators = [SubprocessGenerator ( self.batch_func, (self.comm_qs[i], src_samples, dst_samples, src_index_host.create_cli(), dst_index_host.create_cli(), ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ |
|
for i in range(self.generators_count) ] |
|
|
|
self.generator_counter = -1 |
|
|
|
self.initialized = True |
|
|
|
def start(self): |
|
if not self.debug: |
|
SubprocessGenerator.start_in_parallel( self.generators ) |
|
|
|
def _filter_uniform_yaw(self, samples): |
|
samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] |
|
|
|
grads = 128 |
|
|
|
grads_space = np.linspace (-1.2, 1.2,grads) |
|
|
|
yaws_sample_list = [None]*grads |
|
for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): |
|
yaw = grads_space[g] |
|
next_yaw = grads_space[g+1] if g < grads-1 else yaw |
|
|
|
yaw_samples = [] |
|
for idx, pyr in samples_pyr: |
|
s_yaw = -pyr[1] |
|
if (g == 0 and s_yaw < next_yaw) or \ |
|
(g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ |
|
(g == grads-1 and s_yaw >= yaw): |
|
yaw_samples += [ idx ] |
|
if len(yaw_samples) > 0: |
|
yaws_sample_list[g] = yaw_samples |
|
|
|
yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] |
|
|
|
return mplib.Index2DHost( yaws_sample_list ) |
|
|
|
def set_face_scale(self, scale): |
|
for comm_q in self.comm_qs: |
|
comm_q.put( ('face_scale', scale) ) |
|
|
|
|
|
|
|
def is_initialized(self): |
|
return self.initialized |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
if not self.initialized: |
|
return [] |
|
|
|
self.generator_counter += 1 |
|
generator = self.generators[self.generator_counter % len(self.generators) ] |
|
return next(generator) |
|
|
|
def batch_func(self, param ): |
|
comm_q, src_samples, dst_samples, src_index_host, dst_index_host, ct_index_host = param |
|
|
|
batch_size = self.batch_size |
|
resolution = self.resolution |
|
face_type = self.face_type |
|
data_format = self.data_format |
|
random_src_flip = self.random_src_flip |
|
random_dst_flip = self.random_dst_flip |
|
ct_mode = self.ct_mode |
|
|
|
rotation_range=[-10,10] |
|
scale_range=[-0.05, 0.05] |
|
tx_range=[-0.05, 0.05] |
|
ty_range=[-0.05, 0.05] |
|
rnd_state = np.random |
|
|
|
face_scale = 1.0 |
|
|
|
hi_res = 1024 |
|
|
|
def gen_sample(sample, target_face_type, resolution, allow_flip=False, scale=1.0, ct_mode=None, ct_sample=None): |
|
tx = rnd_state.uniform( tx_range[0], tx_range[1] ) |
|
ty = rnd_state.uniform( ty_range[0], ty_range[1] ) |
|
rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) |
|
scale = rnd_state.uniform(scale +scale_range[0], scale +scale_range[1]) |
|
|
|
flip = allow_flip and rnd_state.randint(10) < 4 |
|
|
|
face_type = sample.face_type |
|
face_lmrks = sample.landmarks |
|
face = sample.load_bgr() |
|
h,w,c = face.shape |
|
|
|
if face_type == FaceType.HEAD: |
|
hi_mat = LandmarksProcessor.get_transform_mat (face_lmrks, hi_res, FaceType.HEAD) |
|
else: |
|
hi_mat = LandmarksProcessor.get_transform_mat (face_lmrks, hi_res, FaceType.HEAD_FACE) |
|
|
|
hi_lmrks = LandmarksProcessor.transform_points(face_lmrks, hi_mat) |
|
hi_warp_params = imagelib.gen_warp_params(hi_res) |
|
face_warp_params = imagelib.gen_warp_params(resolution) |
|
|
|
hi_to_target_mat = LandmarksProcessor.get_transform_mat (hi_lmrks, resolution, target_face_type) |
|
hi_to_target_mat = mathlib.transform_mat(hi_to_target_mat, resolution, tx, ty, rotation, scale) |
|
|
|
face_to_target_mat = LandmarksProcessor.get_transform_mat (face_lmrks, resolution, target_face_type) |
|
face_to_target_mat = mathlib.transform_mat(face_to_target_mat, resolution, tx, ty, rotation, scale) |
|
|
|
warped_face = face |
|
if ct_mode is not None: |
|
ct_bgr = ct_sample.load_bgr() |
|
ct_bgr = cv2.resize(ct_bgr, (w,h), interpolation=cv2.INTER_LINEAR ) |
|
warped_face = imagelib.color_transfer (ct_mode, warped_face, ct_bgr) |
|
|
|
warped_face = cv2.warpAffine(warped_face, hi_mat, (hi_res,hi_res), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) |
|
warped_face = np.clip( imagelib.warp_by_params (hi_warp_params, warped_face, can_warp=True, can_transform=False, can_flip=False, border_replicate=cv2.BORDER_REPLICATE), 0, 1) |
|
warped_face = cv2.warpAffine(warped_face, hi_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) |
|
|
|
""" |
|
if face_type != target_face_type: |
|
... |
|
else: |
|
if w != resolution: |
|
face = cv2.resize(face, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) |
|
""" |
|
|
|
|
|
|
|
|
|
target_face = face |
|
if ct_mode is not None: |
|
target_face = imagelib.color_transfer (ct_mode, target_face, ct_bgr) |
|
|
|
target_face = cv2.warpAffine(target_face, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) |
|
|
|
|
|
face_mask = sample.get_xseg_mask() |
|
if face_mask is not None: |
|
if face_mask.shape[0] != h or face_mask.shape[1] != w: |
|
face_mask = cv2.resize(face_mask, (w,h), interpolation=cv2.INTER_CUBIC) |
|
face_mask = imagelib.normalize_channels(face_mask, 1) |
|
else: |
|
face_mask = LandmarksProcessor.get_image_hull_mask (face.shape, face_lmrks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) |
|
face_mask = np.clip(face_mask, 0, 1) |
|
|
|
target_face_mask = cv2.warpAffine(face_mask, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LINEAR ) |
|
target_face_mask = imagelib.normalize_channels(target_face_mask, 1) |
|
target_face_mask = np.clip(target_face_mask, 0, 1) |
|
|
|
em_mask = np.clip(LandmarksProcessor.get_image_eye_mask (face.shape, face_lmrks) + \ |
|
LandmarksProcessor.get_image_mouth_mask (face.shape, face_lmrks), 0, 1) |
|
|
|
target_face_em = cv2.warpAffine(em_mask, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LINEAR ) |
|
target_face_em = imagelib.normalize_channels(target_face_em, 1) |
|
|
|
div = target_face_em.max() |
|
if div != 0.0: |
|
target_face_em = target_face_em / div |
|
|
|
target_face_em = target_face_em * target_face_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if flip: |
|
warped_face = warped_face[:,::-1,...] |
|
target_face = target_face[:,::-1,...] |
|
target_face_mask = target_face_mask[:,::-1,...] |
|
target_face_em = target_face_em[:,::-1,...] |
|
|
|
return warped_face, target_face, target_face_mask, target_face_em |
|
|
|
|
|
while True: |
|
while not comm_q.empty(): |
|
cmd, param = comm_q.get() |
|
if cmd == 'face_scale': |
|
face_scale = param |
|
|
|
batches = [ [], [], [], [], [], [] ,[] ,[] ] |
|
|
|
src_indexes = src_index_host.multi_get(batch_size) |
|
dst_indexes = dst_index_host.multi_get(batch_size) |
|
|
|
for n_batch in range(batch_size): |
|
src_sample = src_samples[src_indexes[n_batch]] |
|
dst_sample = dst_samples[dst_indexes[n_batch]] |
|
|
|
src_warped_face, src_target_face, src_target_face_mask, src_target_face_em = \ |
|
gen_sample(src_sample, face_type, resolution, allow_flip=random_src_flip, scale=face_scale, ct_mode=ct_mode, ct_sample=dst_sample) |
|
|
|
dst_warped_face, dst_target_face, dst_target_face_mask, dst_target_face_em = \ |
|
gen_sample(dst_sample, face_type, resolution, allow_flip=random_dst_flip, scale=face_scale) |
|
|
|
|
|
|
|
if data_format == "NCHW": |
|
src_warped_face = np.transpose(src_warped_face, (2,0,1) ) |
|
src_target_face = np.transpose(src_target_face, (2,0,1) ) |
|
src_target_face_mask = np.transpose(src_target_face_mask, (2,0,1) ) |
|
src_target_face_em = np.transpose(src_target_face_em, (2,0,1) ) |
|
dst_warped_face = np.transpose(dst_warped_face, (2,0,1) ) |
|
dst_target_face = np.transpose(dst_target_face, (2,0,1) ) |
|
dst_target_face_mask = np.transpose(dst_target_face_mask, (2,0,1) ) |
|
dst_target_face_em = np.transpose(dst_target_face_em, (2,0,1) ) |
|
|
|
batches[0].append(src_warped_face) |
|
batches[1].append(src_target_face) |
|
batches[2].append(src_target_face_mask) |
|
batches[3].append(src_target_face_em) |
|
batches[4].append(dst_warped_face) |
|
batches[5].append(dst_target_face) |
|
batches[6].append(dst_target_face_mask) |
|
batches[7].append(dst_target_face_em) |
|
|
|
|
|
yield [ np.array(batch) for batch in batches] |
|
|