|
import scipy.io |
|
import scipy.misc |
|
from glob import glob |
|
import os |
|
import numpy as np |
|
from thirdparty.face_of_art.ops import * |
|
import tensorflow as tf |
|
from tensorflow import contrib |
|
from thirdparty.face_of_art.menpo_functions import * |
|
from thirdparty.face_of_art.logging_functions import * |
|
from thirdparty.face_of_art.data_loading_functions import * |
|
|
|
|
|
class DeepHeatmapsModel(object): |
|
|
|
"""facial landmark localization Network""" |
|
|
|
def __init__(self, mode='TRAIN', train_iter=100000, batch_size=10, learning_rate=1e-3, l_weight_primary=1., |
|
l_weight_fusion=1.,l_weight_upsample=3.,adam_optimizer=True,momentum=0.95,step=100000, gamma=0.1,reg=0, |
|
weight_initializer='xavier', weight_initializer_std=0.01, bias_initializer=0.0, image_size=256,c_dim=3, |
|
num_landmarks=68, sigma=1.5, scale=1, margin=0.25, bb_type='gt', win_mult=3.33335, |
|
augment_basic=True,augment_texture=False, p_texture=0., augment_geom=False, p_geom=0., |
|
output_dir='output', save_model_path='model', |
|
save_sample_path='sample', save_log_path='logs', test_model_path='model/deep_heatmaps-50000', |
|
pre_train_path='model/deep_heatmaps-50000', load_pretrain=False, load_primary_only=False, |
|
img_path='data', test_data='full', valid_data='full', valid_size=0, log_valid_every=5, |
|
train_crop_dir='crop_gt_margin_0.25', img_dir_ns='crop_gt_margin_0.25_ns', |
|
print_every=100, save_every=5000, sample_every=5000, sample_grid=9, sample_to_log=True, |
|
debug_data_size=20, debug=False, epoch_data_dir='epoch_data', use_epoch_data=False, menpo_verbose=True): |
|
|
|
|
|
|
|
self.log_histograms = False |
|
self.save_valid_images = True |
|
self.sample_per_channel = False |
|
|
|
|
|
self.reset_training_op = False |
|
|
|
self.fast_img_gen = True |
|
|
|
self.compute_nme = True |
|
|
|
self.config = tf.ConfigProto() |
|
self.config.gpu_options.allow_growth = True |
|
|
|
|
|
self.print_every = print_every |
|
self.save_every = save_every |
|
self.sample_every = sample_every |
|
self.sample_grid = sample_grid |
|
self.sample_to_log = sample_to_log |
|
self.log_valid_every = log_valid_every |
|
|
|
self.debug = debug |
|
self.debug_data_size = debug_data_size |
|
self.use_epoch_data = use_epoch_data |
|
self.epoch_data_dir = epoch_data_dir |
|
|
|
self.load_pretrain = load_pretrain |
|
self.load_primary_only = load_primary_only |
|
self.pre_train_path = pre_train_path |
|
|
|
self.mode = mode |
|
self.train_iter = train_iter |
|
self.learning_rate = learning_rate |
|
|
|
self.image_size = image_size |
|
self.c_dim = c_dim |
|
self.batch_size = batch_size |
|
|
|
self.num_landmarks = num_landmarks |
|
|
|
self.save_log_path = save_log_path |
|
self.save_sample_path = save_sample_path |
|
self.save_model_path = save_model_path |
|
self.test_model_path = test_model_path |
|
self.img_path=img_path |
|
|
|
self.momentum = momentum |
|
self.step = step |
|
self.gamma = gamma |
|
self.reg = reg |
|
self.l_weight_primary = l_weight_primary |
|
self.l_weight_fusion = l_weight_fusion |
|
self.l_weight_upsample = l_weight_upsample |
|
|
|
self.weight_initializer = weight_initializer |
|
self.weight_initializer_std = weight_initializer_std |
|
self.bias_initializer = bias_initializer |
|
self.adam_optimizer = adam_optimizer |
|
|
|
self.sigma = sigma |
|
self.scale = scale |
|
self.win_mult = win_mult |
|
|
|
self.test_data = test_data |
|
self.train_crop_dir = train_crop_dir |
|
self.img_dir_ns = os.path.join(img_path,img_dir_ns) |
|
self.augment_basic = augment_basic |
|
self.augment_texture = augment_texture |
|
self.p_texture = p_texture |
|
self.augment_geom = augment_geom |
|
self.p_geom = p_geom |
|
|
|
self.valid_size = valid_size |
|
self.valid_data = valid_data |
|
|
|
|
|
self.bb_dir = os.path.join(img_path, 'Bounding_Boxes') |
|
self.bb_dictionary = load_bb_dictionary(self.bb_dir, mode, test_data=self.test_data) |
|
|
|
|
|
if self.use_epoch_data: |
|
epoch_0 = os.path.join(self.epoch_data_dir, '0') |
|
self.img_menpo_list = load_menpo_image_list( |
|
img_path, train_crop_dir=epoch_0, img_dir_ns=None, mode=mode, bb_dictionary=self.bb_dictionary, |
|
image_size=self.image_size, test_data=self.test_data, augment_basic=False, augment_texture=False, |
|
augment_geom=False, verbose=menpo_verbose) |
|
else: |
|
self.img_menpo_list = load_menpo_image_list( |
|
img_path, train_crop_dir, self.img_dir_ns, mode, bb_dictionary=self.bb_dictionary, |
|
image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.test_data, |
|
augment_basic=augment_basic, augment_texture=augment_texture, p_texture=p_texture, |
|
augment_geom=augment_geom, p_geom=p_geom, verbose=menpo_verbose) |
|
|
|
if mode == 'TRAIN': |
|
|
|
train_params = locals() |
|
print_training_params_to_file(train_params) |
|
|
|
self.train_inds = np.arange(len(self.img_menpo_list)) |
|
|
|
if self.debug: |
|
self.train_inds = self.train_inds[:self.debug_data_size] |
|
self.img_menpo_list = self.img_menpo_list[self.train_inds] |
|
|
|
if valid_size > 0: |
|
|
|
self.valid_bb_dictionary = load_bb_dictionary(self.bb_dir, 'TEST', test_data=self.valid_data) |
|
self.valid_img_menpo_list = load_menpo_image_list( |
|
img_path, train_crop_dir, self.img_dir_ns, 'TEST', bb_dictionary=self.valid_bb_dictionary, |
|
image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.valid_data, |
|
verbose=menpo_verbose) |
|
|
|
np.random.seed(0) |
|
self.val_inds = np.arange(len(self.valid_img_menpo_list)) |
|
np.random.shuffle(self.val_inds) |
|
self.val_inds = self.val_inds[:self.valid_size] |
|
|
|
self.valid_img_menpo_list = self.valid_img_menpo_list[self.val_inds] |
|
|
|
self.valid_images_loaded =\ |
|
np.zeros([self.valid_size, self.image_size, self.image_size, self.c_dim]).astype('float32') |
|
self.valid_gt_maps_small_loaded =\ |
|
np.zeros([self.valid_size, self.image_size / 4, self.image_size / 4, |
|
self.num_landmarks]).astype('float32') |
|
self.valid_gt_maps_loaded =\ |
|
np.zeros([self.valid_size, self.image_size, self.image_size, self.num_landmarks] |
|
).astype('float32') |
|
self.valid_landmarks_loaded = np.zeros([self.valid_size, num_landmarks, 2]).astype('float32') |
|
self.valid_landmarks_pred = np.zeros([self.valid_size, self.num_landmarks, 2]).astype('float32') |
|
|
|
load_images_landmarks_approx_maps_alloc_once( |
|
self.valid_img_menpo_list, np.arange(self.valid_size), images=self.valid_images_loaded, |
|
maps_small=self.valid_gt_maps_small_loaded, maps=self.valid_gt_maps_loaded, |
|
landmarks=self.valid_landmarks_loaded, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, scale=self.scale, win_mult=self.win_mult, sigma=self.sigma, |
|
save_landmarks=self.compute_nme) |
|
|
|
if self.valid_size > self.sample_grid: |
|
self.valid_gt_maps_loaded = self.valid_gt_maps_loaded[:self.sample_grid] |
|
self.valid_gt_maps_small_loaded = self.valid_gt_maps_small_loaded[:self.sample_grid] |
|
else: |
|
self.val_inds = None |
|
|
|
self.epoch_inds_shuffle = train_val_shuffle_inds_per_epoch( |
|
self.val_inds, self.train_inds, train_iter, batch_size, save_log_path) |
|
|
|
def add_placeholders(self): |
|
|
|
if self.mode == 'TEST': |
|
self.images = tf.placeholder( |
|
tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'images') |
|
|
|
self.heatmaps = tf.placeholder( |
|
tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'heatmaps') |
|
|
|
self.heatmaps_small = tf.placeholder( |
|
tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'heatmaps_small') |
|
self.lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'lms') |
|
self.pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'pred_lms') |
|
|
|
elif self.mode == 'TRAIN': |
|
self.images = tf.placeholder( |
|
tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'train_images') |
|
|
|
self.heatmaps = tf.placeholder( |
|
tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'train_heatmaps') |
|
|
|
self.heatmaps_small = tf.placeholder( |
|
tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'train_heatmaps_small') |
|
|
|
self.train_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_lms') |
|
self.train_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_pred_lms') |
|
|
|
self.valid_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_lms') |
|
self.valid_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_pred_lms') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.sample_to_log: |
|
row = int(np.sqrt(self.sample_grid)) |
|
self.log_image_map_small = tf.placeholder( |
|
tf.uint8, [None, row * int(self.image_size/4), 3 * row * int(self.image_size/4), self.c_dim], |
|
'sample_img_map_small') |
|
self.log_image_map = tf.placeholder( |
|
tf.uint8, [None, row * self.image_size, 3 * row * self.image_size, self.c_dim], |
|
'sample_img_map') |
|
if self.sample_per_channel: |
|
row = np.ceil(np.sqrt(self.num_landmarks)).astype(np.int64) |
|
self.log_map_channels_small = tf.placeholder( |
|
tf.uint8, [None, row * int(self.image_size/4), 2 * row * int(self.image_size/4), self.c_dim], |
|
'sample_map_channels_small') |
|
self.log_map_channels = tf.placeholder( |
|
tf.uint8, [None, row * self.image_size, 2 * row * self.image_size, self.c_dim], |
|
'sample_map_channels') |
|
|
|
def heatmaps_network(self, input_images, reuse=None, name='pred_heatmaps'): |
|
|
|
with tf.name_scope(name): |
|
|
|
if self.weight_initializer == 'xavier': |
|
weight_initializer = contrib.layers.xavier_initializer() |
|
else: |
|
weight_initializer = tf.random_normal_initializer(stddev=self.weight_initializer_std) |
|
|
|
bias_init = tf.constant_initializer(self.bias_initializer) |
|
|
|
with tf.variable_scope('heatmaps_network'): |
|
with tf.name_scope('primary_net'): |
|
|
|
l1 = conv_relu_pool(input_images, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init, |
|
reuse=reuse, var_scope='conv_1') |
|
l2 = conv_relu_pool(l1, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init, |
|
reuse=reuse, var_scope='conv_2') |
|
l3 = conv_relu(l2, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init, |
|
reuse=reuse, var_scope='conv_3') |
|
|
|
l4_1 = conv_relu(l3, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_1') |
|
l4_2 = conv_relu(l3, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_2') |
|
l4_3 = conv_relu(l3, 3, 128, conv_dilation=3, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_3') |
|
l4_4 = conv_relu(l3, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_4') |
|
|
|
l4 = tf.concat([l4_1, l4_2, l4_3, l4_4], 3, name='conv_4') |
|
|
|
l5_1 = conv_relu(l4, 3, 256, conv_dilation=1, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_1') |
|
l5_2 = conv_relu(l4, 3, 256, conv_dilation=2, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_2') |
|
l5_3 = conv_relu(l4, 3, 256, conv_dilation=3, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_3') |
|
l5_4 = conv_relu(l4, 3, 256, conv_dilation=4, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_4') |
|
|
|
l5 = tf.concat([l5_1, l5_2, l5_3, l5_4], 3, name='conv_5') |
|
|
|
l6 = conv_relu(l5, 1, 512, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_6') |
|
l7 = conv_relu(l6, 1, 256, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_7') |
|
primary_out = conv(l7, 1, self.num_landmarks, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_8') |
|
|
|
with tf.name_scope('fusion_net'): |
|
|
|
l_fsn_0 = tf.concat([l3, l7], 3, name='conv_3_7_fsn') |
|
|
|
l_fsn_1_1 = conv_relu(l_fsn_0, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_1') |
|
l_fsn_1_2 = conv_relu(l_fsn_0, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_2') |
|
l_fsn_1_3 = conv_relu(l_fsn_0, 3, 64, conv_dilation=3, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_3') |
|
|
|
l_fsn_1 = tf.concat([l_fsn_1_1, l_fsn_1_2, l_fsn_1_3], 3, name='conv_fsn_1') |
|
|
|
l_fsn_2_1 = conv_relu(l_fsn_1, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_1') |
|
l_fsn_2_2 = conv_relu(l_fsn_1, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_2') |
|
l_fsn_2_3 = conv_relu(l_fsn_1, 3, 64, conv_dilation=4, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_3') |
|
l_fsn_2_4 = conv_relu(l_fsn_1, 5, 64, conv_dilation=3, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_4') |
|
|
|
l_fsn_2 = tf.concat([l_fsn_2_1, l_fsn_2_2, l_fsn_2_3, l_fsn_2_4], 3, name='conv_fsn_2') |
|
|
|
l_fsn_3_1 = conv_relu(l_fsn_2, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_1') |
|
l_fsn_3_2 = conv_relu(l_fsn_2, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_2') |
|
l_fsn_3_3 = conv_relu(l_fsn_2, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_3') |
|
l_fsn_3_4 = conv_relu(l_fsn_2, 5, 128, conv_dilation=3, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_4') |
|
|
|
l_fsn_3 = tf.concat([l_fsn_3_1, l_fsn_3_2, l_fsn_3_3, l_fsn_3_4], 3, name='conv_fsn_3') |
|
|
|
l_fsn_4 = conv_relu(l_fsn_3, 1, 256, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_4') |
|
fusion_out = conv(l_fsn_4, 1, self.num_landmarks, conv_ker_init=weight_initializer, |
|
conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_5') |
|
|
|
with tf.name_scope('upsample_net'): |
|
|
|
out = deconv(fusion_out, 8, self.num_landmarks, conv_stride=4, |
|
conv_ker_init=deconv2d_bilinear_upsampling_initializer( |
|
[8, 8, self.num_landmarks, self.num_landmarks]), conv_bias_init=bias_init, |
|
reuse=reuse, var_scope='deconv_1') |
|
|
|
self.all_layers = [l1, l2, l3, l4, l5, l6, l7, primary_out, l_fsn_1, l_fsn_2, l_fsn_3, l_fsn_4, |
|
fusion_out, out] |
|
|
|
return primary_out, fusion_out, out |
|
|
|
def build_model(self): |
|
self.pred_hm_p, self.pred_hm_f, self.pred_hm_u = self.heatmaps_network(self.images,name='heatmaps_prediction') |
|
|
|
def create_loss_ops(self): |
|
|
|
def nme_norm_eyes(pred_landmarks, real_landmarks, normalize=True, name='NME'): |
|
"""calculate normalized mean error on landmarks - normalize with inter pupil distance""" |
|
|
|
with tf.name_scope(name): |
|
with tf.name_scope('real_pred_landmarks_rmse'): |
|
|
|
landmarks_rms_err = tf.reduce_mean( |
|
tf.sqrt(tf.reduce_sum(tf.square(pred_landmarks - real_landmarks), axis=2)), axis=1) |
|
if normalize: |
|
|
|
with tf.name_scope('inter_pupil_dist'): |
|
with tf.name_scope('left_eye_center'): |
|
p1 = tf.reduce_mean(tf.slice(real_landmarks, [0, 42, 0], [-1, 6, 2]), axis=1) |
|
with tf.name_scope('right_eye_center'): |
|
p2 = tf.reduce_mean(tf.slice(real_landmarks, [0, 36, 0], [-1, 6, 2]), axis=1) |
|
|
|
eye_dist = tf.sqrt(tf.reduce_sum(tf.square(p1 - p2), axis=1)) |
|
|
|
return landmarks_rms_err / eye_dist |
|
else: |
|
return landmarks_rms_err |
|
|
|
if self.mode is 'TRAIN': |
|
|
|
|
|
primary_maps_diff = self.pred_hm_p - self.heatmaps_small |
|
fusion_maps_diff = self.pred_hm_f - self.heatmaps_small |
|
upsample_maps_diff = self.pred_hm_u - self.heatmaps |
|
|
|
self.l2_primary = tf.reduce_mean(tf.square(primary_maps_diff)) |
|
self.l2_fusion = tf.reduce_mean(tf.square(fusion_maps_diff)) |
|
self.l2_upsample = tf.reduce_mean(tf.square(upsample_maps_diff)) |
|
|
|
self.total_loss = 1000.*(self.l_weight_primary * self.l2_primary + self.l_weight_fusion * self.l2_fusion + |
|
self.l_weight_upsample * self.l2_upsample) |
|
|
|
|
|
self.total_loss += self.reg * tf.add_n( |
|
[tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name]) |
|
|
|
|
|
if self.compute_nme: |
|
self.nme_loss = tf.reduce_mean(nme_norm_eyes(self.train_pred_lms, self.train_lms)) |
|
|
|
if self.valid_size > 0 and self.compute_nme: |
|
self.valid_nme_loss = tf.reduce_mean(nme_norm_eyes(self.valid_pred_lms, self.valid_lms)) |
|
|
|
elif self.mode == 'TEST' and self.compute_nme: |
|
self.nme_per_image = nme_norm_eyes(self.pred_lms, self.lms) |
|
self.nme_loss = tf.reduce_mean(self.nme_per_image) |
|
|
|
def predict_valid_landmarks_in_batches(self, images, session): |
|
|
|
num_images=int(images.shape[0]) |
|
num_batches = int(1.*num_images/self.batch_size) |
|
if num_batches == 0: |
|
batch_size = num_images |
|
num_batches = 1 |
|
else: |
|
batch_size = self.batch_size |
|
|
|
for j in range(num_batches): |
|
|
|
batch_images = images[j * batch_size:(j + 1) * batch_size,:,:,:] |
|
batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images}) |
|
batch_heat_maps_to_landmarks_alloc_once( |
|
batch_maps=batch_maps_pred, batch_landmarks=self.valid_landmarks_pred[j * batch_size:(j + 1) * batch_size, :, :], |
|
batch_size=batch_size,image_size=self.image_size,num_landmarks=self.num_landmarks) |
|
|
|
reminder = num_images-num_batches*batch_size |
|
if reminder > 0: |
|
batch_images = images[-reminder:, :, :, :] |
|
batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images}) |
|
|
|
batch_heat_maps_to_landmarks_alloc_once( |
|
batch_maps=batch_maps_pred, |
|
batch_landmarks=self.valid_landmarks_pred[-reminder:, :, :], |
|
batch_size=reminder, image_size=self.image_size, num_landmarks=self.num_landmarks) |
|
|
|
def create_summary_ops(self): |
|
"""create summary ops for logging""" |
|
|
|
|
|
l2_primary = tf.summary.scalar('l2_primary', self.l2_primary) |
|
l2_fusion = tf.summary.scalar('l2_fusion', self.l2_fusion) |
|
l2_upsample = tf.summary.scalar('l2_upsample', self.l2_upsample) |
|
|
|
l_total = tf.summary.scalar('l_total', self.total_loss) |
|
self.batch_summary_op = tf.summary.merge([l2_primary,l2_fusion,l2_upsample,l_total]) |
|
|
|
if self.compute_nme: |
|
nme = tf.summary.scalar('nme', self.nme_loss) |
|
self.batch_summary_op = tf.summary.merge([self.batch_summary_op, nme]) |
|
|
|
if self.log_histograms: |
|
var_summary = [tf.summary.histogram(var.name,var) for var in tf.trainable_variables()] |
|
grads = tf.gradients(self.total_loss, tf.trainable_variables()) |
|
grads = list(zip(grads, tf.trainable_variables())) |
|
grad_summary = [tf.summary.histogram(var.name+'/grads',grad) for grad,var in grads] |
|
activ_summary = [tf.summary.histogram(layer.name, layer) for layer in self.all_layers] |
|
self.batch_summary_op = tf.summary.merge([self.batch_summary_op, var_summary, grad_summary, activ_summary]) |
|
|
|
if self.valid_size > 0 and self.compute_nme: |
|
self.valid_summary = tf.summary.scalar('valid_nme', self.valid_nme_loss) |
|
|
|
if self.sample_to_log: |
|
img_map_summary_small = tf.summary.image('compare_map_to_gt_small', self.log_image_map_small) |
|
img_map_summary = tf.summary.image('compare_map_to_gt', self.log_image_map) |
|
|
|
if self.sample_per_channel: |
|
map_channels_summary = tf.summary.image('compare_map_channels_to_gt', self.log_map_channels) |
|
map_channels_summary_small = tf.summary.image('compare_map_channels_to_gt_small', |
|
self.log_map_channels_small) |
|
self.img_summary = tf.summary.merge( |
|
[img_map_summary, img_map_summary_small,map_channels_summary,map_channels_summary_small]) |
|
else: |
|
self.img_summary = tf.summary.merge([img_map_summary, img_map_summary_small]) |
|
|
|
if self.valid_size >= self.sample_grid: |
|
img_map_summary_valid_small = tf.summary.image('compare_map_to_gt_small_valid', self.log_image_map_small) |
|
img_map_summary_valid = tf.summary.image('compare_map_to_gt_valid', self.log_image_map) |
|
|
|
if self.sample_per_channel: |
|
map_channels_summary_valid_small = tf.summary.image('compare_map_channels_to_gt_small_valid', |
|
self.log_map_channels_small) |
|
map_channels_summary_valid = tf.summary.image('compare_map_channels_to_gt_valid', |
|
self.log_map_channels) |
|
self.img_summary_valid = tf.summary.merge( |
|
[img_map_summary_valid,img_map_summary_valid_small,map_channels_summary_valid, |
|
map_channels_summary_valid_small]) |
|
else: |
|
self.img_summary_valid = tf.summary.merge([img_map_summary_valid, img_map_summary_valid_small]) |
|
|
|
def train(self): |
|
|
|
tf.set_random_seed(1234) |
|
np.random.seed(1234) |
|
|
|
|
|
self.add_placeholders() |
|
|
|
self.build_model() |
|
|
|
self.create_loss_ops() |
|
|
|
self.create_summary_ops() |
|
|
|
|
|
global_step = tf.Variable(0, trainable=False) |
|
lr = tf.train.exponential_decay(self.learning_rate,global_step, self.step, self.gamma, staircase=True) |
|
if self.adam_optimizer: |
|
optimizer = tf.train.AdamOptimizer(lr) |
|
else: |
|
optimizer = tf.train.MomentumOptimizer(lr, self.momentum) |
|
|
|
train_op = optimizer.minimize(self.total_loss,global_step=global_step) |
|
|
|
with tf.Session(config=self.config) as sess: |
|
|
|
tf.global_variables_initializer().run() |
|
|
|
|
|
if self.load_pretrain: |
|
print |
|
print('*** loading pre-trained weights from: '+self.pre_train_path+' ***') |
|
if self.load_primary_only: |
|
print('*** loading primary-net only ***') |
|
primary_var = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if |
|
('deconv_' not in v.name) and ('_fsn_' not in v.name)] |
|
loader = tf.train.Saver(var_list=primary_var) |
|
else: |
|
loader = tf.train.Saver() |
|
loader.restore(sess, self.pre_train_path) |
|
print("*** Model restore finished, current global step: %d" % global_step.eval()) |
|
|
|
|
|
if self.reset_training_op: |
|
print ("resetting optimizer and global step") |
|
opt_var_list = [optimizer.get_slot(var, name) for name in optimizer.get_slot_names() |
|
for var in tf.global_variables() if optimizer.get_slot(var, name) is not None] |
|
opt_var_list_init = tf.variables_initializer(opt_var_list) |
|
opt_var_list_init.run() |
|
sess.run(global_step.initializer) |
|
|
|
|
|
summary_writer = tf.summary.FileWriter(logdir=self.save_log_path, graph=tf.get_default_graph()) |
|
saver = tf.train.Saver() |
|
|
|
print('\n*** Start Training ***') |
|
|
|
|
|
resume_step = global_step.eval() |
|
num_train_images = len(self.img_menpo_list) |
|
batches_in_epoch = int(float(num_train_images) / float(self.batch_size)) |
|
epoch = int(resume_step / batches_in_epoch) |
|
img_inds = self.epoch_inds_shuffle[epoch, :] |
|
log_valid = True |
|
log_valid_images = True |
|
|
|
|
|
batch_images = np.zeros([self.batch_size, self.image_size, self.image_size, self.c_dim]).astype( |
|
'float32') |
|
batch_lms = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32') |
|
batch_lms_pred = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32') |
|
|
|
batch_maps_small = np.zeros((self.batch_size, int(self.image_size/4), |
|
int(self.image_size/4), self.num_landmarks)).astype('float32') |
|
batch_maps = np.zeros((self.batch_size, self.image_size, self.image_size, |
|
self.num_landmarks)).astype('float32') |
|
|
|
|
|
gaussian_filt_large = create_gaussian_filter(sigma=self.sigma, win_mult=self.win_mult) |
|
gaussian_filt_small = create_gaussian_filter(sigma=1.*self.sigma/4, win_mult=self.win_mult) |
|
|
|
|
|
for step in range(resume_step, self.train_iter): |
|
|
|
j = step % batches_in_epoch |
|
|
|
|
|
if step > resume_step and j == 0: |
|
epoch += 1 |
|
img_inds = self.epoch_inds_shuffle[epoch, :] |
|
log_valid = True |
|
log_valid_images = True |
|
if self.use_epoch_data: |
|
epoch_dir = os.path.join(self.epoch_data_dir, str(epoch)) |
|
self.img_menpo_list = load_menpo_image_list( |
|
self.img_path, train_crop_dir=epoch_dir, img_dir_ns=None, mode=self.mode, |
|
bb_dictionary=self.bb_dictionary, image_size=self.image_size, test_data=self.test_data, |
|
augment_basic=False, augment_texture=False, augment_geom=False) |
|
|
|
|
|
batch_inds = img_inds[j * self.batch_size:(j + 1) * self.batch_size] |
|
|
|
|
|
load_images_landmarks_approx_maps_alloc_once( |
|
self.img_menpo_list, batch_inds, images=batch_images, maps_small=batch_maps_small, |
|
maps=batch_maps, landmarks=batch_lms, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, scale=self.scale, gauss_filt_large=gaussian_filt_large, |
|
gauss_filt_small=gaussian_filt_small, win_mult=self.win_mult, sigma=self.sigma, |
|
save_landmarks=self.compute_nme) |
|
|
|
feed_dict_train = {self.images: batch_images, self.heatmaps: batch_maps, |
|
self.heatmaps_small: batch_maps_small} |
|
|
|
|
|
sess.run(train_op, feed_dict_train) |
|
|
|
|
|
if step == resume_step or (step + 1) % self.print_every == 0: |
|
|
|
|
|
if self.compute_nme: |
|
batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images}) |
|
|
|
batch_heat_maps_to_landmarks_alloc_once( |
|
batch_maps=batch_maps_pred,batch_landmarks=batch_lms_pred, |
|
batch_size=self.batch_size, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks) |
|
|
|
train_feed_dict_log = { |
|
self.images: batch_images, self.heatmaps: batch_maps, |
|
self.heatmaps_small: batch_maps_small, self.train_lms: batch_lms, |
|
self.train_pred_lms: batch_lms_pred} |
|
|
|
summary, l_p, l_f, l_t, nme = sess.run( |
|
[self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss, |
|
self.nme_loss], |
|
train_feed_dict_log) |
|
|
|
print ( |
|
'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f]' |
|
' total loss: [%.6f] NME: [%.6f]' % ( |
|
epoch, step + 1, self.train_iter, l_p, l_f, l_t, nme)) |
|
else: |
|
train_feed_dict_log = {self.images: batch_images, self.heatmaps: batch_maps, |
|
self.heatmaps_small: batch_maps_small} |
|
|
|
summary, l_p, l_f, l_t = sess.run( |
|
[self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss], |
|
train_feed_dict_log) |
|
print ( |
|
'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f] total loss: [%.6f]' |
|
% (epoch, step + 1, self.train_iter, l_p, l_f, l_t)) |
|
|
|
summary_writer.add_summary(summary, step) |
|
|
|
|
|
if self.valid_size > 0 and (log_valid and epoch % self.log_valid_every == 0) \ |
|
and self.compute_nme: |
|
log_valid = False |
|
|
|
self.predict_valid_landmarks_in_batches(self.valid_images_loaded, sess) |
|
valid_feed_dict_log = { |
|
self.valid_lms: self.valid_landmarks_loaded, |
|
self.valid_pred_lms: self.valid_landmarks_pred} |
|
|
|
v_summary, v_nme = sess.run([self.valid_summary, self.valid_nme_loss], |
|
valid_feed_dict_log) |
|
summary_writer.add_summary(v_summary, step) |
|
print ( |
|
'epoch: [%d] step: [%d/%d] valid NME: [%.6f]' % ( |
|
epoch, step + 1, self.train_iter, v_nme)) |
|
|
|
|
|
if (step + 1) % self.save_every == 0: |
|
saver.save(sess, os.path.join(self.save_model_path, 'deep_heatmaps'), global_step=step + 1) |
|
print ('model/deep-heatmaps-%d saved' % (step + 1)) |
|
|
|
|
|
if step == resume_step or (step + 1) % self.sample_every == 0: |
|
|
|
batch_maps_small_pred = sess.run(self.pred_hm_p, {self.images: batch_images}) |
|
if not self.compute_nme: |
|
batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images}) |
|
batch_lms_pred = None |
|
|
|
merged_img = merge_images_landmarks_maps_gt( |
|
batch_images.copy(), batch_maps_pred, batch_maps, landmarks=batch_lms_pred, |
|
image_size=self.image_size, num_landmarks=self.num_landmarks, num_samples=self.sample_grid, |
|
scale=self.scale, circle_size=2, fast=self.fast_img_gen) |
|
|
|
merged_img_small = merge_images_landmarks_maps_gt( |
|
batch_images.copy(), batch_maps_small_pred, batch_maps_small, |
|
image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, num_samples=self.sample_grid, scale=self.scale, |
|
circle_size=0, fast=self.fast_img_gen) |
|
|
|
if self.sample_per_channel: |
|
map_per_channel = map_comapre_channels( |
|
batch_images.copy(), batch_maps_pred, batch_maps, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, scale=self.scale) |
|
|
|
map_per_channel_small = map_comapre_channels( |
|
batch_images.copy(), batch_maps_small_pred, batch_maps_small, image_size=int(self.image_size/4), |
|
num_landmarks=self.num_landmarks, scale=self.scale) |
|
|
|
if self.sample_to_log: |
|
if self.sample_per_channel: |
|
summary_img = sess.run( |
|
self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0), |
|
self.log_map_channels: np.expand_dims(map_per_channel, 0), |
|
self.log_image_map_small: np.expand_dims(merged_img_small, 0), |
|
self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)}) |
|
else: |
|
summary_img = sess.run( |
|
self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0), |
|
self.log_image_map_small: np.expand_dims(merged_img_small, 0)}) |
|
summary_writer.add_summary(summary_img, step) |
|
|
|
if (self.valid_size >= self.sample_grid) and self.save_valid_images and\ |
|
(log_valid_images and epoch % self.log_valid_every == 0): |
|
log_valid_images = False |
|
|
|
batch_maps_small_pred_val,batch_maps_pred_val =\ |
|
sess.run([self.pred_hm_p,self.pred_hm_u], |
|
{self.images: self.valid_images_loaded[:self.sample_grid]}) |
|
|
|
merged_img_small = merge_images_landmarks_maps_gt( |
|
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val, |
|
self.valid_gt_maps_small_loaded, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, num_samples=self.sample_grid, |
|
scale=self.scale, circle_size=0, fast=self.fast_img_gen) |
|
|
|
merged_img = merge_images_landmarks_maps_gt( |
|
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred_val, |
|
self.valid_gt_maps_loaded, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, num_samples=self.sample_grid, |
|
scale=self.scale, circle_size=2, fast=self.fast_img_gen) |
|
|
|
if self.sample_per_channel: |
|
map_per_channel_small = map_comapre_channels( |
|
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val, |
|
self.valid_gt_maps_small_loaded, image_size=int(self.image_size / 4), |
|
num_landmarks=self.num_landmarks, scale=self.scale) |
|
|
|
map_per_channel = map_comapre_channels( |
|
self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred, |
|
self.valid_gt_maps_loaded, image_size=self.image_size, |
|
num_landmarks=self.num_landmarks, scale=self.scale) |
|
|
|
summary_img = sess.run( |
|
self.img_summary_valid, |
|
{self.log_image_map: np.expand_dims(merged_img, 0), |
|
self.log_map_channels: np.expand_dims(map_per_channel, 0), |
|
self.log_image_map_small: np.expand_dims(merged_img_small, 0), |
|
self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)}) |
|
else: |
|
summary_img = sess.run( |
|
self.img_summary_valid, |
|
{self.log_image_map: np.expand_dims(merged_img, 0), |
|
self.log_image_map_small: np.expand_dims(merged_img_small, 0)}) |
|
|
|
summary_writer.add_summary(summary_img, step) |
|
else: |
|
sample_path_imgs = os.path.join( |
|
self.save_sample_path, 'epoch-%d-train-iter-%d-1.png' % (epoch, step + 1)) |
|
sample_path_imgs_small = os.path.join( |
|
self.save_sample_path, 'epoch-%d-train-iter-%d-1-s.png' % (epoch, step + 1)) |
|
scipy.misc.imsave(sample_path_imgs, merged_img) |
|
scipy.misc.imsave(sample_path_imgs_small, merged_img_small) |
|
|
|
if self.sample_per_channel: |
|
sample_path_ch_maps = os.path.join( |
|
self.save_sample_path, 'epoch-%d-train-iter-%d-3.png' % (epoch, step + 1)) |
|
sample_path_ch_maps_small = os.path.join( |
|
self.save_sample_path, 'epoch-%d-train-iter-%d-3-s.png' % (epoch, step + 1)) |
|
scipy.misc.imsave(sample_path_ch_maps, map_per_channel) |
|
scipy.misc.imsave(sample_path_ch_maps_small, map_per_channel_small) |
|
|
|
print('*** Finished Training ***') |
|
|
|
def get_image_maps(self, test_image, reuse=None, norm=False): |
|
""" returns heatmaps of input image (menpo image object)""" |
|
|
|
self.add_placeholders() |
|
|
|
pred_hm_p, pred_hm_f, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse) |
|
|
|
with tf.Session(config=self.config) as sess: |
|
|
|
saver = tf.train.Saver() |
|
saver.restore(sess, self.test_model_path) |
|
_, model_name = os.path.split(self.test_model_path) |
|
|
|
test_image = test_image.pixels_with_channels_at_back().astype('float32') |
|
if norm: |
|
if self.scale is '255': |
|
test_image *= 255 |
|
elif self.scale is '0': |
|
test_image = 2 * test_image - 1 |
|
|
|
map_primary, map_fusion, map_upsample = sess.run( |
|
[pred_hm_p, pred_hm_f, pred_hm_u], {self.images: np.expand_dims(test_image, 0)}) |
|
|
|
return map_primary, map_fusion, map_upsample |
|
|
|
def get_landmark_predictions(self, img_list, pdm_models_dir, clm_model_path, reuse=None, map_to_input_size=False): |
|
|
|
"""returns dictionary with landmark predictions of each step of the ECpTp algorithm and ECT""" |
|
|
|
from thirdparty.face_of_art.pdm_clm_functions import feature_based_pdm_corr, clm_correct |
|
|
|
jaw_line_inds = np.arange(0, 17) |
|
left_brow_inds = np.arange(17, 22) |
|
right_brow_inds = np.arange(22, 27) |
|
|
|
self.add_placeholders() |
|
|
|
_, _, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse) |
|
|
|
with tf.Session(config=self.config) as sess: |
|
|
|
saver = tf.train.Saver() |
|
saver.restore(sess, self.test_model_path) |
|
_, model_name = os.path.split(self.test_model_path) |
|
e_list = [] |
|
ect_list = [] |
|
ecp_list = [] |
|
ecpt_list = [] |
|
ecptp_jaw_list = [] |
|
ecptp_out_list = [] |
|
|
|
for test_image in img_list: |
|
|
|
if map_to_input_size: |
|
test_image_transform = test_image[1] |
|
test_image=test_image[0] |
|
|
|
|
|
if test_image.n_channels < 3: |
|
test_image_map = sess.run( |
|
pred_hm_u, {self.images: np.expand_dims( |
|
gray2rgb(test_image.pixels_with_channels_at_back()).astype('float32'), 0)}) |
|
else: |
|
test_image_map = sess.run( |
|
pred_hm_u, {self.images: np.expand_dims( |
|
test_image.pixels_with_channels_at_back().astype('float32'), 0)}) |
|
init_lms = heat_maps_to_landmarks(np.squeeze(test_image_map)) |
|
|
|
|
|
p_pdm_lms = feature_based_pdm_corr(lms_init=init_lms, models_dir=pdm_models_dir, train_type='basic') |
|
|
|
|
|
try: |
|
pdm_clm_lms = clm_correct( |
|
clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=p_pdm_lms) |
|
except: |
|
pdm_clm_lms = p_pdm_lms.copy() |
|
|
|
|
|
try: |
|
ect_lms = clm_correct( |
|
clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=init_lms) |
|
except: |
|
ect_lms = p_pdm_lms.copy() |
|
|
|
|
|
ecptp_out = p_pdm_lms.copy() |
|
ecptp_out[left_brow_inds] = pdm_clm_lms[left_brow_inds] |
|
ecptp_out[right_brow_inds] = pdm_clm_lms[right_brow_inds] |
|
ecptp_out[jaw_line_inds] = pdm_clm_lms[jaw_line_inds] |
|
|
|
|
|
ecptp_jaw = p_pdm_lms.copy() |
|
ecptp_jaw[jaw_line_inds] = pdm_clm_lms[jaw_line_inds] |
|
|
|
if map_to_input_size: |
|
ecptp_jaw = test_image_transform.apply(ecptp_jaw) |
|
ecptp_out = test_image_transform.apply(ecptp_out) |
|
ect_lms = test_image_transform.apply(ect_lms) |
|
init_lms = test_image_transform.apply(init_lms) |
|
p_pdm_lms = test_image_transform.apply(p_pdm_lms) |
|
pdm_clm_lms = test_image_transform.apply(pdm_clm_lms) |
|
|
|
ecptp_jaw_list.append(ecptp_jaw) |
|
ecptp_out_list.append(ecptp_out) |
|
ect_list.append(ect_lms) |
|
e_list.append(init_lms) |
|
ecp_list.append(p_pdm_lms) |
|
ecpt_list.append(pdm_clm_lms) |
|
|
|
pred_dict = { |
|
'E': e_list, |
|
'ECp': ecp_list, |
|
'ECpT': ecpt_list, |
|
'ECT': ect_list, |
|
'ECpTp_jaw': ecptp_jaw_list, |
|
'ECpTp_out': ecptp_out_list |
|
} |
|
|
|
return pred_dict |
|
|