ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
raw
history blame
2.37 kB
"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
"""
from additional_modules.deep3dfacerecon.models.base_model import BaseModel
from additional_modules.deep3dfacerecon.models import networks
class FaceReconModel(BaseModel):
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
self.visual_names = ['output_vis']
self.model_names = ['net_recon']
self.net_recon = networks.define_net_recon(
net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path
)
self.net_recon.to(opt.device)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
self.input_img = input['imgs'].to(self.device)
self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
self.trans_m = input['M'].to(self.device) if 'M' in input else None
self.image_paths = input['im_paths'] if 'im_paths' in input else None
def split_coeff(self, coeffs):
"""
Return:
coeffs_dict -- a dict of torch.tensors
Parameters:
coeffs -- torch.tensor, size (B, 256)
"""
id_coeffs = coeffs[:, :80]
exp_coeffs = coeffs[:, 80: 144]
tex_coeffs = coeffs[:, 144: 224]
angles = coeffs[:, 224: 227]
gammas = coeffs[:, 227: 254]
translations = coeffs[:, 254:]
return {
'id': id_coeffs,
'exp': exp_coeffs,
'tex': tex_coeffs,
'angle': angles,
'gamma': gammas,
'trans': translations
}
def forward(self):
output_coeff = self.net_recon(self.input_img)
self.pred_coeffs_dict = self.split_coeff(output_coeff)