"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch """ from additional_modules.deep3dfacerecon.models.base_model import BaseModel from additional_modules.deep3dfacerecon.models import networks class FaceReconModel(BaseModel): def __init__(self, opt): """Initialize this model class. Parameters: opt -- training/test options A few things can be done here. - (required) call the initialization function of BaseModel - define loss function, visualization images, model names, and optimizers """ BaseModel.__init__(self, opt) # call the initialization method of BaseModel self.visual_names = ['output_vis'] self.model_names = ['net_recon'] self.net_recon = networks.define_net_recon( net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path ) self.net_recon.to(opt.device) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. """ self.input_img = input['imgs'].to(self.device) self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None self.trans_m = input['M'].to(self.device) if 'M' in input else None self.image_paths = input['im_paths'] if 'im_paths' in input else None def split_coeff(self, coeffs): """ Return: coeffs_dict -- a dict of torch.tensors Parameters: coeffs -- torch.tensor, size (B, 256) """ id_coeffs = coeffs[:, :80] exp_coeffs = coeffs[:, 80: 144] tex_coeffs = coeffs[:, 144: 224] angles = coeffs[:, 224: 227] gammas = coeffs[:, 227: 254] translations = coeffs[:, 254:] return { 'id': id_coeffs, 'exp': exp_coeffs, 'tex': tex_coeffs, 'angle': angles, 'gamma': gammas, 'trans': translations } def forward(self): output_coeff = self.net_recon(self.input_img) self.pred_coeffs_dict = self.split_coeff(output_coeff)