ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import math
import numpy as np
from modules.real3d.facev2v_warp.network import AppearanceFeatureExtractor, CanonicalKeypointDetector, PoseExpressionEstimator, MotionFieldEstimator, Generator
from modules.real3d.facev2v_warp.func_utils import transform_kp, make_coordinate_grid_2d, apply_imagenet_normalization
from modules.real3d.facev2v_warp.losses import PerceptualLoss, GANLoss, FeatureMatchingLoss, EquivarianceLoss, KeypointPriorLoss, HeadPoseLoss, DeformationPriorLoss
from utils.commons.image_utils import erode, dilate
from utils.commons.hparams import hparams
class Hopenet(nn.Module):
# Hopenet with 3 output layers for yaw, pitch and roll
# Predicts Euler angles by binning and regression with the expected value
def __init__(self, block, layers, num_bins):
self.inplanes = 64
super(Hopenet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)
self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
# Vestigial layer from previous experiments
self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
self.idx_tensor = torch.FloatTensor(list(range(num_bins))).unsqueeze(0).cuda()
self.n_bins = num_bins
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
real_yaw = self.fc_yaw(x)
real_pitch = self.fc_pitch(x)
real_roll = self.fc_roll(x)
real_yaw = torch.softmax(real_yaw, dim=1)
real_pitch = torch.softmax(real_pitch, dim=1)
real_roll = torch.softmax(real_roll, dim=1)
real_yaw = (real_yaw * self.idx_tensor).sum(dim=1)
real_pitch = (real_pitch * self.idx_tensor).sum(dim=1)
real_roll = (real_roll * self.idx_tensor).sum(dim=1)
real_yaw = (real_yaw - self.n_bins // 2) * 3 * np.pi / 180
real_pitch = (real_pitch - self.n_bins // 2) * 3 * np.pi / 180
real_roll = (real_roll - self.n_bins // 2) * 3 * np.pi / 180
return real_yaw, real_pitch, real_roll
class Transform:
"""
Random tps transformation for equivariance constraints.
reference: FOMM
"""
def __init__(self, bs, sigma_affine=0.05, sigma_tps=0.005, points_tps=5):
noise = torch.normal(mean=0, std=sigma_affine * torch.ones([bs, 2, 3]))
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
self.bs = bs
self.control_points = make_coordinate_grid_2d((points_tps, points_tps))
self.control_points = self.control_points.unsqueeze(0)
self.control_params = torch.normal(mean=0, std=sigma_tps * torch.ones([bs, 1, points_tps ** 2]))
def transform_frame(self, frame):
grid = make_coordinate_grid_2d(frame.shape[2:]).unsqueeze(0)
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, align_corners=True, padding_mode="reflection")
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
transformed = transformed.squeeze(-1)
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
class WarpBasedTorsoModel(nn.Module):
def __init__(self, model_scale='small'):
super().__init__()
self.appearance_extractor = AppearanceFeatureExtractor(model_scale)
self.canonical_kp_detector = CanonicalKeypointDetector(model_scale)
self.pose_exp_estimator = PoseExpressionEstimator(model_scale)
self.motion_field_estimator = MotionFieldEstimator(model_scale)
self.deform_based_generator = Generator()
self.pretrained_hopenet = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_bins=66).cuda()
pretrained_path = "/home/tiger/nfs/myenv/cache/useful_ckpts/hopenet_robust_alpha1.pkl" # https://drive.google.com/open?id=1m25PrSE7g9D2q2XJVMR6IA7RaCvWSzCR
self.pretrained_hopenet.load_state_dict(torch.load(pretrained_path, map_location=torch.device("cpu")))
self.pretrained_hopenet.requires_grad_(False)
self.pose_loss_fn = HeadPoseLoss() # 20
self.equivariance_loss_fn = EquivarianceLoss() # 20
self.keypoint_prior_loss_fn = KeypointPriorLoss()# 10
self.deform_prior_loss_fn = DeformationPriorLoss() # 5
def forward(self, torso_src_img, src_img, drv_img, cal_loss=False):
# predict cano keypoint
cano_keypoint = self.canonical_kp_detector(src_img)
# predict src_pose and drv_pose
transform_fn = Transform(drv_img.shape[0])
transformed_drv_img = transform_fn.transform_frame(drv_img)
cat_imgs = torch.cat([src_img, drv_img, transformed_drv_img], dim=0)
yaw, pitch, roll, t, delta = self.pose_exp_estimator(cat_imgs)
[yaw_s, yaw_d, yaw_tran], [pitch_s, pitch_d, pitch_tran], [roll_s, roll_d, roll_tran] = (
torch.chunk(yaw, 3, dim=0),
torch.chunk(pitch, 3, dim=0),
torch.chunk(roll, 3, dim=0),
)
[t_s, t_d, t_tran], [delta_s, delta_d, delta_tran] = (
torch.chunk(t, 3, dim=0),
torch.chunk(delta, 3, dim=0),
)
kp_s, Rs = transform_kp(cano_keypoint, yaw_s, pitch_s, roll_s, t_s, delta_s)
kp_d, Rd = transform_kp(cano_keypoint, yaw_d, pitch_d, roll_d, t_d, delta_d)
# deform the torso img
torso_appearance_feats = self.appearance_extractor(torso_src_img)
deformation, occlusion = self.motion_field_estimator(torso_appearance_feats, kp_s, kp_d, Rs, Rd)
deformed_torso_img = self.deform_based_generator(torso_appearance_feats, deformation, occlusion)
ret = {'kp_src': kp_s, 'kp_drv': kp_d}
if cal_loss:
losses = {}
with torch.no_grad():
self.pretrained_hopenet.eval()
real_yaw, real_pitch, real_roll = self.pretrained_hopenet(F.interpolate(apply_imagenet_normalization(cat_imgs), size=(224, 224)))
pose_loss = self.pose_loss_fn(yaw, pitch, roll, real_yaw, real_pitch, real_roll)
losses['facev2v/pose_pred_loss'] = pose_loss
kp_tran, _ = transform_kp(cano_keypoint, yaw_tran, pitch_tran, roll_tran, t_tran, delta_tran)
reverse_kp = transform_fn.warp_coordinates(kp_tran[:, :, :2])
equivariance_loss = self.equivariance_loss_fn(kp_d, reverse_kp)
losses['facev2v/equivariance_loss'] = equivariance_loss
keypoint_prior_loss = self.keypoint_prior_loss_fn(kp_d)
losses['facev2v/keypoint_prior_loss'] = keypoint_prior_loss
deform_prior_loss = self.deform_prior_loss_fn(delta_d)
losses['facev2v/deform_prior_loss'] = deform_prior_loss
ret['losses'] = losses
return deformed_torso_img, ret
class WarpBasedTorsoModelMediaPipe(nn.Module):
def __init__(self, model_scale='small'):
super().__init__()
self.appearance_extractor = AppearanceFeatureExtractor(model_scale)
self.motion_field_estimator = MotionFieldEstimator(model_scale, input_channels=32+2, num_keypoints=hparams['torso_kp_num']) # 32 channel appearance channel, and 3 channel for segmap
# self.motion_field_estimator = MotionFieldEstimator(model_scale, input_channels=32+2, num_keypoints=9) # 32 channel appearance channel, and 3 channel for segmap
self.deform_based_generator = Generator()
self.occlusion_2_predictor = nn.Sequential(*[
nn.Conv2d(64+1, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 1, 3, 1, 1),
nn.Sigmoid()
])
# V2, 先warp, 再mean
def forward(self, torso_src_img, segmap, kp_s, kp_d, tgt_head_img, cal_loss=False, target_torso_mask=None):
"""
kp_s, kp_d, [b, 68, 3], within the range of [-1,1]
"""
torso_appearance_feats = self.appearance_extractor(torso_src_img) # [B, C, D, H, W]
torso_segmap = torch.nn.functional.interpolate(segmap[:,[2,4]].float(), size=(64,64), mode='bilinear', align_corners=False, antialias=False) # see tasks/eg3ds/loss_utils/segment_loss/mp_segmenter.py for the segmap convention
torso_mask = torso_segmap.sum(dim=1).unsqueeze(1) # [b, 1, ,h, w]
torso_mask = dilate(torso_mask, ksize=hparams.get("torso_mask_dilate_ksize", 7))
if hparams.get("mul_torso_mask", True):
torso_appearance_feats = torso_appearance_feats * torso_mask.unsqueeze(1)
motion_inp_appearance_feats = torch.cat([torso_appearance_feats, torso_segmap.unsqueeze(2).repeat([1,1,torso_appearance_feats.shape[2],1,1])], dim=1)
if hparams['torso_kp_num'] == 4:
kp_s = kp_s[:,[0,8,16,27],:]
kp_d = kp_d[:,[0,8,16,27],:]
elif hparams['torso_kp_num'] == 9:
kp_s = kp_s[:,[0, 3, 6, 8, 10, 13, 16, 27, 33],:]
kp_d = kp_d[:,[0, 3, 6, 8, 10, 13, 16, 27, 33],:]
else:
raise NotImplementedError()
# deform the torso img
Rs = torch.eye(3, 3).unsqueeze(0).repeat([kp_s.shape[0], 1, 1]).to(kp_s.device)
Rd = torch.eye(3, 3).unsqueeze(0).repeat([kp_d.shape[0], 1, 1]).to(kp_d.device)
deformation, occlusion, occlusion_2 = self.motion_field_estimator(motion_inp_appearance_feats, kp_s, kp_d, Rs, Rd)
motion_estimator_grad_scale_factor = 0.1
# motion_estimator_grad_scale_factor = 1.0
deformation = deformation * motion_estimator_grad_scale_factor + deformation.detach() * (1-motion_estimator_grad_scale_factor)
# occlusion, a 0~1 mask that predict the segment map of warped torso, used in oclcusion-aware decoder
occlusion = occlusion * motion_estimator_grad_scale_factor + occlusion.detach() * (1-motion_estimator_grad_scale_factor)
# occlusion_2, a 0~1 mask that predict the segment map of warped torso, but is used in alpha-blending
occlusion_2 = occlusion_2 * motion_estimator_grad_scale_factor + occlusion_2.detach() * (1-motion_estimator_grad_scale_factor)
ret = {'kp_src': kp_s, 'kp_drv': kp_d, 'occlusion': occlusion, 'occlusion_2': occlusion_2}
deformed_torso_img, deformed_torso_hid = self.deform_based_generator(torso_appearance_feats, deformation, occlusion, return_hid=True)
ret['deformed_torso_hid'] = deformed_torso_hid
occlusion_2 = self.occlusion_2_predictor(torch.cat([deformed_torso_hid, F.interpolate(occlusion_2, size=(256,256), mode='bilinear')], dim=1))
ret['occlusion_2'] = occlusion_2
alphas = occlusion_2.clamp(1e-5, 1 - 1e-5)
if target_torso_mask is None:
ret['losses'] = {
'facev2v/occlusion_reg_l1': occlusion.mean(),
'facev2v/occlusion_2_reg_l1': occlusion_2.mean(),
'facev2v/occlusion_2_weights_entropy': torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)), # you can visualize this fn at https://www.desmos.com/calculator/rwbs7bruvj?lang=zh-TW
}
else:
non_target_torso_mask_1 = torch.nn.functional.interpolate((~target_torso_mask).unsqueeze(1).float(), size=occlusion.shape[-2:])
non_target_torso_mask_2 = torch.nn.functional.interpolate((~target_torso_mask).unsqueeze(1).float(), size=occlusion_2.shape[-2:])
ret['losses'] = {
'facev2v/occlusion_reg_l1': self.masked_l1_reg_loss(occlusion, non_target_torso_mask_1.bool(), masked_weight=1, unmasked_weight=hparams['torso_occlusion_reg_unmask_factor']),
'facev2v/occlusion_2_reg_l1': self.masked_l1_reg_loss(occlusion_2, non_target_torso_mask_2.bool(), masked_weight=1, unmasked_weight=hparams['torso_occlusion_reg_unmask_factor']),
'facev2v/occlusion_2_weights_entropy': torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)), # you can visualize this fn at https://www.desmos.com/calculator/rwbs7bruvj?lang=zh-TW
}
# if hparams.get("fuse_with_deform_source"):
# B, _, H, W = deformed_torso_img.shape
# deformation_256 = F.interpolate(deformation.mean(dim=1).permute(0,3,1,2), size=256, mode='bilinear',antialias=True).permute(0,2,3,1)[...,:2]
# deformed_source_torso_img = F.grid_sample(torso_src_img, deformation_256, align_corners=True).view(B, -1, H, W)
# occlusion_256 = F.interpolate(occlusion, size=256, antialias=True, mode='bilinear').reshape([B,1,H,W])
# # deformed_torso_img = deformed_torso_img * (1 - occlusion_256[:,0]) + deformed_source_torso_img[:,0] * occlusion_256[:,0]
# deformed_torso_img = deformed_torso_img * (1 - occlusion_256) + deformed_source_torso_img * occlusion_256
return deformed_torso_img, ret
def masked_l1_reg_loss(self, img_pred, mask, masked_weight=0.01, unmasked_weight=0.001, mode='l1'):
# 对raw图像,因为deform的原因背景没法全黑,导致这部分mse过高,我们将其mask掉,只计算人脸部分
masked_weight = 1.0
weight_mask = mask.float() * masked_weight + (~mask).float() * unmasked_weight
if mode == 'l1':
error = (img_pred).abs().sum(dim=1) * weight_mask
else:
error = (img_pred).pow(2).sum(dim=1) * weight_mask
loss = error.mean()
return loss
@torch.no_grad()
def infer_forward_stage1(self, torso_src_img, segmap, kp_s, kp_d, tgt_head_img, cal_loss=False):
"""
kp_s, kp_d, [b, 68, 3], within the range of [-1,1]
"""
kp_s = kp_s[:,[0,8,16,27],:]
kp_d = kp_d[:,[0,8,16,27],:]
torso_segmap = torch.nn.functional.interpolate(segmap[:,[2,4]].float(), size=(64,64), mode='bilinear', align_corners=False, antialias=False) # see tasks/eg3ds/loss_utils/segment_loss/mp_segmenter.py for the segmap convention
torso_appearance_feats = self.appearance_extractor(torso_src_img)
torso_mask = torso_segmap.sum(dim=1).unsqueeze(1) # [b, 1, ,h, w]
torso_mask = dilate(torso_mask, ksize=hparams.get("torso_mask_dilate_ksize", 7))
if hparams.get("mul_torso_mask", True):
torso_appearance_feats = torso_appearance_feats * torso_mask.unsqueeze(1)
motion_inp_appearance_feats = torch.cat([torso_appearance_feats, torso_segmap.unsqueeze(2).repeat([1,1,torso_appearance_feats.shape[2],1,1])], dim=1)
# deform the torso img
Rs = torch.eye(3, 3).unsqueeze(0).repeat([kp_s.shape[0], 1, 1]).to(kp_s.device)
Rd = torch.eye(3, 3).unsqueeze(0).repeat([kp_d.shape[0], 1, 1]).to(kp_d.device)
deformation, occlusion, occlusion_2 = self.motion_field_estimator(motion_inp_appearance_feats, kp_s, kp_d, Rs, Rd)
motion_estimator_grad_scale_factor = 0.1
deformation = deformation * motion_estimator_grad_scale_factor + deformation.detach() * (1-motion_estimator_grad_scale_factor)
occlusion = occlusion * motion_estimator_grad_scale_factor + occlusion.detach() * (1-motion_estimator_grad_scale_factor)
occlusion_2 = occlusion_2 * motion_estimator_grad_scale_factor + occlusion_2.detach() * (1-motion_estimator_grad_scale_factor)
ret = {'kp_src': kp_s, 'kp_drv': kp_d, 'occlusion': occlusion, 'occlusion_2': occlusion_2}
ret['torso_appearance_feats'] = torso_appearance_feats
ret['deformation'] = deformation
ret['occlusion'] = occlusion
return ret
@torch.no_grad()
def infer_forward_stage2(self, ret):
torso_appearance_feats = ret['torso_appearance_feats']
deformation = ret['deformation']
occlusion = ret['occlusion']
deformed_torso_img, deformed_torso_hid = self.deform_based_generator(torso_appearance_feats, deformation, occlusion, return_hid=True)
ret['deformed_torso_hid'] = deformed_torso_hid
return deformed_torso_img
if __name__ == '__main__':
from utils.nn.model_utils import num_params
import tqdm
model = WarpBasedTorsoModel('small')
model.cuda()
num_params(model)
for n, m in model.named_children():
num_params(m, model_name=n)
torso_ref_img = torch.randn([2, 3, 256, 256]).cuda()
ref_img = torch.randn([2, 3, 256, 256]).cuda()
mv_img = torch.randn([2, 3, 256, 256]).cuda()
out = model(torso_ref_img, ref_img, mv_img)
for i in tqdm.trange(100):
out_img, losses = model(torso_ref_img, ref_img, mv_img, cal_loss=True)
print(" ")