import os from abc import ABC, abstractmethod from collections import OrderedDict from pathlib import Path import torch from . import networks class BaseModel(ABC): """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions When creating your custom class, you need to implement your own initialization. In this function, you should first call Then, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.net_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools. Chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = ( torch.device("cuda:{}".format(self.gpu_ids[0])) if self.gpu_ids else torch.device("cpu") ) print(self.device) self.save_dir = Path(opt.checkpoints_dir).joinpath( opt.name ) # save all the checkpoints to save_dir # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. if opt.preprocess != "scale_width": torch.backends.cudnn.benchmark = True self.loss_names = [] self.net_names = [] self.visual_names = [] self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' @staticmethod def modify_commandline_options(parser, is_train): """Add new model-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser parser: is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser def setup(self, opt, load_weight=None): """加载和打印网络;创建调度程序 Parameters: load_weight: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ if self.isTrain: self.schedulers = [ networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers ] if not self.isTrain or opt.continue_train: load_suffix = "iter_%d" % opt.load_iter if opt.load_iter > 0 else opt.epoch self.load_networks(load_suffix, load_weight) self.print_networks(opt.verbose) def eval(self): """Make models eval mode during test time""" for name in self.net_names: if isinstance(name, str): net = getattr(self, "net_" + name) net.eval() def test(self): """Forward function used in test time. This function wraps function in no_grad() so we don't save intermediate steps for backprop It also calls to produce additional visualization results """ with torch.no_grad(): self.forward() self.compute_visuals() def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" pass def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to an HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret def get_image_paths(self): """Return image paths that are used to load current data""" return self.image_paths def save_networks(self, epoch): """Save all the networks to the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.net_names: if isinstance(name, str): save_filename = "%s_net_%s.pth" % (epoch, name) save_path = Path(self.save_dir, save_filename) net = getattr(self, "net_" + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith("InstanceNorm") and ( key == "running_mean" or key == "running_var" ): if getattr(module, key) is None: state_dict.pop(".".join(keys)) if module.__class__.__name__.startswith("InstanceNorm") and ( key == "num_batches_tracked" ): state_dict.pop(".".join(keys)) else: self.__patch_instance_norm_state_dict( state_dict, getattr(module, key), keys, i + 1 ) def load_networks(self, epoch: int, load_weight=None): """Load all the networks from the disk. Parameters: load_weight: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.net_names: if isinstance(name, str): if not load_weight: load_filename = "%s_net_%s.pth" % (epoch, name) else: load_filename = f"{load_weight}.pth" load_path = self.save_dir.joinpath(load_filename) # if not load_path.exists(): # load_path = "./weights/pre/latest_net_G.pth" net = getattr(self, "net_" + name) if isinstance(net, torch.nn.DataParallel): net = net.module print("loading the model from %s" % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=self.device) if hasattr(state_dict, "_metadata"): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 # need to copy keys here because we mutate in the loop for key in list(state_dict.keys()): self.__patch_instance_norm_state_dict( state_dict, net, key.split(".") ) net.load_state_dict(state_dict) # net.half() # 已经是float16了 /(ㄒoㄒ)/~~,没有float8 def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture Parameters: verbose (bool) -- if verbose: print the network architecture """ print("---------- Networks initialized -------------") for name in self.net_names: if isinstance(name, str): net = getattr(self, "net_" + name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print( "[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6) ) print("-----------------------------------------------")