"""This script defines the base network model for Deep3DFaceRecon_pytorch """ import os import torch from abc import ABC, abstractmethod class BaseModel(ABC): """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ self.opt = opt self.device = opt.device @abstractmethod def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): includes the data itself and its metadata information. """ pass @abstractmethod def forward(self): """Run forward pass; called by both functions and .""" pass def eval(self): """Make models eval mode""" for name in self.model_names: if isinstance(name, str): net = getattr(self, name) net.eval() def test(self): """Forward function used in test time. This function wraps function in no_grad() so we don't save intermediate steps for backprop It also calls to produce additional visualization results """ with torch.no_grad(): self.forward() # self.compute_visuals() def load_networks(self, load_path): """Load all the networks from the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ state_dict = torch.load(load_path, map_location=self.device) print('loading the model from %s' % load_path) for name in self.model_names: if isinstance(name, str): net = getattr(self, name) if isinstance(net, torch.nn.DataParallel): net = net.module net.load_state_dict(state_dict[name])