|
from . import networks |
|
from .base_model import BaseModel |
|
|
|
|
|
class TestModel(BaseModel): |
|
"""This TestModel can be used to generate CycleGAN results for only one direction. |
|
This model will automatically set '--dataset_mode single', which only loads the images from one collection. |
|
|
|
See the test instruction for more details. |
|
""" |
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train=True): |
|
"""Add new dataset-specific options, and rewrite default values for existing options. |
|
|
|
Parameters: |
|
parser -- original option parser |
|
parser: |
|
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. |
|
|
|
Returns: |
|
the modified parser. |
|
|
|
The model can only be used during test time. It requires '--dataset_mode single'. |
|
You need to specify the network using the option '--model_suffix'. |
|
""" |
|
assert not is_train, "TestModel cannot be used during training time" |
|
parser.set_defaults(dataset_mode="single") |
|
parser.add_argument( |
|
"--model_suffix", |
|
type=str, |
|
default="", |
|
help="In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.", |
|
) |
|
|
|
return parser |
|
|
|
def __init__(self, opt): |
|
"""Initialize the pix2pix class. |
|
|
|
Parameters: |
|
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
""" |
|
assert not opt.isTrain |
|
BaseModel.__init__(self, opt) |
|
|
|
self.loss_names = [] |
|
|
|
self.visual_names = ["real", "fake"] |
|
|
|
self.net_names = ["G" + opt.model_suffix] |
|
self.net_G = networks.define_G( |
|
opt.input_nc, |
|
opt.output_nc, |
|
opt.ngf, |
|
opt.netG, |
|
opt.norm, |
|
not opt.no_dropout, |
|
opt.init_type, |
|
opt.init_gain, |
|
self.gpu_ids, |
|
) |
|
|
|
|
|
|
|
setattr(self, "net_G" + opt.model_suffix, self.net_G) |
|
|
|
def set_input(self, input): |
|
"""Unpack input data from the dataLoader and perform necessary pre-processing steps. |
|
|
|
Parameters: |
|
input: a dictionary that contains the data itself and its metadata information. |
|
|
|
We need to use 'single_dataset' a dataset mode. |
|
It only loads images from one domain. |
|
""" |
|
self.real = input["A"].to(self.device) |
|
self.image_paths = input["A_paths"] |
|
|
|
def forward(self): |
|
"""Run forward pass.""" |
|
self.fake = self.net_G(self.real) |
|
|
|
def optimize_parameters(self): |
|
"""No optimization for test model.""" |
|
pass |
|
|