|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
import math |
|
import numpy as np |
|
|
|
from modules.real3d.facev2v_warp.network import AppearanceFeatureExtractor, CanonicalKeypointDetector, PoseExpressionEstimator, MotionFieldEstimator, Generator |
|
from modules.real3d.facev2v_warp.func_utils import transform_kp, make_coordinate_grid_2d, apply_imagenet_normalization |
|
from modules.real3d.facev2v_warp.losses import PerceptualLoss, GANLoss, FeatureMatchingLoss, EquivarianceLoss, KeypointPriorLoss, HeadPoseLoss, DeformationPriorLoss |
|
from utils.commons.image_utils import erode, dilate |
|
from utils.commons.hparams import hparams |
|
|
|
|
|
class Hopenet(nn.Module): |
|
|
|
|
|
def __init__(self, block, layers, num_bins): |
|
self.inplanes = 64 |
|
super(Hopenet, self).__init__() |
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
self.bn1 = nn.BatchNorm2d(64) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
self.layer1 = self._make_layer(block, 64, layers[0]) |
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) |
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) |
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) |
|
self.avgpool = nn.AvgPool2d(7) |
|
self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) |
|
self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) |
|
self.fc_roll = nn.Linear(512 * block.expansion, num_bins) |
|
|
|
|
|
self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) |
|
self.idx_tensor = torch.FloatTensor(list(range(num_bins))).unsqueeze(0).cuda() |
|
self.n_bins = num_bins |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
m.weight.data.normal_(0, math.sqrt(2.0 / n)) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.fill_(1) |
|
m.bias.data.zero_() |
|
|
|
def _make_layer(self, block, planes, blocks, stride=1): |
|
downsample = None |
|
if stride != 1 or self.inplanes != planes * block.expansion: |
|
downsample = nn.Sequential( |
|
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), |
|
nn.BatchNorm2d(planes * block.expansion), |
|
) |
|
|
|
layers = [] |
|
layers.append(block(self.inplanes, planes, stride, downsample)) |
|
self.inplanes = planes * block.expansion |
|
for i in range(1, blocks): |
|
layers.append(block(self.inplanes, planes)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
x = self.avgpool(x) |
|
x = x.view(x.size(0), -1) |
|
real_yaw = self.fc_yaw(x) |
|
real_pitch = self.fc_pitch(x) |
|
real_roll = self.fc_roll(x) |
|
real_yaw = torch.softmax(real_yaw, dim=1) |
|
real_pitch = torch.softmax(real_pitch, dim=1) |
|
real_roll = torch.softmax(real_roll, dim=1) |
|
real_yaw = (real_yaw * self.idx_tensor).sum(dim=1) |
|
real_pitch = (real_pitch * self.idx_tensor).sum(dim=1) |
|
real_roll = (real_roll * self.idx_tensor).sum(dim=1) |
|
real_yaw = (real_yaw - self.n_bins // 2) * 3 * np.pi / 180 |
|
real_pitch = (real_pitch - self.n_bins // 2) * 3 * np.pi / 180 |
|
real_roll = (real_roll - self.n_bins // 2) * 3 * np.pi / 180 |
|
|
|
return real_yaw, real_pitch, real_roll |
|
|
|
|
|
class Transform: |
|
""" |
|
Random tps transformation for equivariance constraints. |
|
reference: FOMM |
|
""" |
|
|
|
def __init__(self, bs, sigma_affine=0.05, sigma_tps=0.005, points_tps=5): |
|
noise = torch.normal(mean=0, std=sigma_affine * torch.ones([bs, 2, 3])) |
|
self.theta = noise + torch.eye(2, 3).view(1, 2, 3) |
|
self.bs = bs |
|
|
|
self.control_points = make_coordinate_grid_2d((points_tps, points_tps)) |
|
self.control_points = self.control_points.unsqueeze(0) |
|
self.control_params = torch.normal(mean=0, std=sigma_tps * torch.ones([bs, 1, points_tps ** 2])) |
|
|
|
def transform_frame(self, frame): |
|
grid = make_coordinate_grid_2d(frame.shape[2:]).unsqueeze(0) |
|
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) |
|
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) |
|
return F.grid_sample(frame, grid, align_corners=True, padding_mode="reflection") |
|
|
|
def warp_coordinates(self, coordinates): |
|
theta = self.theta.type(coordinates.type()) |
|
theta = theta.unsqueeze(1) |
|
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] |
|
transformed = transformed.squeeze(-1) |
|
|
|
control_points = self.control_points.type(coordinates.type()) |
|
control_params = self.control_params.type(coordinates.type()) |
|
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) |
|
distances = torch.abs(distances).sum(-1) |
|
|
|
result = distances ** 2 |
|
result = result * torch.log(distances + 1e-6) |
|
result = result * control_params |
|
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) |
|
transformed = transformed + result |
|
|
|
return transformed |
|
|
|
|
|
class WarpBasedTorsoModel(nn.Module): |
|
def __init__(self, model_scale='small'): |
|
super().__init__() |
|
self.appearance_extractor = AppearanceFeatureExtractor(model_scale) |
|
self.canonical_kp_detector = CanonicalKeypointDetector(model_scale) |
|
self.pose_exp_estimator = PoseExpressionEstimator(model_scale) |
|
self.motion_field_estimator = MotionFieldEstimator(model_scale) |
|
self.deform_based_generator = Generator() |
|
|
|
self.pretrained_hopenet = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_bins=66).cuda() |
|
pretrained_path = "/home/tiger/nfs/myenv/cache/useful_ckpts/hopenet_robust_alpha1.pkl" |
|
self.pretrained_hopenet.load_state_dict(torch.load(pretrained_path, map_location=torch.device("cpu"))) |
|
self.pretrained_hopenet.requires_grad_(False) |
|
|
|
self.pose_loss_fn = HeadPoseLoss() |
|
self.equivariance_loss_fn = EquivarianceLoss() |
|
self.keypoint_prior_loss_fn = KeypointPriorLoss() |
|
self.deform_prior_loss_fn = DeformationPriorLoss() |
|
|
|
def forward(self, torso_src_img, src_img, drv_img, cal_loss=False): |
|
|
|
cano_keypoint = self.canonical_kp_detector(src_img) |
|
|
|
transform_fn = Transform(drv_img.shape[0]) |
|
transformed_drv_img = transform_fn.transform_frame(drv_img) |
|
cat_imgs = torch.cat([src_img, drv_img, transformed_drv_img], dim=0) |
|
yaw, pitch, roll, t, delta = self.pose_exp_estimator(cat_imgs) |
|
[yaw_s, yaw_d, yaw_tran], [pitch_s, pitch_d, pitch_tran], [roll_s, roll_d, roll_tran] = ( |
|
torch.chunk(yaw, 3, dim=0), |
|
torch.chunk(pitch, 3, dim=0), |
|
torch.chunk(roll, 3, dim=0), |
|
) |
|
[t_s, t_d, t_tran], [delta_s, delta_d, delta_tran] = ( |
|
torch.chunk(t, 3, dim=0), |
|
torch.chunk(delta, 3, dim=0), |
|
) |
|
kp_s, Rs = transform_kp(cano_keypoint, yaw_s, pitch_s, roll_s, t_s, delta_s) |
|
kp_d, Rd = transform_kp(cano_keypoint, yaw_d, pitch_d, roll_d, t_d, delta_d) |
|
|
|
torso_appearance_feats = self.appearance_extractor(torso_src_img) |
|
deformation, occlusion = self.motion_field_estimator(torso_appearance_feats, kp_s, kp_d, Rs, Rd) |
|
deformed_torso_img = self.deform_based_generator(torso_appearance_feats, deformation, occlusion) |
|
|
|
ret = {'kp_src': kp_s, 'kp_drv': kp_d} |
|
if cal_loss: |
|
losses = {} |
|
with torch.no_grad(): |
|
self.pretrained_hopenet.eval() |
|
real_yaw, real_pitch, real_roll = self.pretrained_hopenet(F.interpolate(apply_imagenet_normalization(cat_imgs), size=(224, 224))) |
|
pose_loss = self.pose_loss_fn(yaw, pitch, roll, real_yaw, real_pitch, real_roll) |
|
losses['facev2v/pose_pred_loss'] = pose_loss |
|
|
|
kp_tran, _ = transform_kp(cano_keypoint, yaw_tran, pitch_tran, roll_tran, t_tran, delta_tran) |
|
reverse_kp = transform_fn.warp_coordinates(kp_tran[:, :, :2]) |
|
equivariance_loss = self.equivariance_loss_fn(kp_d, reverse_kp) |
|
losses['facev2v/equivariance_loss'] = equivariance_loss |
|
|
|
keypoint_prior_loss = self.keypoint_prior_loss_fn(kp_d) |
|
losses['facev2v/keypoint_prior_loss'] = keypoint_prior_loss |
|
|
|
deform_prior_loss = self.deform_prior_loss_fn(delta_d) |
|
losses['facev2v/deform_prior_loss'] = deform_prior_loss |
|
ret['losses'] = losses |
|
|
|
return deformed_torso_img, ret |
|
|
|
|
|
class WarpBasedTorsoModelMediaPipe(nn.Module): |
|
def __init__(self, model_scale='small'): |
|
super().__init__() |
|
self.appearance_extractor = AppearanceFeatureExtractor(model_scale) |
|
self.motion_field_estimator = MotionFieldEstimator(model_scale, input_channels=32+2, num_keypoints=hparams['torso_kp_num']) |
|
|
|
self.deform_based_generator = Generator() |
|
|
|
self.occlusion_2_predictor = nn.Sequential(*[ |
|
nn.Conv2d(64+1, 32, 3, 1, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 32, 3, 1, 1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 1, 3, 1, 1), |
|
nn.Sigmoid() |
|
]) |
|
|
|
|
|
def forward(self, torso_src_img, segmap, kp_s, kp_d, tgt_head_img, cal_loss=False, target_torso_mask=None): |
|
""" |
|
kp_s, kp_d, [b, 68, 3], within the range of [-1,1] |
|
""" |
|
torso_appearance_feats = self.appearance_extractor(torso_src_img) |
|
torso_segmap = torch.nn.functional.interpolate(segmap[:,[2,4]].float(), size=(64,64), mode='bilinear', align_corners=False, antialias=False) |
|
torso_mask = torso_segmap.sum(dim=1).unsqueeze(1) |
|
torso_mask = dilate(torso_mask, ksize=hparams.get("torso_mask_dilate_ksize", 7)) |
|
if hparams.get("mul_torso_mask", True): |
|
torso_appearance_feats = torso_appearance_feats * torso_mask.unsqueeze(1) |
|
motion_inp_appearance_feats = torch.cat([torso_appearance_feats, torso_segmap.unsqueeze(2).repeat([1,1,torso_appearance_feats.shape[2],1,1])], dim=1) |
|
|
|
if hparams['torso_kp_num'] == 4: |
|
kp_s = kp_s[:,[0,8,16,27],:] |
|
kp_d = kp_d[:,[0,8,16,27],:] |
|
elif hparams['torso_kp_num'] == 9: |
|
kp_s = kp_s[:,[0, 3, 6, 8, 10, 13, 16, 27, 33],:] |
|
kp_d = kp_d[:,[0, 3, 6, 8, 10, 13, 16, 27, 33],:] |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
Rs = torch.eye(3, 3).unsqueeze(0).repeat([kp_s.shape[0], 1, 1]).to(kp_s.device) |
|
Rd = torch.eye(3, 3).unsqueeze(0).repeat([kp_d.shape[0], 1, 1]).to(kp_d.device) |
|
deformation, occlusion, occlusion_2 = self.motion_field_estimator(motion_inp_appearance_feats, kp_s, kp_d, Rs, Rd) |
|
motion_estimator_grad_scale_factor = 0.1 |
|
|
|
deformation = deformation * motion_estimator_grad_scale_factor + deformation.detach() * (1-motion_estimator_grad_scale_factor) |
|
|
|
occlusion = occlusion * motion_estimator_grad_scale_factor + occlusion.detach() * (1-motion_estimator_grad_scale_factor) |
|
|
|
occlusion_2 = occlusion_2 * motion_estimator_grad_scale_factor + occlusion_2.detach() * (1-motion_estimator_grad_scale_factor) |
|
ret = {'kp_src': kp_s, 'kp_drv': kp_d, 'occlusion': occlusion, 'occlusion_2': occlusion_2} |
|
|
|
deformed_torso_img, deformed_torso_hid = self.deform_based_generator(torso_appearance_feats, deformation, occlusion, return_hid=True) |
|
ret['deformed_torso_hid'] = deformed_torso_hid |
|
occlusion_2 = self.occlusion_2_predictor(torch.cat([deformed_torso_hid, F.interpolate(occlusion_2, size=(256,256), mode='bilinear')], dim=1)) |
|
ret['occlusion_2'] = occlusion_2 |
|
alphas = occlusion_2.clamp(1e-5, 1 - 1e-5) |
|
|
|
if target_torso_mask is None: |
|
ret['losses'] = { |
|
'facev2v/occlusion_reg_l1': occlusion.mean(), |
|
'facev2v/occlusion_2_reg_l1': occlusion_2.mean(), |
|
'facev2v/occlusion_2_weights_entropy': torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)), |
|
} |
|
else: |
|
non_target_torso_mask_1 = torch.nn.functional.interpolate((~target_torso_mask).unsqueeze(1).float(), size=occlusion.shape[-2:]) |
|
non_target_torso_mask_2 = torch.nn.functional.interpolate((~target_torso_mask).unsqueeze(1).float(), size=occlusion_2.shape[-2:]) |
|
ret['losses'] = { |
|
'facev2v/occlusion_reg_l1': self.masked_l1_reg_loss(occlusion, non_target_torso_mask_1.bool(), masked_weight=1, unmasked_weight=hparams['torso_occlusion_reg_unmask_factor']), |
|
'facev2v/occlusion_2_reg_l1': self.masked_l1_reg_loss(occlusion_2, non_target_torso_mask_2.bool(), masked_weight=1, unmasked_weight=hparams['torso_occlusion_reg_unmask_factor']), |
|
'facev2v/occlusion_2_weights_entropy': torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return deformed_torso_img, ret |
|
|
|
def masked_l1_reg_loss(self, img_pred, mask, masked_weight=0.01, unmasked_weight=0.001, mode='l1'): |
|
|
|
masked_weight = 1.0 |
|
weight_mask = mask.float() * masked_weight + (~mask).float() * unmasked_weight |
|
if mode == 'l1': |
|
error = (img_pred).abs().sum(dim=1) * weight_mask |
|
else: |
|
error = (img_pred).pow(2).sum(dim=1) * weight_mask |
|
loss = error.mean() |
|
return loss |
|
|
|
@torch.no_grad() |
|
def infer_forward_stage1(self, torso_src_img, segmap, kp_s, kp_d, tgt_head_img, cal_loss=False): |
|
""" |
|
kp_s, kp_d, [b, 68, 3], within the range of [-1,1] |
|
""" |
|
kp_s = kp_s[:,[0,8,16,27],:] |
|
kp_d = kp_d[:,[0,8,16,27],:] |
|
|
|
torso_segmap = torch.nn.functional.interpolate(segmap[:,[2,4]].float(), size=(64,64), mode='bilinear', align_corners=False, antialias=False) |
|
torso_appearance_feats = self.appearance_extractor(torso_src_img) |
|
torso_mask = torso_segmap.sum(dim=1).unsqueeze(1) |
|
torso_mask = dilate(torso_mask, ksize=hparams.get("torso_mask_dilate_ksize", 7)) |
|
if hparams.get("mul_torso_mask", True): |
|
torso_appearance_feats = torso_appearance_feats * torso_mask.unsqueeze(1) |
|
motion_inp_appearance_feats = torch.cat([torso_appearance_feats, torso_segmap.unsqueeze(2).repeat([1,1,torso_appearance_feats.shape[2],1,1])], dim=1) |
|
|
|
Rs = torch.eye(3, 3).unsqueeze(0).repeat([kp_s.shape[0], 1, 1]).to(kp_s.device) |
|
Rd = torch.eye(3, 3).unsqueeze(0).repeat([kp_d.shape[0], 1, 1]).to(kp_d.device) |
|
deformation, occlusion, occlusion_2 = self.motion_field_estimator(motion_inp_appearance_feats, kp_s, kp_d, Rs, Rd) |
|
motion_estimator_grad_scale_factor = 0.1 |
|
deformation = deformation * motion_estimator_grad_scale_factor + deformation.detach() * (1-motion_estimator_grad_scale_factor) |
|
occlusion = occlusion * motion_estimator_grad_scale_factor + occlusion.detach() * (1-motion_estimator_grad_scale_factor) |
|
occlusion_2 = occlusion_2 * motion_estimator_grad_scale_factor + occlusion_2.detach() * (1-motion_estimator_grad_scale_factor) |
|
ret = {'kp_src': kp_s, 'kp_drv': kp_d, 'occlusion': occlusion, 'occlusion_2': occlusion_2} |
|
ret['torso_appearance_feats'] = torso_appearance_feats |
|
ret['deformation'] = deformation |
|
ret['occlusion'] = occlusion |
|
return ret |
|
|
|
@torch.no_grad() |
|
def infer_forward_stage2(self, ret): |
|
torso_appearance_feats = ret['torso_appearance_feats'] |
|
deformation = ret['deformation'] |
|
occlusion = ret['occlusion'] |
|
deformed_torso_img, deformed_torso_hid = self.deform_based_generator(torso_appearance_feats, deformation, occlusion, return_hid=True) |
|
ret['deformed_torso_hid'] = deformed_torso_hid |
|
return deformed_torso_img |
|
|
|
if __name__ == '__main__': |
|
from utils.nn.model_utils import num_params |
|
import tqdm |
|
model = WarpBasedTorsoModel('small') |
|
model.cuda() |
|
num_params(model) |
|
for n, m in model.named_children(): |
|
num_params(m, model_name=n) |
|
torso_ref_img = torch.randn([2, 3, 256, 256]).cuda() |
|
ref_img = torch.randn([2, 3, 256, 256]).cuda() |
|
mv_img = torch.randn([2, 3, 256, 256]).cuda() |
|
out = model(torso_ref_img, ref_img, mv_img) |
|
for i in tqdm.trange(100): |
|
out_img, losses = model(torso_ref_img, ref_img, mv_img, cal_loss=True) |
|
print(" ") |