VxPhotoTalk / src /approaches /train_image_translation.py
VineX's picture
Upload 458 files
7cdd981
"""
# Copyright 2020 Adobe
# All Rights Reserved.
# NOTICE: Adobe permits you to use, modify, and distribute this file in
# accordance with the terms of the Adobe license agreement accompanying
# it.
"""
from src.models.model_image_translation import ResUnetGenerator, VGGLoss
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import time
import numpy as np
import cv2
import os, glob
from src.dataset.image_translation.image_translation_dataset import vis_landmark_on_img, vis_landmark_on_img98, vis_landmark_on_img74
from thirdparty.AdaptiveWingLoss.core import models
from thirdparty.AdaptiveWingLoss.utils.utils import get_preds_fromhm
import face_alignment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Image_translation_block():
def __init__(self, opt_parser, single_test=False):
print('Run on device {}'.format(device))
# for key in vars(opt_parser).keys():
# print(key, ':', vars(opt_parser)[key])
self.opt_parser = opt_parser
# model
if(opt_parser.add_audio_in):
self.G = ResUnetGenerator(input_nc=7, output_nc=3, num_downs=6, use_dropout=False)
else:
self.G = ResUnetGenerator(input_nc=6, output_nc=3, num_downs=6, use_dropout=False)
if (opt_parser.load_G_name != ''):
ckpt = torch.load(opt_parser.load_G_name)
try:
self.G.load_state_dict(ckpt['G'])
except:
tmp = nn.DataParallel(self.G)
tmp.load_state_dict(ckpt['G'])
self.G.load_state_dict(tmp.module.state_dict())
del tmp
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs in G mode!")
self.G = nn.DataParallel(self.G)
self.G.to(device)
if(not single_test):
# dataset
if(opt_parser.use_vox_dataset == 'raw'):
if(opt_parser.comb_fan_awing):
from src.dataset.image_translation.image_translation_dataset import \
image_translation_raw74_dataset as image_translation_dataset
elif(opt_parser.add_audio_in):
from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_with_audio_dataset as \
image_translation_dataset
else:
from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_dataset as \
image_translation_dataset
else:
from src.dataset.image_translation.image_translation_dataset import image_translation_preprocessed98_dataset as \
image_translation_dataset
self.dataset = image_translation_dataset(num_frames=opt_parser.num_frames)
self.dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=opt_parser.batch_size,
shuffle=True,
num_workers=opt_parser.num_workers)
# criterion
self.criterionL1 = nn.L1Loss()
self.criterionVGG = VGGLoss()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs in VGG model!")
self.criterionVGG = nn.DataParallel(self.criterionVGG)
self.criterionVGG.to(device)
# optimizer
self.optimizer = torch.optim.Adam(self.G.parameters(), lr=opt_parser.lr, betas=(0.5, 0.999))
# writer
if(opt_parser.write):
self.writer = SummaryWriter(log_dir=os.path.join(opt_parser.log_dir, opt_parser.name))
self.count = 0
# ===========================================================
# online landmark alignment : Awing
# ===========================================================
PRETRAINED_WEIGHTS = 'thirdparty/AdaptiveWingLoss/ckpt/WFLW_4HG.pth'
GRAY_SCALE = False
HG_BLOCKS = 4
END_RELU = False
NUM_LANDMARKS = 98
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = models.FAN(HG_BLOCKS, END_RELU, GRAY_SCALE, NUM_LANDMARKS)
checkpoint = torch.load(PRETRAINED_WEIGHTS)
if 'state_dict' not in checkpoint:
model_ft.load_state_dict(checkpoint)
else:
pretrained_weights = checkpoint['state_dict']
model_weights = model_ft.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items() \
if k in model_weights}
model_weights.update(pretrained_weights)
model_ft.load_state_dict(model_weights)
print('Load AWing model sucessfully')
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs for AWing!")
self.fa_model = nn.DataParallel(model_ft).to(self.device).eval()
else:
self.fa_model = model_ft.to(self.device).eval()
# ===========================================================
# online landmark alignment : FAN
# ===========================================================
if(opt_parser.comb_fan_awing):
if(opt_parser.fan_2or3D == '2D'):
self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
device='cuda' if torch.cuda.is_available() else "cpu",
flip_input=True)
else:
self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D,
device='cuda' if torch.cuda.is_available() else "cpu",
flip_input=True)
def __train_pass__(self, epoch, is_training=True):
st_epoch = time.time()
if(is_training):
self.G.train()
status = 'TRAIN'
else:
self.G.eval()
status = 'EVAL'
g_time = 0.0
for i, batch in enumerate(self.dataloader):
if(i >= len(self.dataloader)-2):
break
st_batch = time.time()
if(self.opt_parser.comb_fan_awing):
image_in, image_out, fan_pred_landmarks = batch
fan_pred_landmarks = fan_pred_landmarks.reshape(-1, 68, 3).detach().cpu().numpy()
elif(self.opt_parser.add_audio_in):
image_in, image_out, audio_in = batch
audio_in = audio_in.reshape(-1, 1, 256, 256).to(device)
else:
image_in, image_out = batch
with torch.no_grad():
# # online landmark (AwingNet)
image_in, image_out = \
image_in.reshape(-1, 3, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)
inputs = image_out
outputs, boundary_channels = self.fa_model(inputs)
pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
pred_landmarks, _ = get_preds_fromhm(pred_heatmap)
pred_landmarks = pred_landmarks.numpy() * 4
# online landmark (FAN) -> replace jaw + eye brow in AwingNet
if(self.opt_parser.comb_fan_awing):
fl_jaw_eyebrow = fan_pred_landmarks[:, 0:27, 0:2]
fl_rest = pred_landmarks[:, 51:, :]
pred_landmarks = np.concatenate([fl_jaw_eyebrow, fl_rest], axis=1).astype(np.int)
# draw landmark on while bg
img_fls = []
for pred_fl in pred_landmarks:
img_fl = np.ones(shape=(256, 256, 3)) * 255.0
if(self.opt_parser.comb_fan_awing):
img_fl = vis_landmark_on_img74(img_fl, pred_fl) # 74x2
else:
img_fl = vis_landmark_on_img98(img_fl, pred_fl) # 98x2
img_fls.append(img_fl.transpose((2, 0, 1)))
img_fls = np.stack(img_fls, axis=0).astype(np.float32) / 255.0
image_fls_in = torch.tensor(img_fls, requires_grad=False).to(device)
if(self.opt_parser.add_audio_in):
# print(image_fls_in.shape, image_in.shape, audio_in.shape)
image_in = torch.cat([image_fls_in, image_in, audio_in], dim=1)
else:
image_in = torch.cat([image_fls_in, image_in], dim=1)
# image_in, image_out = \
# image_in.reshape(-1, 6, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)
# image2image net fp
g_out = self.G(image_in)
g_out = torch.tanh(g_out)
loss_l1 = self.criterionL1(g_out, image_out)
loss_vgg, loss_style = self.criterionVGG(g_out, image_out, style=True)
loss_vgg, loss_style = torch.mean(loss_vgg), torch.mean(loss_style)
loss = loss_l1 + loss_vgg + loss_style
if(is_training):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# log
if(self.opt_parser.write):
self.writer.add_scalar('loss', loss.cpu().detach().numpy(), self.count)
self.writer.add_scalar('loss_l1', loss_l1.cpu().detach().numpy(), self.count)
self.writer.add_scalar('loss_vgg', loss_vgg.cpu().detach().numpy(), self.count)
self.count += 1
# save image to track training process
if (i % self.opt_parser.jpg_freq == 0):
vis_in = np.concatenate([image_in[0, 3:6].cpu().detach().numpy().transpose((1, 2, 0)),
image_in[0, 0:3].cpu().detach().numpy().transpose((1, 2, 0))], axis=1)
vis_out = np.concatenate([image_out[0].cpu().detach().numpy().transpose((1, 2, 0)),
g_out[0].cpu().detach().numpy().transpose((1, 2, 0))], axis=1)
vis = np.concatenate([vis_in, vis_out], axis=0)
try:
os.makedirs(os.path.join(self.opt_parser.jpg_dir, self.opt_parser.name))
except:
pass
cv2.imwrite(os.path.join(self.opt_parser.jpg_dir, self.opt_parser.name, 'e{:03d}_b{:04d}.jpg'.format(epoch, i)), vis * 255.0)
# save ckpt
if (i % self.opt_parser.ckpt_last_freq == 0):
self.__save_model__('last', epoch)
print("Epoch {}, Batch {}/{}, loss {:.4f}, l1 {:.4f}, vggloss {:.4f}, styleloss {:.4f} time {:.4f}".format(
epoch, i, len(self.dataset) // self.opt_parser.batch_size,
loss.cpu().detach().numpy(),
loss_l1.cpu().detach().numpy(),
loss_vgg.cpu().detach().numpy(),
loss_style.cpu().detach().numpy(),
time.time() - st_batch))
g_time += time.time() - st_batch
if(self.opt_parser.test_speed):
if(i >= 100):
break
print('Epoch time usage:', time.time() - st_epoch, 'I/O time usage:', time.time() - st_epoch - g_time, '\n=========================')
if(self.opt_parser.test_speed):
exit(0)
if(epoch % self.opt_parser.ckpt_epoch_freq == 0):
self.__save_model__('{:02d}'.format(epoch), epoch)
def __save_model__(self, save_type, epoch):
try:
os.makedirs(os.path.join(self.opt_parser.ckpt_dir, self.opt_parser.name))
except:
pass
if (self.opt_parser.write):
torch.save({
'G': self.G.state_dict(),
'opt': self.optimizer,
'epoch': epoch
}, os.path.join(self.opt_parser.ckpt_dir, self.opt_parser.name, 'ckpt_{}.pth'.format(save_type)))
def train(self):
for epoch in range(self.opt_parser.nepoch):
self.__train_pass__(epoch, is_training=True)
def test(self):
if (self.opt_parser.use_vox_dataset == 'raw'):
if(self.opt_parser.add_audio_in):
from src.dataset.image_translation.image_translation_dataset import \
image_translation_raw98_with_audio_test_dataset as image_translation_test_dataset
else:
from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_test_dataset as image_translation_test_dataset
else:
from src.dataset.image_translation.image_translation_dataset import image_translation_preprocessed98_test_dataset as image_translation_test_dataset
self.dataset = image_translation_test_dataset(num_frames=self.opt_parser.num_frames)
self.dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=1,
shuffle=True,
num_workers=self.opt_parser.num_workers)
self.G.eval()
for i, batch in enumerate(self.dataloader):
print(i, 50)
if (i > 50):
break
if (self.opt_parser.add_audio_in):
image_in, image_out, audio_in = batch
audio_in = audio_in.reshape(-1, 1, 256, 256).to(device)
else:
image_in, image_out = batch
# # online landmark (AwingNet)
with torch.no_grad():
image_in, image_out = \
image_in.reshape(-1, 3, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)
pred_landmarks = []
for j in range(image_in.shape[0] // 16):
inputs = image_out[j*16:j*16+16]
outputs, boundary_channels = self.fa_model(inputs)
pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
pred_landmark, _ = get_preds_fromhm(pred_heatmap)
pred_landmarks.append(pred_landmark.numpy() * 4)
pred_landmarks = np.concatenate(pred_landmarks, axis=0)
# draw landmark on while bg
img_fls = []
for pred_fl in pred_landmarks:
img_fl = np.ones(shape=(256, 256, 3)) * 255.0
img_fl = vis_landmark_on_img98(img_fl, pred_fl) # 98x2
img_fls.append(img_fl.transpose((2, 0, 1)))
img_fls = np.stack(img_fls, axis=0).astype(np.float32) / 255.0
image_fls_in = torch.tensor(img_fls, requires_grad=False).to(device)
if (self.opt_parser.add_audio_in):
# print(image_fls_in.shape, image_in.shape, audio_in.shape)
image_in = torch.cat([image_fls_in,
image_in[0:image_fls_in.shape[0]],
audio_in[0:image_fls_in.shape[0]]], dim=1)
else:
image_in = torch.cat([image_fls_in, image_in[0:image_fls_in.shape[0]]], dim=1)
# normal 68 test dataset
# image_in, image_out = image_in.reshape(-1, 6, 256, 256), image_out.reshape(-1, 3, 256, 256)
# random single frame
# cv2.imwrite('random_img_{}.jpg'.format(i), np.swapaxes(image_out[5].numpy(),0, 2)*255.0)
image_in, image_out = image_in.to(device), image_out.to(device)
writer = cv2.VideoWriter('tmp_{:04d}.mp4'.format(i), cv2.VideoWriter_fourcc(*'mjpg'), 25, (256*4, 256))
for j in range(image_in.shape[0] // 16):
g_out = self.G(image_in[j*16:j*16+16])
g_out = torch.tanh(g_out)
# norm 68 pts
# g_out = np.swapaxes(g_out.cpu().detach().numpy(), 1, 3)
# ref_out = np.swapaxes(image_out[j*16:j*16+16].cpu().detach().numpy(), 1, 3)
# ref_in = np.swapaxes(image_in[j*16:j*16+16, 3:6, :, :].cpu().detach().numpy(), 1, 3)
# fls_in = np.swapaxes(image_in[j * 16:j * 16 + 16, 0:3, :, :].cpu().detach().numpy(), 1, 3)
g_out = g_out.cpu().detach().numpy().transpose((0, 2, 3, 1))
g_out[g_out < 0] = 0
ref_out = image_out[j * 16:j * 16 + 16].cpu().detach().numpy().transpose((0, 2, 3, 1))
ref_in = image_in[j * 16:j * 16 + 16, 3:6, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
fls_in = image_in[j * 16:j * 16 + 16, 0:3, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
for k in range(g_out.shape[0]):
frame = np.concatenate((ref_in[k], g_out[k], fls_in[k], ref_out[k]), axis=1) * 255.0
writer.write(frame.astype(np.uint8))
writer.release()
os.system('ffmpeg -y -i tmp_{:04d}.mp4 -pix_fmt yuv420p random_{:04d}.mp4'.format(i, i))
os.system('rm tmp_{:04d}.mp4'.format(i))
def single_test(self, jpg=None, fls=None, filename=None, prefix='', grey_only=False):
import time
st = time.time()
self.G.eval()
if(jpg is None):
jpg = glob.glob1(self.opt_parser.single_test, '*.jpg')[0]
jpg = cv2.imread(os.path.join(self.opt_parser.single_test, jpg))
if(fls is None):
fls = glob.glob1(self.opt_parser.single_test, '*.txt')[0]
fls = np.loadtxt(os.path.join(self.opt_parser.single_test, fls))
fls = fls * 95
fls[:, 0::3] += 130
fls[:, 1::3] += 80
writer = cv2.VideoWriter('out.mp4', cv2.VideoWriter_fourcc(*'mjpg'), 62.5, (256 * 3, 256))
for i, frame in enumerate(fls):
img_fl = np.ones(shape=(256, 256, 3)) * 255
fl = frame.astype(int)
img_fl = vis_landmark_on_img(img_fl, np.reshape(fl, (68, 3)))
frame = np.concatenate((img_fl, jpg), axis=2).astype(np.float32)/255.0
image_in, image_out = frame.transpose((2, 0, 1)), np.zeros(shape=(3, 256, 256))
# image_in, image_out = frame.transpose((2, 1, 0)), np.zeros(shape=(3, 256, 256))
image_in, image_out = torch.tensor(image_in, requires_grad=False), \
torch.tensor(image_out, requires_grad=False)
image_in, image_out = image_in.reshape(-1, 6, 256, 256), image_out.reshape(-1, 3, 256, 256)
image_in, image_out = image_in.to(device), image_out.to(device)
g_out = self.G(image_in)
g_out = torch.tanh(g_out)
g_out = g_out.cpu().detach().numpy().transpose((0, 2, 3, 1))
g_out[g_out < 0] = 0
ref_in = image_in[:, 3:6, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
fls_in = image_in[:, 0:3, :, :].cpu().detach().numpy().transpose((0, 2, 3, 1))
# g_out = g_out.cpu().detach().numpy().transpose((0, 3, 2, 1))
# g_out[g_out < 0] = 0
# ref_in = image_in[:, 3:6, :, :].cpu().detach().numpy().transpose((0, 3, 2, 1))
# fls_in = image_in[:, 0:3, :, :].cpu().detach().numpy().transpose((0, 3, 2, 1))
if(grey_only):
g_out_grey =np.mean(g_out, axis=3, keepdims=True)
g_out[:, :, :, 0:1] = g_out[:, :, :, 1:2] = g_out[:, :, :, 2:3] = g_out_grey
for i in range(g_out.shape[0]):
frame = np.concatenate((ref_in[i], g_out[i], fls_in[i]), axis=1) * 255.0
writer.write(frame.astype(np.uint8))
writer.release()
print('Time - only video:', time.time() - st)
if(filename is None):
filename = 'v'
os.system('ffmpeg -loglevel error -y -i out.mp4 -i {} -pix_fmt yuv420p -strict -2 examples/{}_{}.mp4'.format(
'examples/'+filename[9:-16]+'.wav',
prefix, filename[:-4]))
# os.system('rm out.mp4')
print('Time - ffmpeg add audio:', time.time() - st)