|
import os |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from . import networks |
|
|
|
|
|
class BaseModel(ABC): |
|
"""This class is an abstract base class (ABC) for models. |
|
To create a subclass, you need to implement the following five functions: |
|
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). |
|
-- <set_input>: unpack data from dataset and apply preprocessing. |
|
-- <forward>: produce intermediate results. |
|
-- <optimize_parameters>: calculate losses, gradients, and update network weights. |
|
-- <modify_commandline_options>: (optionally) add model-specific options and set default options. |
|
""" |
|
|
|
def __init__(self, opt): |
|
"""Initialize the BaseModel class. |
|
|
|
Parameters: |
|
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
|
|
When creating your custom class, you need to implement your own initialization. |
|
In this function, you should first call <BaseModel.__init__(self, opt)> |
|
Then, you need to define four lists: |
|
-- self.loss_names (str list): specify the training losses that you want to plot and save. |
|
-- self.net_names (str list): define networks used in our training. |
|
-- self.visual_names (str list): specify the images that you want to display and save. |
|
-- self.optimizers (optimizer list): define and initialize optimizers. |
|
You can define one optimizer for each network. |
|
If two networks are updated at the same time, you can use itertools. Chain to group them. |
|
See cycle_gan_model.py for an example. |
|
""" |
|
self.opt = opt |
|
self.gpu_ids = opt.gpu_ids |
|
self.isTrain = opt.isTrain |
|
self.device = ( |
|
torch.device("cuda:{}".format(self.gpu_ids[0])) |
|
if self.gpu_ids |
|
else torch.device("cpu") |
|
) |
|
print(self.device) |
|
self.save_dir = Path(opt.checkpoints_dir).joinpath( |
|
opt.name |
|
) |
|
|
|
if opt.preprocess != "scale_width": |
|
torch.backends.cudnn.benchmark = True |
|
self.loss_names = [] |
|
self.net_names = [] |
|
self.visual_names = [] |
|
self.optimizers = [] |
|
self.image_paths = [] |
|
self.metric = 0 |
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train): |
|
"""Add new model-specific options, and rewrite default values for existing options. |
|
|
|
Parameters: |
|
parser -- original option parser |
|
parser: |
|
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. |
|
|
|
Returns: |
|
the modified parser. |
|
""" |
|
return parser |
|
|
|
def setup(self, opt, load_weight=None): |
|
"""加载和打印网络;创建调度程序 |
|
|
|
Parameters: |
|
load_weight: |
|
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
""" |
|
if self.isTrain: |
|
self.schedulers = [ |
|
networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers |
|
] |
|
if not self.isTrain or opt.continue_train: |
|
load_suffix = "iter_%d" % opt.load_iter if opt.load_iter > 0 else opt.epoch |
|
self.load_networks(load_suffix, load_weight) |
|
self.print_networks(opt.verbose) |
|
|
|
def eval(self): |
|
"""Make models eval mode during test time""" |
|
for name in self.net_names: |
|
if isinstance(name, str): |
|
net = getattr(self, "net_" + name) |
|
net.eval() |
|
|
|
def test(self): |
|
"""Forward function used in test time. |
|
|
|
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop |
|
It also calls <compute_visuals> to produce additional visualization results |
|
""" |
|
with torch.no_grad(): |
|
self.forward() |
|
self.compute_visuals() |
|
|
|
def compute_visuals(self): |
|
"""Calculate additional output images for visdom and HTML visualization""" |
|
pass |
|
|
|
def get_current_visuals(self): |
|
"""Return visualization images. train.py will display these images with visdom, and save the images to an HTML""" |
|
visual_ret = OrderedDict() |
|
for name in self.visual_names: |
|
if isinstance(name, str): |
|
visual_ret[name] = getattr(self, name) |
|
return visual_ret |
|
|
|
def get_image_paths(self): |
|
"""Return image paths that are used to load current data""" |
|
return self.image_paths |
|
|
|
def save_networks(self, epoch): |
|
"""Save all the networks to the disk. |
|
|
|
Parameters: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
for name in self.net_names: |
|
if isinstance(name, str): |
|
save_filename = "%s_net_%s.pth" % (epoch, name) |
|
save_path = Path(self.save_dir, save_filename) |
|
net = getattr(self, "net_" + name) |
|
|
|
if len(self.gpu_ids) > 0 and torch.cuda.is_available(): |
|
torch.save(net.module.cpu().state_dict(), save_path) |
|
net.cuda(self.gpu_ids[0]) |
|
else: |
|
torch.save(net.cpu().state_dict(), save_path) |
|
|
|
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): |
|
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" |
|
key = keys[i] |
|
if i + 1 == len(keys): |
|
if module.__class__.__name__.startswith("InstanceNorm") and ( |
|
key == "running_mean" or key == "running_var" |
|
): |
|
if getattr(module, key) is None: |
|
state_dict.pop(".".join(keys)) |
|
if module.__class__.__name__.startswith("InstanceNorm") and ( |
|
key == "num_batches_tracked" |
|
): |
|
state_dict.pop(".".join(keys)) |
|
else: |
|
self.__patch_instance_norm_state_dict( |
|
state_dict, getattr(module, key), keys, i + 1 |
|
) |
|
|
|
def load_networks(self, epoch: int, load_weight=None): |
|
"""Load all the networks from the disk. |
|
|
|
Parameters: |
|
load_weight: |
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
|
""" |
|
for name in self.net_names: |
|
if isinstance(name, str): |
|
if not load_weight: |
|
load_filename = "%s_net_%s.pth" % (epoch, name) |
|
else: |
|
load_filename = f"{load_weight}.pth" |
|
load_path = self.save_dir.joinpath(load_filename) |
|
|
|
|
|
net = getattr(self, "net_" + name) |
|
if isinstance(net, torch.nn.DataParallel): |
|
net = net.module |
|
print("loading the model from %s" % load_path) |
|
|
|
|
|
state_dict = torch.load(load_path, map_location=self.device) |
|
if hasattr(state_dict, "_metadata"): |
|
del state_dict._metadata |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
self.__patch_instance_norm_state_dict( |
|
state_dict, net, key.split(".") |
|
) |
|
net.load_state_dict(state_dict) |
|
|
|
|
|
|
|
def print_networks(self, verbose): |
|
"""Print the total number of parameters in the network and (if verbose) network architecture |
|
|
|
Parameters: |
|
verbose (bool) -- if verbose: print the network architecture |
|
""" |
|
print("---------- Networks initialized -------------") |
|
for name in self.net_names: |
|
if isinstance(name, str): |
|
net = getattr(self, "net_" + name) |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
if verbose: |
|
print(net) |
|
print( |
|
"[Network %s] Total number of parameters : %.3f M" |
|
% (name, num_params / 1e6) |
|
) |
|
print("-----------------------------------------------") |
|
|