CycleGAN / models /base_model.py
Yanguan's picture
0
58da73e
import os
from abc import ABC, abstractmethod
from collections import OrderedDict
from pathlib import Path
import torch
from . import networks
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.net_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers.
You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools. Chain to group them.
See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = (
torch.device("cuda:{}".format(self.gpu_ids[0]))
if self.gpu_ids
else torch.device("cpu")
)
print(self.device)
self.save_dir = Path(opt.checkpoints_dir).joinpath(
opt.name
) # save all the checkpoints to save_dir
# with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
if opt.preprocess != "scale_width":
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.net_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
parser:
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
def setup(self, opt, load_weight=None):
"""加载和打印网络;创建调度程序
Parameters:
load_weight:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [
networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers
]
if not self.isTrain or opt.continue_train:
load_suffix = "iter_%d" % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix, load_weight)
self.print_networks(opt.verbose)
def eval(self):
"""Make models eval mode during test time"""
for name in self.net_names:
if isinstance(name, str):
net = getattr(self, "net_" + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to an HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_image_paths(self):
"""Return image paths that are used to load current data"""
return self.image_paths
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.net_names:
if isinstance(name, str):
save_filename = "%s_net_%s.pth" % (epoch, name)
save_path = Path(self.save_dir, save_filename)
net = getattr(self, "net_" + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith("InstanceNorm") and (
key == "running_mean" or key == "running_var"
):
if getattr(module, key) is None:
state_dict.pop(".".join(keys))
if module.__class__.__name__.startswith("InstanceNorm") and (
key == "num_batches_tracked"
):
state_dict.pop(".".join(keys))
else:
self.__patch_instance_norm_state_dict(
state_dict, getattr(module, key), keys, i + 1
)
def load_networks(self, epoch: int, load_weight=None):
"""Load all the networks from the disk.
Parameters:
load_weight:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.net_names:
if isinstance(name, str):
if not load_weight:
load_filename = "%s_net_%s.pth" % (epoch, name)
else:
load_filename = f"{load_weight}.pth"
load_path = self.save_dir.joinpath(load_filename)
# if not load_path.exists():
# load_path = "./weights/pre/latest_net_G.pth"
net = getattr(self, "net_" + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print("loading the model from %s" % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=self.device)
if hasattr(state_dict, "_metadata"):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
# need to copy keys here because we mutate in the loop
for key in list(state_dict.keys()):
self.__patch_instance_norm_state_dict(
state_dict, net, key.split(".")
)
net.load_state_dict(state_dict)
# net.half()
# 已经是float16了 /(ㄒoㄒ)/~~,没有float8
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print("---------- Networks initialized -------------")
for name in self.net_names:
if isinstance(name, str):
net = getattr(self, "net_" + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print(
"[Network %s] Total number of parameters : %.3f M"
% (name, num_params / 1e6)
)
print("-----------------------------------------------")