|
import tensorflow as tf |
|
from deep_heatmaps_model_fusion_net import DeepHeatmapsModel |
|
import os |
|
|
|
flags = tf.app.flags |
|
|
|
|
|
flags.DEFINE_string('output_dir', 'output', "directory for saving models, logs and samples") |
|
flags.DEFINE_string('save_model_path', 'model', "directory for saving the model") |
|
flags.DEFINE_string('save_sample_path', 'sample', |
|
"directory for saving the sampled images, relevant if sample_to_log is False") |
|
flags.DEFINE_string('save_log_path', 'logs', "directory for saving the log file") |
|
flags.DEFINE_string('img_path', '~/landmark_detection_datasets', "data directory") |
|
flags.DEFINE_string('valid_data', 'full', 'validation set to use: full/common/challenging/test') |
|
flags.DEFINE_string('train_crop_dir', 'crop_gt_margin_0.25', "directory of train images cropped to bb (+margin)") |
|
flags.DEFINE_string('img_dir_ns', 'crop_gt_margin_0.25_ns', "directory of train imgs cropped to bb + style transfer") |
|
flags.DEFINE_string('epoch_data_dir', 'epoch_data', "directory containing pre-augmented data for each epoch") |
|
flags.DEFINE_bool('use_epoch_data', False, "use pre-augmented data") |
|
|
|
|
|
flags.DEFINE_integer('print_every', 100, "print losses to screen + log every X steps") |
|
flags.DEFINE_integer('save_every', 20000, "save model every X steps") |
|
flags.DEFINE_integer('sample_every', 5000, "sample heatmaps + landmark predictions every X steps") |
|
flags.DEFINE_integer('sample_grid', 4, 'number of training images in sample') |
|
flags.DEFINE_bool('sample_to_log', True, 'samples will be saved to tensorboard log') |
|
flags.DEFINE_integer('valid_size', 20, 'number of validation images to run') |
|
flags.DEFINE_integer('log_valid_every', 10, 'evaluate on valid set every X epochs') |
|
flags.DEFINE_integer('debug_data_size', 20, 'subset data size to test in debug mode') |
|
flags.DEFINE_bool('debug', False, 'run in debug mode - use subset of the data') |
|
|
|
|
|
flags.DEFINE_string('pre_train_path', 'model/deep_heatmaps-40000', 'pretrained model path') |
|
flags.DEFINE_bool('load_pretrain', False, "load pretrained weight?") |
|
flags.DEFINE_bool('load_primary_only', False, 'fine-tuning using only primary network weights') |
|
|
|
|
|
flags.DEFINE_integer('image_size', 256, "image size") |
|
flags.DEFINE_integer('c_dim', 3, "color channels") |
|
flags.DEFINE_integer('num_landmarks', 68, "number of face landmarks") |
|
flags.DEFINE_float('sigma', 6, "std for heatmap generation gaussian") |
|
flags.DEFINE_integer('scale', 1, 'scale for image normalization 255/1/0') |
|
flags.DEFINE_float('margin', 0.25, 'margin for face crops - % of bb size') |
|
flags.DEFINE_string('bb_type', 'gt', "bb to use - 'gt':for ground truth / 'init':for face detector output") |
|
flags.DEFINE_float('win_mult', 3.33335, 'gaussian filter size for approx maps: 2 * sigma * win_mult + 1') |
|
|
|
|
|
flags.DEFINE_float('l_weight_primary', 1., 'primary loss weight') |
|
flags.DEFINE_float('l_weight_fusion', 0., 'fusion loss weight') |
|
flags.DEFINE_float('l_weight_upsample', 3., 'upsample loss weight') |
|
flags.DEFINE_integer('train_iter', 60000, 'maximum training iterations') |
|
flags.DEFINE_integer('batch_size', 6, "batch_size") |
|
flags.DEFINE_float('learning_rate', 1e-4, "initial learning rate") |
|
flags.DEFINE_bool('adam_optimizer', True, "use adam optimizer (if False momentum optimizer is used)") |
|
flags.DEFINE_float('momentum', 0.95, "optimizer momentum (if adam_optimizer==False)") |
|
flags.DEFINE_integer('step', 100000, 'step for lr decay') |
|
flags.DEFINE_float('gamma', 0.1, 'exponential base for lr decay') |
|
flags.DEFINE_float('reg', 1e-5, 'scalar multiplier for weight decay (0 to disable)') |
|
flags.DEFINE_string('weight_initializer', 'xavier', 'weight initializer: random_normal / xavier') |
|
flags.DEFINE_float('weight_initializer_std', 0.01, 'std for random_normal weight initializer') |
|
flags.DEFINE_float('bias_initializer', 0.0, 'constant value for bias initializer') |
|
|
|
|
|
flags.DEFINE_bool('augment_basic', True, "use basic augmentation?") |
|
flags.DEFINE_bool('augment_texture', False, "use artistic texture augmentation?") |
|
flags.DEFINE_float('p_texture', 0., 'probability of artistic texture augmentation') |
|
flags.DEFINE_bool('augment_geom', False, "use artistic geometric augmentation?") |
|
flags.DEFINE_float('p_geom', 0., 'probability of artistic geometric augmentation') |
|
|
|
|
|
FLAGS = flags.FLAGS |
|
|
|
if not os.path.exists(FLAGS.output_dir): |
|
os.mkdir(FLAGS.output_dir) |
|
|
|
|
|
def main(_): |
|
|
|
save_model_path = os.path.join(FLAGS.output_dir, FLAGS.save_model_path) |
|
save_sample_path = os.path.join(FLAGS.output_dir, FLAGS.save_sample_path) |
|
save_log_path = os.path.join(FLAGS.output_dir, FLAGS.save_log_path) |
|
|
|
|
|
if not os.path.exists(save_model_path): |
|
os.mkdir(save_model_path) |
|
if not os.path.exists(save_log_path): |
|
os.mkdir(save_log_path) |
|
if not os.path.exists(save_sample_path) and not FLAGS.sample_to_log: |
|
os.mkdir(save_sample_path) |
|
|
|
model = DeepHeatmapsModel( |
|
mode='TRAIN', train_iter=FLAGS.train_iter, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate, |
|
l_weight_primary=FLAGS.l_weight_primary, l_weight_fusion=FLAGS.l_weight_fusion, |
|
l_weight_upsample=FLAGS.l_weight_upsample, reg=FLAGS.reg, adam_optimizer=FLAGS.adam_optimizer, |
|
momentum=FLAGS.momentum, step=FLAGS.step, gamma=FLAGS.gamma, |
|
weight_initializer=FLAGS.weight_initializer, weight_initializer_std=FLAGS.weight_initializer_std, |
|
bias_initializer=FLAGS.bias_initializer, image_size=FLAGS.image_size, c_dim=FLAGS.c_dim, |
|
num_landmarks=FLAGS.num_landmarks, sigma=FLAGS.sigma, scale=FLAGS.scale, margin=FLAGS.margin, |
|
bb_type=FLAGS.bb_type, win_mult=FLAGS.win_mult, augment_basic=FLAGS.augment_basic, |
|
augment_texture=FLAGS.augment_texture, p_texture=FLAGS.p_texture, augment_geom=FLAGS.augment_geom, |
|
p_geom=FLAGS.p_geom, output_dir=FLAGS.output_dir, save_model_path=save_model_path, |
|
save_sample_path=save_sample_path, save_log_path=save_log_path, pre_train_path=FLAGS.pre_train_path, |
|
load_pretrain=FLAGS.load_pretrain, load_primary_only=FLAGS.load_primary_only, |
|
img_path=FLAGS.img_path, valid_data=FLAGS.valid_data, valid_size=FLAGS.valid_size, |
|
log_valid_every=FLAGS.log_valid_every, train_crop_dir=FLAGS.train_crop_dir, img_dir_ns=FLAGS.img_dir_ns, |
|
print_every=FLAGS.print_every, save_every=FLAGS.save_every, sample_every=FLAGS.sample_every, |
|
sample_grid=FLAGS.sample_grid, sample_to_log=FLAGS.sample_to_log, debug_data_size=FLAGS.debug_data_size, |
|
debug=FLAGS.debug, use_epoch_data=FLAGS.use_epoch_data, epoch_data_dir=FLAGS.epoch_data_dir) |
|
|
|
model.train() |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|