CycleGAN / models /test_model.py
Yanguan's picture
0
58da73e
from . import networks
from .base_model import BaseModel
class TestModel(BaseModel):
"""This TestModel can be used to generate CycleGAN results for only one direction.
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
See the test instruction for more details.
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
parser:
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
The model can only be used during test time. It requires '--dataset_mode single'.
You need to specify the network using the option '--model_suffix'.
"""
assert not is_train, "TestModel cannot be used during training time"
parser.set_defaults(dataset_mode="single")
parser.add_argument(
"--model_suffix",
type=str,
default="",
help="In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.",
)
return parser
def __init__(self, opt):
"""Initialize the pix2pix class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
assert not opt.isTrain
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = []
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ["real", "fake"]
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.net_names = ["G" + opt.model_suffix] # only generator is needed.
self.net_G = networks.define_G(
opt.input_nc,
opt.output_nc,
opt.ngf,
opt.netG,
opt.norm,
not opt.no_dropout,
opt.init_type,
opt.init_gain,
self.gpu_ids,
)
# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
setattr(self, "net_G" + opt.model_suffix, self.net_G) # store netG in self.
def set_input(self, input):
"""Unpack input data from the dataLoader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
We need to use 'single_dataset' a dataset mode.
It only loads images from one domain.
"""
self.real = input["A"].to(self.device)
self.image_paths = input["A_paths"]
def forward(self):
"""Run forward pass."""
self.fake = self.net_G(self.real) # G(real)
def optimize_parameters(self):
"""No optimization for test model."""
pass