|
import torch |
|
import torch.nn as nn |
|
import render_util |
|
import geo_transform |
|
import numpy as np |
|
|
|
|
|
def compute_tri_normal(geometry, tris): |
|
geometry = geometry.permute(0, 2, 1) |
|
tri_1 = tris[:, 0] |
|
tri_2 = tris[:, 1] |
|
tri_3 = tris[:, 2] |
|
|
|
vert_1 = torch.index_select(geometry, 2, tri_1) |
|
vert_2 = torch.index_select(geometry, 2, tri_2) |
|
vert_3 = torch.index_select(geometry, 2, tri_3) |
|
|
|
nnorm = torch.cross(vert_2-vert_1, vert_3-vert_1, 1) |
|
normal = nn.functional.normalize(nnorm).permute(0, 2, 1) |
|
return normal |
|
|
|
|
|
class Compute_normal_base(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, normal): |
|
normal_b, = render_util.normal_base_forward(normal) |
|
ctx.save_for_backward(normal) |
|
return normal_b |
|
|
|
@staticmethod |
|
def backward(ctx, grad_normal_b): |
|
normal, = ctx.saved_tensors |
|
grad_normal, = render_util.normal_base_backward(grad_normal_b, normal) |
|
return grad_normal |
|
|
|
|
|
class Normal_Base(torch.nn.Module): |
|
def __init__(self): |
|
super(Normal_Base, self).__init__() |
|
|
|
def forward(self, normal): |
|
return Compute_normal_base.apply(normal) |
|
|
|
|
|
def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): |
|
point_num = geometry.shape[1] |
|
rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) |
|
proj_geo = geo_transform.proj_geo(rott_geo, cam) |
|
rot_tri_normal = compute_tri_normal(rott_geo, tris) |
|
rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) |
|
is_visible = -torch.bmm(rot_vert_normal.reshape(-1, 1, 3), |
|
nn.functional.normalize(rott_geo.reshape(-1, 3, 1))).reshape(-1, point_num) |
|
is_visible[is_visible < 0.01] = -1 |
|
pixel_valid = torch.zeros((ori_img.shape[0], ori_img.shape[1]*ori_img.shape[2]), |
|
dtype=torch.float32, device=ori_img.device) |
|
return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid |
|
|
|
|
|
class Render_Face(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, |
|
pixel_valid): |
|
batch_size, h, w, _ = ori_img.shape |
|
ori_img = ori_img.view(batch_size, -1, 3) |
|
ori_size = torch.cat((torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*h, |
|
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*w), |
|
dim=1).view(-1) |
|
tri_index, tri_coord, render, real = render_util.render_face_forward( |
|
proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid) |
|
ctx.save_for_backward(ori_img, ori_size, proj_geo, texture, nbl, |
|
tri_inds, tri_index, tri_coord) |
|
return render, real |
|
|
|
@staticmethod |
|
def backward(ctx, grad_render, grad_real): |
|
ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord = \ |
|
ctx.saved_tensors |
|
grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( |
|
grad_render, grad_real, ori_img, ori_size, proj_geo, texture, nbl, tri_inds, |
|
tri_index, tri_coord) |
|
return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None |
|
|
|
|
|
class Render_RGB(nn.Module): |
|
def __init__(self): |
|
super(Render_RGB, self).__init__() |
|
|
|
def forward(self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid): |
|
return Render_Face.apply(proj_geo, texture, nbl, ori_img, is_visible, |
|
tri_inds, pixel_valid) |
|
|
|
|
|
def cal_land(proj_geo, is_visible, lands_info, land_num): |
|
land_index, = render_util.update_contour( |
|
lands_info, is_visible, land_num) |
|
proj_land = torch.index_select( |
|
proj_geo.reshape(-1, 3), 0, land_index)[:, :2].reshape(-1, land_num, 2) |
|
return proj_land |
|
|
|
|
|
class Render_Land(nn.Module): |
|
def __init__(self): |
|
super(Render_Land, self).__init__() |
|
lands_info = np.loadtxt('../data/3DMM/lands_info.txt', dtype=np.int32) |
|
self.lands_info = torch.as_tensor(lands_info).cuda() |
|
tris = np.loadtxt('../data/3DMM/tris.txt', dtype=np.int64) |
|
self.tris = torch.as_tensor(tris).cuda() - 1 |
|
vert_tris = np.loadtxt('../data/3DMM/vert_tris.txt', dtype=np.int64) |
|
self.vert_tris = torch.as_tensor(vert_tris).cuda() |
|
self.normal_baser = Normal_Base().cuda() |
|
self.renderer = Render_RGB().cuda() |
|
|
|
def render_mesh(self, geometry, euler, trans, cam, ori_img, light): |
|
batch_size, h, w, _ = ori_img.shape |
|
ori_img = ori_img.view(batch_size, -1, 3) |
|
ori_size = torch.cat((torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*h, |
|
torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*w), |
|
dim=1).view(-1) |
|
rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( |
|
geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img) |
|
tri_nb = self.normal_baser(rot_tri_normal.contiguous()) |
|
nbl = torch.bmm(tri_nb, (light.reshape(-1, 9, 3)) |
|
[:, :, 0].unsqueeze(-1).repeat(1, 1, 3)) |
|
texture = torch.ones_like(geometry) * 200 |
|
render, = render_util.render_mesh( |
|
proj_geo, ori_img, ori_size, texture, nbl, self.tris) |
|
return render.view(batch_size, h, w, 3).byte() |
|
|
|
def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): |
|
rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = \ |
|
preprocess_render(geometry, euler, trans, cam, |
|
self.tris, self.vert_tris, ori_img) |
|
tri_nb = self.normal_baser(rot_tri_normal.contiguous()) |
|
nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) |
|
render, real = self.renderer( |
|
proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid) |
|
proj_land = cal_land(proj_geo, is_visible, |
|
self.lands_info, lands.shape[1]) |
|
col_minus = torch.norm((render-real).reshape(-1, 3), |
|
dim=1).reshape(ori_img.shape[0], -1) |
|
col_dis = torch.mean(col_minus*pixel_valid) / \ |
|
(torch.mean(pixel_valid)+0.00001) |
|
land_dists = torch.norm( |
|
(proj_land-lands).reshape(-1, 2), dim=1).reshape(ori_img.shape[0], -1) |
|
lan_dis = torch.mean(land_dists) |
|
return col_dis, lan_dis |
|
|