|
"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch |
|
""" |
|
|
|
from additional_modules.deep3dfacerecon.models.base_model import BaseModel |
|
from additional_modules.deep3dfacerecon.models import networks |
|
|
|
|
|
class FaceReconModel(BaseModel): |
|
def __init__(self, opt): |
|
"""Initialize this model class. |
|
|
|
Parameters: |
|
opt -- training/test options |
|
|
|
A few things can be done here. |
|
- (required) call the initialization function of BaseModel |
|
- define loss function, visualization images, model names, and optimizers |
|
""" |
|
BaseModel.__init__(self, opt) |
|
|
|
self.visual_names = ['output_vis'] |
|
self.model_names = ['net_recon'] |
|
|
|
self.net_recon = networks.define_net_recon( |
|
net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path |
|
) |
|
self.net_recon.to(opt.device) |
|
|
|
def set_input(self, input): |
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps. |
|
|
|
Parameters: |
|
input: a dictionary that contains the data itself and its metadata information. |
|
""" |
|
self.input_img = input['imgs'].to(self.device) |
|
self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None |
|
self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None |
|
self.trans_m = input['M'].to(self.device) if 'M' in input else None |
|
self.image_paths = input['im_paths'] if 'im_paths' in input else None |
|
|
|
def split_coeff(self, coeffs): |
|
""" |
|
Return: |
|
coeffs_dict -- a dict of torch.tensors |
|
|
|
Parameters: |
|
coeffs -- torch.tensor, size (B, 256) |
|
""" |
|
id_coeffs = coeffs[:, :80] |
|
exp_coeffs = coeffs[:, 80: 144] |
|
tex_coeffs = coeffs[:, 144: 224] |
|
angles = coeffs[:, 224: 227] |
|
gammas = coeffs[:, 227: 254] |
|
translations = coeffs[:, 254:] |
|
return { |
|
'id': id_coeffs, |
|
'exp': exp_coeffs, |
|
'tex': tex_coeffs, |
|
'angle': angles, |
|
'gamma': gammas, |
|
'trans': translations |
|
} |
|
|
|
def forward(self): |
|
output_coeff = self.net_recon(self.input_img) |
|
self.pred_coeffs_dict = self.split_coeff(output_coeff) |
|
|