|
import argparse |
|
import torch |
|
from torchvision.utils import make_grid |
|
from PIL import Image, ImageFile |
|
from net import Net |
|
from utils import DEVICE, test_transform |
|
Image.MAX_IMAGE_PIXELS = None |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
def style_transfer_method(content_image,style_img): |
|
main_parser = argparse.ArgumentParser(description="main parser") |
|
subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand") |
|
|
|
main_parser.add_argument("--pretrained", type=bool, default=True, |
|
help="whether to use the pre-trained checkpoints") |
|
main_parser.add_argument("--requires_grad", type=bool, default=True, |
|
help="set to True if the model requires model gradient") |
|
|
|
train_parser = subparsers.add_parser("train", help="training mode parser") |
|
train_parser.add_argument("--training", type=bool, default=True) |
|
train_parser.add_argument("--iterations", type=int, default=60000, |
|
help="total training epochs (default: 160000)") |
|
train_parser.add_argument("--batch_size", type=int, default=2, |
|
help="training batch size (default: 8)") |
|
train_parser.add_argument("--num_workers", type=int, default=2, |
|
help="iterator threads (default: 8)") |
|
train_parser.add_argument("--lr", type=float, default=1e-4, help="the learning rate during training (default: 1e-4)") |
|
train_parser.add_argument("--content_folder", type=str, required = True, |
|
help="the root of content images, the path should point to a folder") |
|
train_parser.add_argument("--style_folder", type=str, required = True, |
|
help="the root of style images, the path should point to a folder") |
|
train_parser.add_argument("--log_interval", type=int, default=10000, |
|
help="number of images after which the training loss is logged (default: 20000)") |
|
|
|
train_parser.add_argument("--w_content1", type=float, default=12, help="the stage1 content loss weight") |
|
train_parser.add_argument("--w_content2", type=float, default=9, help="the stage2 content loss weight") |
|
train_parser.add_argument("--w_content3", type=float, default=7, help="the stage3 content loss weight") |
|
train_parser.add_argument("--w_remd1", type=float, default=2, help="the stage1 remd loss weight") |
|
train_parser.add_argument("--w_remd2", type=float, default=2, help="the stage2 remd loss weight") |
|
train_parser.add_argument("--w_remd3", type=float, default=2, help="the stage3 remd loss weight") |
|
train_parser.add_argument("--w_moment1", type=float, default=2, help="the stage1 moment loss weight") |
|
train_parser.add_argument("--w_moment2", type=float, default=2, help="the stage2 moment loss weight") |
|
train_parser.add_argument("--w_moment3", type=float, default=2, help="the stage3 moment loss weight") |
|
train_parser.add_argument("--color_on", type=str, default=True, help="turn on the color loss") |
|
train_parser.add_argument("--w_color1", type=float, default=0.25, help="the stage1 color loss weight") |
|
train_parser.add_argument("--w_color2", type=float, default=0.5, help="the stage2 color loss weight") |
|
train_parser.add_argument("--w_color3", type=float, default=1, help="the stage3 color loss weight") |
|
|
|
|
|
eval_parser = subparsers.add_parser("eval", help="evaluation mode parser") |
|
eval_parser.add_argument("--training", type=bool, default=False) |
|
eval_parser.add_argument("--run_folder", type=bool, default=False) |
|
|
|
args = main_parser.parse_args() |
|
|
|
args.training = False |
|
|
|
model = Net(args) |
|
model.eval() |
|
model = model.to(DEVICE) |
|
|
|
tf = test_transform() |
|
|
|
Ic = tf(content_image).to(DEVICE) |
|
Is = tf(Image.fromarray(style_img)).to(DEVICE) |
|
|
|
Ic = Ic.unsqueeze(dim=0) |
|
Is = Is.unsqueeze(dim=0) |
|
|
|
with torch.no_grad(): |
|
Ics = model(Ic, Is) |
|
|
|
grid = make_grid(Ics[0]) |
|
|
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
im = Image.fromarray(ndarr) |
|
|
|
return im |
|
|