|
import multiprocessing |
|
from functools import partial |
|
|
|
import numpy as np |
|
|
|
from core import mathlib |
|
from core.interact import interact as io |
|
from core.leras import nn |
|
from facelib import FaceType |
|
from models import ModelBase |
|
from samplelib import * |
|
|
|
class QModel(ModelBase): |
|
|
|
def on_initialize(self): |
|
device_config = nn.getCurrentDeviceConfig() |
|
devices = device_config.devices |
|
self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" |
|
nn.initialize(data_format=self.model_data_format) |
|
tf = nn.tf |
|
|
|
resolution = self.resolution = 96 |
|
self.face_type = FaceType.FULL |
|
ae_dims = 128 |
|
e_dims = 64 |
|
d_dims = 64 |
|
d_mask_dims = 16 |
|
self.pretrain = False |
|
self.pretrain_just_disabled = False |
|
|
|
masked_training = True |
|
|
|
models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices]) |
|
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' |
|
optimizer_vars_on_cpu = models_opt_device=='/CPU:0' |
|
|
|
input_ch = 3 |
|
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) |
|
mask_shape = nn.get4Dshape(resolution,resolution,1) |
|
|
|
self.model_filename_list = [] |
|
|
|
model_archi = nn.DeepFakeArchi(resolution, opts='ud') |
|
|
|
with tf.device ('/CPU:0'): |
|
|
|
self.warped_src = tf.placeholder (nn.floatx, bgr_shape) |
|
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) |
|
|
|
self.target_src = tf.placeholder (nn.floatx, bgr_shape) |
|
self.target_dst = tf.placeholder (nn.floatx, bgr_shape) |
|
|
|
self.target_srcm = tf.placeholder (nn.floatx, mask_shape) |
|
self.target_dstm = tf.placeholder (nn.floatx, mask_shape) |
|
|
|
|
|
with tf.device (models_opt_device): |
|
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') |
|
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 |
|
|
|
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') |
|
inter_out_ch = self.inter.get_out_ch() |
|
|
|
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') |
|
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') |
|
|
|
self.model_filename_list += [ [self.encoder, 'encoder.npy' ], |
|
[self.inter, 'inter.npy' ], |
|
[self.decoder_src, 'decoder_src.npy'], |
|
[self.decoder_dst, 'decoder_dst.npy'] ] |
|
|
|
if self.is_training: |
|
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() |
|
|
|
|
|
self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') |
|
self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu ) |
|
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] |
|
|
|
if self.is_training: |
|
|
|
gpu_count = max(1, len(devices) ) |
|
bs_per_gpu = max(1, 4 // gpu_count) |
|
self.set_batch_size( gpu_count*bs_per_gpu) |
|
|
|
|
|
gpu_pred_src_src_list = [] |
|
gpu_pred_dst_dst_list = [] |
|
gpu_pred_src_dst_list = [] |
|
gpu_pred_src_srcm_list = [] |
|
gpu_pred_dst_dstm_list = [] |
|
gpu_pred_src_dstm_list = [] |
|
|
|
gpu_src_losses = [] |
|
gpu_dst_losses = [] |
|
gpu_src_dst_loss_gvs = [] |
|
|
|
for gpu_id in range(gpu_count): |
|
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): |
|
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) |
|
with tf.device(f'/CPU:0'): |
|
|
|
gpu_warped_src = self.warped_src [batch_slice,:,:,:] |
|
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] |
|
gpu_target_src = self.target_src [batch_slice,:,:,:] |
|
gpu_target_dst = self.target_dst [batch_slice,:,:,:] |
|
gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] |
|
gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] |
|
|
|
|
|
gpu_src_code = self.inter(self.encoder(gpu_warped_src)) |
|
gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) |
|
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) |
|
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) |
|
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) |
|
|
|
gpu_pred_src_src_list.append(gpu_pred_src_src) |
|
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) |
|
gpu_pred_src_dst_list.append(gpu_pred_src_dst) |
|
|
|
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) |
|
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) |
|
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) |
|
|
|
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) |
|
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) |
|
|
|
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur |
|
gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) |
|
|
|
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src |
|
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst |
|
|
|
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src |
|
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst |
|
|
|
gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur |
|
gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) |
|
|
|
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) |
|
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) |
|
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) |
|
|
|
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) |
|
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) |
|
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) |
|
|
|
gpu_src_losses += [gpu_src_loss] |
|
gpu_dst_losses += [gpu_dst_loss] |
|
|
|
gpu_G_loss = gpu_src_loss + gpu_dst_loss |
|
gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] |
|
|
|
|
|
|
|
with tf.device (models_opt_device): |
|
pred_src_src = nn.concat(gpu_pred_src_src_list, 0) |
|
pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) |
|
pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) |
|
pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) |
|
pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) |
|
pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) |
|
|
|
src_loss = nn.average_tensor_list(gpu_src_losses) |
|
dst_loss = nn.average_tensor_list(gpu_dst_losses) |
|
src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs) |
|
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv) |
|
|
|
|
|
def src_dst_train(warped_src, target_src, target_srcm, \ |
|
warped_dst, target_dst, target_dstm): |
|
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], |
|
feed_dict={self.warped_src :warped_src, |
|
self.target_src :target_src, |
|
self.target_srcm:target_srcm, |
|
self.warped_dst :warped_dst, |
|
self.target_dst :target_dst, |
|
self.target_dstm:target_dstm, |
|
}) |
|
s = np.mean(s) |
|
d = np.mean(d) |
|
return s, d |
|
self.src_dst_train = src_dst_train |
|
|
|
def AE_view(warped_src, warped_dst): |
|
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], |
|
feed_dict={self.warped_src:warped_src, |
|
self.warped_dst:warped_dst}) |
|
|
|
self.AE_view = AE_view |
|
else: |
|
|
|
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): |
|
gpu_dst_code = self.inter(self.encoder(self.warped_dst)) |
|
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) |
|
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) |
|
|
|
def AE_merge( warped_dst): |
|
|
|
return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) |
|
|
|
self.AE_merge = AE_merge |
|
|
|
|
|
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): |
|
if self.pretrain_just_disabled: |
|
do_init = False |
|
if model == self.inter: |
|
do_init = True |
|
else: |
|
do_init = self.is_first_run() |
|
|
|
if not do_init: |
|
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) |
|
|
|
if do_init and self.pretrained_model_path is not None: |
|
pretrained_filepath = self.pretrained_model_path / filename |
|
if pretrained_filepath.exists(): |
|
do_init = not model.load_weights(pretrained_filepath) |
|
|
|
if do_init: |
|
model.init_weights() |
|
|
|
|
|
if self.is_training: |
|
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() |
|
training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() |
|
|
|
cpu_count = min(multiprocessing.cpu_count(), 8) |
|
src_generators_count = cpu_count // 2 |
|
dst_generators_count = cpu_count // 2 |
|
|
|
self.set_training_data_generators ([ |
|
SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), |
|
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), |
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, |
|
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, |
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution} |
|
], |
|
generators_count=src_generators_count ), |
|
|
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), |
|
sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), |
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, |
|
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, |
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution} |
|
], |
|
generators_count=dst_generators_count ) |
|
]) |
|
|
|
self.last_samples = None |
|
|
|
|
|
def get_model_filename_list(self): |
|
return self.model_filename_list |
|
|
|
|
|
def onSave(self): |
|
for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): |
|
model.save_weights ( self.get_strpath_storage_for_file(filename) ) |
|
|
|
|
|
def onTrainOneIter(self): |
|
|
|
if self.get_iter() % 3 == 0 and self.last_samples is not None: |
|
( (warped_src, target_src, target_srcm), \ |
|
(warped_dst, target_dst, target_dstm) ) = self.last_samples |
|
warped_src = target_src |
|
warped_dst = target_dst |
|
else: |
|
samples = self.last_samples = self.generate_next_samples() |
|
( (warped_src, target_src, target_srcm), \ |
|
(warped_dst, target_dst, target_dstm) ) = samples |
|
|
|
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, |
|
warped_dst, target_dst, target_dstm) |
|
|
|
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) |
|
|
|
|
|
def onGetPreview(self, samples, for_history=False): |
|
( (warped_src, target_src, target_srcm), |
|
(warped_dst, target_dst, target_dstm) ) = samples |
|
|
|
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] |
|
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] |
|
|
|
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] |
|
|
|
n_samples = min(4, self.get_batch_size() ) |
|
result = [] |
|
st = [] |
|
for i in range(n_samples): |
|
ar = S[i], SS[i], D[i], DD[i], SD[i] |
|
st.append ( np.concatenate ( ar, axis=1) ) |
|
|
|
result += [ ('Quick96', np.concatenate (st, axis=0 )), ] |
|
|
|
st_m = [] |
|
for i in range(n_samples): |
|
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) |
|
st_m.append ( np.concatenate ( ar, axis=1) ) |
|
|
|
result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] |
|
|
|
return result |
|
|
|
def predictor_func (self, face=None): |
|
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") |
|
|
|
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x, "NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] |
|
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] |
|
|
|
|
|
def get_MergerConfig(self): |
|
import merger |
|
return self.predictor_func, (self.resolution, self.resolution, 3), merger.MergerConfigMasked(face_type=self.face_type, |
|
default_mode = 'overlay', |
|
) |
|
|
|
Model = QModel |
|
|