SyncTalk / data_utils /face_tracking /bundle_adjustment.py
yinwentao
DockerFile
8d34f50
import numpy as np
import os
from util import *
import argparse
def set_requires_grad(tensor_list):
for tensor in tensor_list:
tensor.requires_grad = True
parser = argparse.ArgumentParser()
parser.add_argument(
"--path", type=str, default="", help="idname of target person")
parser.add_argument('--img_h', type=int, default=512, help='height if image')
parser.add_argument('--img_w', type=int, default=512, help='width of image')
args = parser.parse_args()
id_dir = args.path
params_dict = torch.load(os.path.join(id_dir, 'track_params.pt'))
euler_angle = params_dict['euler'].cuda()
trans = params_dict['trans'].cuda() / 1000.0
focal_len = params_dict['focal'].cuda()
track_xys = torch.as_tensor(
np.load(os.path.join(id_dir, 'track_xys.npy'))).float().cuda()
num_frames = track_xys.shape[0]
point_num = track_xys.shape[1]
pts = torch.zeros((point_num, 3), dtype=torch.float32).cuda()
set_requires_grad([euler_angle, trans, pts])
cxy = torch.Tensor((args.img_w/2.0, args.img_h/2.0)).float().cuda()
optimizer_pts = torch.optim.Adam([pts], lr=1e-2)
iter_num = 500
for iter in range(iter_num):
proj_pts = forward_transform(pts.unsqueeze(0).expand(
num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
loss = cal_lan_loss(proj_pts[..., :2], track_xys)
optimizer_pts.zero_grad()
loss.backward()
optimizer_pts.step()
optimizer_ba = torch.optim.Adam([pts, euler_angle, trans], lr=1e-4)
iter_num = 8000
for iter in range(iter_num):
proj_pts = forward_transform(pts.unsqueeze(0).expand(
num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
loss_lan = cal_lan_loss(proj_pts[..., :2], track_xys)
loss = loss_lan
optimizer_ba.zero_grad()
loss.backward()
optimizer_ba.step()
torch.save({'euler': euler_angle.detach().cpu(),
'trans': trans.detach().cpu(),
'focal': focal_len.detach().cpu()}, os.path.join(id_dir, 'bundle_adjustment.pt'))
print('bundle adjustment params saved')