ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
raw
history blame
2.66 kB
"""This script defines the base network model for Deep3DFaceRecon_pytorch
"""
import os
import torch
from abc import ABC, abstractmethod
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.device = opt.device
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def eval(self):
"""Make models eval mode"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
# self.compute_visuals()
def load_networks(self, load_path):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
state_dict = torch.load(load_path, map_location=self.device)
print('loading the model from %s' % load_path)
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
net.load_state_dict(state_dict[name])