|
"""This script defines the base network model for Deep3DFaceRecon_pytorch |
|
""" |
|
|
|
import os |
|
import torch |
|
from abc import ABC, abstractmethod |
|
|
|
|
|
class BaseModel(ABC): |
|
"""This class is an abstract base class (ABC) for models. |
|
To create a subclass, you need to implement the following five functions: |
|
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). |
|
-- <set_input>: unpack data from dataset and apply preprocessing. |
|
-- <forward>: produce intermediate results. |
|
-- <optimize_parameters>: calculate losses, gradients, and update network weights. |
|
-- <modify_commandline_options>: (optionally) add model-specific options and set default options. |
|
""" |
|
|
|
def __init__(self, opt): |
|
"""Initialize the BaseModel class. |
|
|
|
Parameters: |
|
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
""" |
|
self.opt = opt |
|
self.device = opt.device |
|
|
|
@abstractmethod |
|
def set_input(self, input): |
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps. |
|
|
|
Parameters: |
|
input (dict): includes the data itself and its metadata information. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def forward(self): |
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" |
|
pass |
|
|
|
def eval(self): |
|
"""Make models eval mode""" |
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
net.eval() |
|
|
|
def test(self): |
|
"""Forward function used in test time. |
|
|
|
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop |
|
It also calls <compute_visuals> to produce additional visualization results |
|
""" |
|
with torch.no_grad(): |
|
self.forward() |
|
|
|
|
|
def load_networks(self, load_path): |
|
"""Load all the networks from the disk. |
|
|
|
Parameters: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
state_dict = torch.load(load_path, map_location=self.device) |
|
print('loading the model from %s' % load_path) |
|
|
|
for name in self.model_names: |
|
if isinstance(name, str): |
|
net = getattr(self, name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
net.load_state_dict(state_dict[name]) |
|
|