Upload 6 files
Browse files- app.py +75 -0
- hist_loss.py +208 -0
- ics.jpg +0 -0
- net.py +281 -0
- style_trsfer.py +80 -0
- utils.py +168 -0
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import DiffusionPipeline,UniPCMultistepScheduler
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import gc
|
| 5 |
+
from style_trsfer import style_transfer_method
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def generate(style_image,text, negative_prompts,steps,guidance_scale):
|
| 9 |
+
pipeline = DiffusionPipeline.from_pretrained("./CCLAP")
|
| 10 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
| 11 |
+
pipeline.scheduler.config)
|
| 12 |
+
device = torch.device(
|
| 13 |
+
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
if device.type == 'cuda':
|
| 15 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
| 16 |
+
pipeline.to(device)
|
| 17 |
+
torch.cuda.empty_cache()
|
| 18 |
+
gc.collect()
|
| 19 |
+
content_image = pipeline(text,
|
| 20 |
+
num_inference_steps=steps,
|
| 21 |
+
negative_prompt=negative_prompts,
|
| 22 |
+
guidance_scale=guidance_scale).images[0]
|
| 23 |
+
result = style_transfer_method(content_image,style_image)
|
| 24 |
+
return content_image,result
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == '__main__':
|
| 28 |
+
|
| 29 |
+
demo = gr.Interface(title="CCLAP",
|
| 30 |
+
description = (
|
| 31 |
+
"This is the demo of CCLAP to generate Chinese landscape painting."
|
| 32 |
+
),
|
| 33 |
+
css="",
|
| 34 |
+
fn=generate,
|
| 35 |
+
inputs=[gr.Image(label="Style Image",shape=(512,512)),
|
| 36 |
+
gr.Textbox(lines=3, placeholder="Input the prompt", label="Prompt"),
|
| 37 |
+
gr.Textbox(lines=3, placeholder="low quality", label="Negative prompt"),
|
| 38 |
+
gr.Slider(minimum=0, maximum=100, value=20,label='Steps'),
|
| 39 |
+
gr.Slider(minimum=0, maximum=30, value=7.5,label='Guidance_scale'),
|
| 40 |
+
],
|
| 41 |
+
outputs=[gr.Image(label="Content Output",shape=(256,256)),
|
| 42 |
+
gr.Image(label="Final Output",shape=(256,256))],
|
| 43 |
+
examples = [
|
| 44 |
+
[
|
| 45 |
+
'style_image/style1.jpg',
|
| 46 |
+
'A Chinese landscape painting of a mountain landscape with trees',
|
| 47 |
+
'low quality',
|
| 48 |
+
20,
|
| 49 |
+
7.5
|
| 50 |
+
],
|
| 51 |
+
[
|
| 52 |
+
'style_image/style2.jpg',
|
| 53 |
+
'A Chinese landscape painting of a building with trees in front of it',
|
| 54 |
+
'low quality',
|
| 55 |
+
20,
|
| 56 |
+
7.5
|
| 57 |
+
],
|
| 58 |
+
[
|
| 59 |
+
'style_image/style3.jpg',
|
| 60 |
+
'A Chinese landscape painting of a landscape with mountains in the background',
|
| 61 |
+
'low quality',
|
| 62 |
+
20,
|
| 63 |
+
7.5
|
| 64 |
+
],
|
| 65 |
+
[
|
| 66 |
+
'style_image/style4.jpg',
|
| 67 |
+
'A Chinese landscape painting of a landscape with mountains and a river',
|
| 68 |
+
'low quality',
|
| 69 |
+
20,
|
| 70 |
+
7.5
|
| 71 |
+
],
|
| 72 |
+
],
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
demo.launch()
|
hist_loss.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright 2021 Mahmoud Afifi.
|
| 3 |
+
Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN:
|
| 4 |
+
Controlling Colors of GAN-Generated and Real Images via Color Histograms."
|
| 5 |
+
In CVPR, 2021.
|
| 6 |
+
|
| 7 |
+
@inproceedings{afifi2021histogan,
|
| 8 |
+
title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via
|
| 9 |
+
Color Histograms},
|
| 10 |
+
author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.},
|
| 11 |
+
booktitle={CVPR},
|
| 12 |
+
year={2021}
|
| 13 |
+
}
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from PIL import Image
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torchvision.transforms as transforms
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
EPS = 1e-6
|
| 25 |
+
|
| 26 |
+
class RGBuvHistBlock(nn.Module):
|
| 27 |
+
def __init__(self, h=64, insz=150, resizing='interpolation',
|
| 28 |
+
method='inverse-quadratic', sigma=0.02, intensity_scale=True,
|
| 29 |
+
device='cuda'):
|
| 30 |
+
""" Computes the RGB-uv histogram feature of a given image.
|
| 31 |
+
Args:
|
| 32 |
+
h: histogram dimension size (scalar). The default value is 64.
|
| 33 |
+
insz: maximum size of the input image; if it is larger than this size, the
|
| 34 |
+
image will be resized (scalar). Default value is 150 (i.e., 150 x 150
|
| 35 |
+
pixels).
|
| 36 |
+
resizing: resizing method if applicable. Options are: 'interpolation' or
|
| 37 |
+
'sampling'. Default is 'interpolation'.
|
| 38 |
+
method: the method used to count the number of pixels for each bin in the
|
| 39 |
+
histogram feature. Options are: 'thresholding', 'RBF' (radial basis
|
| 40 |
+
function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'.
|
| 41 |
+
sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is
|
| 42 |
+
the sigma parameter of the kernel function. The default value is 0.02.
|
| 43 |
+
intensity_scale: boolean variable to use the intensity scale (I_y in
|
| 44 |
+
Equation 2). Default value is True.
|
| 45 |
+
|
| 46 |
+
Methods:
|
| 47 |
+
forward: accepts input image and returns its histogram feature. Note that
|
| 48 |
+
unless the method is 'thresholding', this is a differentiable function
|
| 49 |
+
and can be easily integrated with the loss function. As mentioned in the
|
| 50 |
+
paper, the 'inverse-quadratic' was found more stable than 'RBF' in our
|
| 51 |
+
training.
|
| 52 |
+
"""
|
| 53 |
+
super(RGBuvHistBlock, self).__init__()
|
| 54 |
+
self.h = h
|
| 55 |
+
self.insz = insz
|
| 56 |
+
self.device = device
|
| 57 |
+
self.resizing = resizing
|
| 58 |
+
self.method = method
|
| 59 |
+
self.intensity_scale = intensity_scale
|
| 60 |
+
if self.method == 'thresholding':
|
| 61 |
+
self.eps = 6.0 / h
|
| 62 |
+
else:
|
| 63 |
+
self.sigma = sigma
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
x = torch.clamp(x, 0, 1)
|
| 67 |
+
if x.shape[2] > self.insz or x.shape[3] > self.insz:
|
| 68 |
+
if self.resizing == 'interpolation':
|
| 69 |
+
x_sampled = F.interpolate(x, size=(self.insz, self.insz),
|
| 70 |
+
mode='bilinear', align_corners=False)
|
| 71 |
+
elif self.resizing == 'sampling':
|
| 72 |
+
inds_1 = torch.LongTensor(
|
| 73 |
+
np.linspace(0, x.shape[2], self.h, endpoint=False)).to(
|
| 74 |
+
device=self.device)
|
| 75 |
+
inds_2 = torch.LongTensor(
|
| 76 |
+
np.linspace(0, x.shape[3], self.h, endpoint=False)).to(
|
| 77 |
+
device=self.device)
|
| 78 |
+
x_sampled = x.index_select(2, inds_1)
|
| 79 |
+
x_sampled = x_sampled.index_select(3, inds_2)
|
| 80 |
+
else:
|
| 81 |
+
raise Exception(
|
| 82 |
+
f'Wrong resizing method. It should be: interpolation or sampling. '
|
| 83 |
+
f'But the given value is {self.resizing}.')
|
| 84 |
+
else:
|
| 85 |
+
x_sampled = x
|
| 86 |
+
|
| 87 |
+
L = x_sampled.shape[0] # size of mini-batch
|
| 88 |
+
if x_sampled.shape[1] > 3:
|
| 89 |
+
x_sampled = x_sampled[:, :3, :, :]
|
| 90 |
+
X = torch.unbind(x_sampled, dim=0)
|
| 91 |
+
hists = torch.zeros((x_sampled.shape[0], 3, self.h, self.h)).to(
|
| 92 |
+
device=self.device)
|
| 93 |
+
for l in range(L):
|
| 94 |
+
I = torch.t(torch.reshape(X[l], (3, -1)))
|
| 95 |
+
II = torch.pow(I, 2)
|
| 96 |
+
if self.intensity_scale:
|
| 97 |
+
Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS),
|
| 98 |
+
dim=1)
|
| 99 |
+
else:
|
| 100 |
+
Iy = 1
|
| 101 |
+
|
| 102 |
+
Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + EPS),
|
| 103 |
+
dim=1)
|
| 104 |
+
Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + EPS),
|
| 105 |
+
dim=1)
|
| 106 |
+
diff_u0 = abs(
|
| 107 |
+
Iu0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
| 108 |
+
dim=0).to(self.device))
|
| 109 |
+
diff_v0 = abs(
|
| 110 |
+
Iv0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
| 111 |
+
dim=0).to(self.device))
|
| 112 |
+
if self.method == 'thresholding':
|
| 113 |
+
diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2
|
| 114 |
+
diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2
|
| 115 |
+
elif self.method == 'RBF':
|
| 116 |
+
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
|
| 117 |
+
2) / self.sigma ** 2
|
| 118 |
+
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
|
| 119 |
+
2) / self.sigma ** 2
|
| 120 |
+
diff_u0 = torch.exp(-diff_u0) # Radial basis function
|
| 121 |
+
diff_v0 = torch.exp(-diff_v0)
|
| 122 |
+
elif self.method == 'inverse-quadratic':
|
| 123 |
+
diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)),
|
| 124 |
+
2) / self.sigma ** 2
|
| 125 |
+
diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)),
|
| 126 |
+
2) / self.sigma ** 2
|
| 127 |
+
diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic
|
| 128 |
+
diff_v0 = 1 / (1 + diff_v0)
|
| 129 |
+
else:
|
| 130 |
+
raise Exception(
|
| 131 |
+
f'Wrong kernel method. It should be either thresholding, RBF,'
|
| 132 |
+
f' inverse-quadratic. But the given value is {self.method}.')
|
| 133 |
+
diff_u0 = diff_u0.type(torch.float32)
|
| 134 |
+
diff_v0 = diff_v0.type(torch.float32)
|
| 135 |
+
a = torch.t(Iy * diff_u0)
|
| 136 |
+
hists[l, 0, :, :] = torch.mm(a, diff_v0)
|
| 137 |
+
|
| 138 |
+
Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS),
|
| 139 |
+
dim=1)
|
| 140 |
+
Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS),
|
| 141 |
+
dim=1)
|
| 142 |
+
diff_u1 = abs(
|
| 143 |
+
Iu1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
| 144 |
+
dim=0).to(self.device))
|
| 145 |
+
diff_v1 = abs(
|
| 146 |
+
Iv1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
| 147 |
+
dim=0).to(self.device))
|
| 148 |
+
|
| 149 |
+
if self.method == 'thresholding':
|
| 150 |
+
diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2
|
| 151 |
+
diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2
|
| 152 |
+
elif self.method == 'RBF':
|
| 153 |
+
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
|
| 154 |
+
2) / self.sigma ** 2
|
| 155 |
+
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
|
| 156 |
+
2) / self.sigma ** 2
|
| 157 |
+
diff_u1 = torch.exp(-diff_u1) # Gaussian
|
| 158 |
+
diff_v1 = torch.exp(-diff_v1)
|
| 159 |
+
elif self.method == 'inverse-quadratic':
|
| 160 |
+
diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)),
|
| 161 |
+
2) / self.sigma ** 2
|
| 162 |
+
diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)),
|
| 163 |
+
2) / self.sigma ** 2
|
| 164 |
+
diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic
|
| 165 |
+
diff_v1 = 1 / (1 + diff_v1)
|
| 166 |
+
|
| 167 |
+
diff_u1 = diff_u1.type(torch.float32)
|
| 168 |
+
diff_v1 = diff_v1.type(torch.float32)
|
| 169 |
+
a = torch.t(Iy * diff_u1)
|
| 170 |
+
hists[l, 1, :, :] = torch.mm(a, diff_v1)
|
| 171 |
+
|
| 172 |
+
Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + EPS),
|
| 173 |
+
dim=1)
|
| 174 |
+
Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + EPS),
|
| 175 |
+
dim=1)
|
| 176 |
+
diff_u2 = abs(
|
| 177 |
+
Iu2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
| 178 |
+
dim=0).to(self.device))
|
| 179 |
+
diff_v2 = abs(
|
| 180 |
+
Iv2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)),
|
| 181 |
+
dim=0).to(self.device))
|
| 182 |
+
if self.method == 'thresholding':
|
| 183 |
+
diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2
|
| 184 |
+
diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2
|
| 185 |
+
elif self.method == 'RBF':
|
| 186 |
+
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
|
| 187 |
+
2) / self.sigma ** 2
|
| 188 |
+
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
|
| 189 |
+
2) / self.sigma ** 2
|
| 190 |
+
diff_u2 = torch.exp(-diff_u2) # Gaussian
|
| 191 |
+
diff_v2 = torch.exp(-diff_v2)
|
| 192 |
+
elif self.method == 'inverse-quadratic':
|
| 193 |
+
diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)),
|
| 194 |
+
2) / self.sigma ** 2
|
| 195 |
+
diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)),
|
| 196 |
+
2) / self.sigma ** 2
|
| 197 |
+
diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic
|
| 198 |
+
diff_v2 = 1 / (1 + diff_v2)
|
| 199 |
+
diff_u2 = diff_u2.type(torch.float32)
|
| 200 |
+
diff_v2 = diff_v2.type(torch.float32)
|
| 201 |
+
a = torch.t(Iy * diff_u2)
|
| 202 |
+
hists[l, 2, :, :] = torch.mm(a, diff_v2)
|
| 203 |
+
|
| 204 |
+
# normalization
|
| 205 |
+
hists_normalized = hists / (
|
| 206 |
+
((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS)
|
| 207 |
+
|
| 208 |
+
return hists_normalized
|
ics.jpg
ADDED
|
net.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from utils import mean_variance_norm, DEVICE
|
| 4 |
+
from utils import calc_ss_loss, calc_remd_loss, calc_moment_loss, calc_mse_loss, calc_histogram_loss
|
| 5 |
+
from hist_loss import RGBuvHistBlock
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
class Net(nn.Module):
|
| 9 |
+
def __init__(self, args):
|
| 10 |
+
super(Net, self).__init__()
|
| 11 |
+
self.args = args
|
| 12 |
+
self.vgg = vgg19[:44]
|
| 13 |
+
self.vgg.load_state_dict(torch.load('./checkpoints/encoder.pth', map_location='cpu'), strict=False)
|
| 14 |
+
for param in self.vgg.parameters():
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
|
| 17 |
+
self.align1 = PAMA(512)
|
| 18 |
+
self.align2 = PAMA(512)
|
| 19 |
+
self.align3 = PAMA(512)
|
| 20 |
+
|
| 21 |
+
self.decoder = decoder
|
| 22 |
+
self.hist = RGBuvHistBlock(insz=64, h=256,
|
| 23 |
+
intensity_scale=True,
|
| 24 |
+
method='inverse-quadratic',
|
| 25 |
+
device=DEVICE)
|
| 26 |
+
|
| 27 |
+
if args.pretrained == True:
|
| 28 |
+
self.align1.load_state_dict(torch.load('./checkpoints/PAMA1.pth', map_location='cpu'), strict=True)
|
| 29 |
+
self.align2.load_state_dict(torch.load('./checkpoints/PAMA2.pth', map_location='cpu'), strict=True)
|
| 30 |
+
self.align3.load_state_dict(torch.load('./checkpoints/PAMA3.pth', map_location='cpu'), strict=True)
|
| 31 |
+
self.decoder.load_state_dict(torch.load('./checkpoints/decoder.pth', map_location='cpu'), strict=False)
|
| 32 |
+
|
| 33 |
+
if args.requires_grad == False:
|
| 34 |
+
for param in self.parameters():
|
| 35 |
+
param.requires_grad = False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def forward(self, Ic, Is):
|
| 39 |
+
feat_c = self.forward_vgg(Ic)
|
| 40 |
+
feat_s = self.forward_vgg(Is)
|
| 41 |
+
Fc, Fs = feat_c[3], feat_s[3]
|
| 42 |
+
|
| 43 |
+
Fcs1 = self.align1(Fc, Fs)
|
| 44 |
+
Fcs2 = self.align2(Fcs1, Fs)
|
| 45 |
+
Fcs3 = self.align3(Fcs2, Fs)
|
| 46 |
+
|
| 47 |
+
Ics3 = self.decoder(Fcs3)
|
| 48 |
+
|
| 49 |
+
if self.args.training == True:
|
| 50 |
+
Ics1 = self.decoder(Fcs1)
|
| 51 |
+
Ics2 = self.decoder(Fcs2)
|
| 52 |
+
Irc = self.decoder(Fc)
|
| 53 |
+
Irs = self.decoder(Fs)
|
| 54 |
+
feat_cs1 = self.forward_vgg(Ics1)
|
| 55 |
+
feat_cs2 = self.forward_vgg(Ics2)
|
| 56 |
+
feat_cs3 = self.forward_vgg(Ics3)
|
| 57 |
+
feat_rc = self.forward_vgg(Irc)
|
| 58 |
+
feat_rs = self.forward_vgg(Irs)
|
| 59 |
+
|
| 60 |
+
content_loss1, remd_loss1, moment_loss1, color_loss1 = 0.0, 0.0, 0.0, 0.0
|
| 61 |
+
content_loss2, remd_loss2, moment_loss2, color_loss2 = 0.0, 0.0, 0.0, 0.0
|
| 62 |
+
content_loss3, remd_loss3, moment_loss3, color_loss3 = 0.0, 0.0, 0.0, 0.0
|
| 63 |
+
loss_rec = 0.0
|
| 64 |
+
|
| 65 |
+
for l in range(2, 5):
|
| 66 |
+
content_loss1 += self.args.w_content1 * calc_ss_loss(feat_cs1[l], feat_c[l])
|
| 67 |
+
remd_loss1 += self.args.w_remd1 * calc_remd_loss(feat_cs1[l], feat_s[l])
|
| 68 |
+
moment_loss1 += self.args.w_moment1 * calc_moment_loss(feat_cs1[l], feat_s[l])
|
| 69 |
+
|
| 70 |
+
content_loss2 += self.args.w_content2 * calc_ss_loss(feat_cs2[l], feat_c[l])
|
| 71 |
+
remd_loss2 += self.args.w_remd2 * calc_remd_loss(feat_cs2[l], feat_s[l])
|
| 72 |
+
moment_loss2 += self.args.w_moment2 * calc_moment_loss(feat_cs2[l], feat_s[l])
|
| 73 |
+
|
| 74 |
+
content_loss3 += self.args.w_content3 * calc_ss_loss(feat_cs3[l], feat_c[l])
|
| 75 |
+
remd_loss3 += self.args.w_remd3 * calc_remd_loss(feat_cs3[l], feat_s[l])
|
| 76 |
+
moment_loss3 += self.args.w_moment3 * calc_moment_loss(feat_cs3[l], feat_s[l])
|
| 77 |
+
|
| 78 |
+
loss_rec += 0.5 * calc_mse_loss(feat_rc[l], feat_c[l]) + 0.5 * calc_mse_loss(feat_rs[l], feat_s[l])
|
| 79 |
+
loss_rec += 25 * calc_mse_loss(Irc, Ic)
|
| 80 |
+
loss_rec += 25 * calc_mse_loss(Irs, Is)
|
| 81 |
+
|
| 82 |
+
if self.args.color_on:
|
| 83 |
+
color_loss1 += self.args.w_color1 * calc_histogram_loss(Ics1, Is, self.hist)
|
| 84 |
+
color_loss2 += self.args.w_color2 * calc_histogram_loss(Ics2, Is, self.hist)
|
| 85 |
+
color_loss3 += self.args.w_color3 * calc_histogram_loss(Ics3, Is, self.hist)
|
| 86 |
+
|
| 87 |
+
loss1 = (content_loss1+remd_loss1+moment_loss1+color_loss1)/(self.args.w_content1+self.args.w_remd1+self.args.w_moment1+self.args.w_color1)
|
| 88 |
+
loss2 = (content_loss2+remd_loss2+moment_loss2+color_loss2)/(self.args.w_content2+self.args.w_remd2+self.args.w_moment2+self.args.w_color2)
|
| 89 |
+
loss3 = (content_loss3+remd_loss3+moment_loss3+color_loss3)/(self.args.w_content3+self.args.w_remd3+self.args.w_moment3+self.args.w_color3)
|
| 90 |
+
loss = loss1 + loss2 + loss3 + loss_rec
|
| 91 |
+
return loss
|
| 92 |
+
else:
|
| 93 |
+
return Ics3
|
| 94 |
+
|
| 95 |
+
def forward_vgg(self, x):
|
| 96 |
+
relu1_1 = self.vgg[:4](x)
|
| 97 |
+
relu2_1 = self.vgg[4:11](relu1_1)
|
| 98 |
+
relu3_1 = self.vgg[11:18](relu2_1)
|
| 99 |
+
relu4_1 = self.vgg[18:31](relu3_1)
|
| 100 |
+
relu5_1 = self.vgg[31:44](relu4_1)
|
| 101 |
+
return [relu1_1, relu2_1, relu3_1, relu4_1, relu5_1]
|
| 102 |
+
|
| 103 |
+
def save_ckpts(self):
|
| 104 |
+
torch.save(self.align1.state_dict(), "./checkpoints/PAMA1.pth")
|
| 105 |
+
torch.save(self.align2.state_dict(), "./checkpoints/PAMA2.pth")
|
| 106 |
+
torch.save(self.align3.state_dict(), "./checkpoints/PAMA3.pth")
|
| 107 |
+
torch.save(self.decoder.state_dict(), "./checkpoints/decoder.pth")
|
| 108 |
+
|
| 109 |
+
#---------------------------------------------------------------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
vgg19 = nn.Sequential(
|
| 112 |
+
nn.Conv2d(3, 3, (1, 1)),
|
| 113 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 114 |
+
nn.Conv2d(3, 64, (3, 3)),
|
| 115 |
+
nn.ReLU(), # relu1-1
|
| 116 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 117 |
+
nn.Conv2d(64, 64, (3, 3)),
|
| 118 |
+
nn.ReLU(), # relu1-2
|
| 119 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
| 120 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 121 |
+
nn.Conv2d(64, 128, (3, 3)),
|
| 122 |
+
nn.ReLU(), # relu2-1
|
| 123 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 124 |
+
nn.Conv2d(128, 128, (3, 3)),
|
| 125 |
+
nn.ReLU(), # relu2-2
|
| 126 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
| 127 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 128 |
+
nn.Conv2d(128, 256, (3, 3)),
|
| 129 |
+
nn.ReLU(), # relu3-1
|
| 130 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 131 |
+
nn.Conv2d(256, 256, (3, 3)),
|
| 132 |
+
nn.ReLU(), # relu3-2
|
| 133 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 134 |
+
nn.Conv2d(256, 256, (3, 3)),
|
| 135 |
+
nn.ReLU(), # relu3-3
|
| 136 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 137 |
+
nn.Conv2d(256, 256, (3, 3)),
|
| 138 |
+
nn.ReLU(), # relu3-4
|
| 139 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
| 140 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 141 |
+
nn.Conv2d(256, 512, (3, 3)),
|
| 142 |
+
nn.ReLU(), # relu4-1,
|
| 143 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 144 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 145 |
+
nn.ReLU(), # relu4-2
|
| 146 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 147 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 148 |
+
nn.ReLU(), # relu4-3
|
| 149 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 150 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 151 |
+
nn.ReLU(), # relu4-4
|
| 152 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
| 153 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 154 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 155 |
+
nn.ReLU(), # relu5-1
|
| 156 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 157 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 158 |
+
nn.ReLU(), # relu5-2
|
| 159 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 160 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 161 |
+
nn.ReLU(), # relu5-3
|
| 162 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 163 |
+
nn.Conv2d(512, 512, (3, 3)),
|
| 164 |
+
nn.ReLU() # relu5-4
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
#---------------------------------------------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
decoder = nn.Sequential(
|
| 170 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 171 |
+
nn.Conv2d(512, 256, (3, 3)),
|
| 172 |
+
nn.ReLU(), #relu4_1
|
| 173 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 174 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 175 |
+
nn.Conv2d(256, 256, (3, 3)),
|
| 176 |
+
nn.ReLU(),
|
| 177 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 178 |
+
nn.Conv2d(256, 256, (3, 3)),
|
| 179 |
+
nn.ReLU(),
|
| 180 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 181 |
+
nn.Conv2d(256, 256, (3, 3)),
|
| 182 |
+
nn.ReLU(),
|
| 183 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 184 |
+
nn.Conv2d(256, 128, (3, 3)),
|
| 185 |
+
nn.ReLU(), #relu3_1
|
| 186 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 187 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 188 |
+
nn.Conv2d(128, 128, (3, 3)),
|
| 189 |
+
nn.ReLU(),
|
| 190 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 191 |
+
nn.Conv2d(128, 64, (3, 3)),
|
| 192 |
+
nn.ReLU(), #relu2_1
|
| 193 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 194 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 195 |
+
nn.Conv2d(64, 64, (3, 3)),
|
| 196 |
+
nn.ReLU(), #relu1_1
|
| 197 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
| 198 |
+
nn.Conv2d(64, 3, (3, 3)),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
#---------------------------------------------------------------------------------------------------------------
|
| 202 |
+
|
| 203 |
+
class AttentionUnit(nn.Module):
|
| 204 |
+
def __init__(self, channels):
|
| 205 |
+
super(AttentionUnit, self).__init__()
|
| 206 |
+
self.relu6 = nn.ReLU6()
|
| 207 |
+
self.f = nn.Conv2d(channels, channels//2, (1, 1))
|
| 208 |
+
self.g = nn.Conv2d(channels, channels//2, (1, 1))
|
| 209 |
+
self.h = nn.Conv2d(channels, channels//2, (1, 1))
|
| 210 |
+
|
| 211 |
+
self.out_conv = nn.Conv2d(channels//2, channels, (1, 1))
|
| 212 |
+
self.softmax = nn.Softmax(dim = -1)
|
| 213 |
+
|
| 214 |
+
def forward(self, Fc, Fs):
|
| 215 |
+
B, C, H, W = Fc.shape
|
| 216 |
+
f_Fc = self.relu6(self.f(mean_variance_norm(Fc)))
|
| 217 |
+
g_Fs = self.relu6(self.g(mean_variance_norm(Fs)))
|
| 218 |
+
h_Fs = self.relu6(self.h(Fs))
|
| 219 |
+
f_Fc = f_Fc.view(f_Fc.shape[0], f_Fc.shape[1], -1).permute(0, 2, 1)
|
| 220 |
+
g_Fs = g_Fs.view(g_Fs.shape[0], g_Fs.shape[1], -1)
|
| 221 |
+
|
| 222 |
+
Attention = self.softmax(torch.bmm(f_Fc, g_Fs))
|
| 223 |
+
|
| 224 |
+
h_Fs = h_Fs.view(h_Fs.shape[0], h_Fs.shape[1], -1)
|
| 225 |
+
|
| 226 |
+
Fcs = torch.bmm(h_Fs, Attention.permute(0, 2, 1))
|
| 227 |
+
Fcs = Fcs.view(B, C//2, H, W)
|
| 228 |
+
Fcs = self.relu6(self.out_conv(Fcs))
|
| 229 |
+
|
| 230 |
+
return Fcs
|
| 231 |
+
|
| 232 |
+
class FuseUnit(nn.Module):
|
| 233 |
+
def __init__(self, channels):
|
| 234 |
+
super(FuseUnit, self).__init__()
|
| 235 |
+
self.proj1 = nn.Conv2d(2*channels, channels, (1, 1))
|
| 236 |
+
self.proj2 = nn.Conv2d(channels, channels, (1, 1))
|
| 237 |
+
self.proj3 = nn.Conv2d(channels, channels, (1, 1))
|
| 238 |
+
|
| 239 |
+
self.fuse1x = nn.Conv2d(channels, 1, (1, 1), stride = 1)
|
| 240 |
+
self.fuse3x = nn.Conv2d(channels, 1, (3, 3), stride = 1)
|
| 241 |
+
self.fuse5x = nn.Conv2d(channels, 1, (5, 5), stride = 1)
|
| 242 |
+
|
| 243 |
+
self.pad3x = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 244 |
+
self.pad5x = nn.ReflectionPad2d((2, 2, 2, 2))
|
| 245 |
+
self.sigmoid = nn.Sigmoid()
|
| 246 |
+
|
| 247 |
+
def forward(self, F1, F2):
|
| 248 |
+
Fcat = self.proj1(torch.cat((F1, F2), dim=1))
|
| 249 |
+
F1 = self.proj2(F1)
|
| 250 |
+
F2 = self.proj3(F2)
|
| 251 |
+
|
| 252 |
+
fusion1 = self.sigmoid(self.fuse1x(Fcat))
|
| 253 |
+
fusion3 = self.sigmoid(self.fuse3x(self.pad3x(Fcat)))
|
| 254 |
+
fusion5 = self.sigmoid(self.fuse5x(self.pad5x(Fcat)))
|
| 255 |
+
fusion = (fusion1 + fusion3 + fusion5) / 3
|
| 256 |
+
|
| 257 |
+
return torch.clamp(fusion, min=0, max=1.0)*F1 + torch.clamp(1 - fusion, min=0, max=1.0)*F2
|
| 258 |
+
|
| 259 |
+
class PAMA(nn.Module):
|
| 260 |
+
def __init__(self, channels):
|
| 261 |
+
super(PAMA, self).__init__()
|
| 262 |
+
self.conv_in = nn.Conv2d(channels, channels, (3, 3), stride=1)
|
| 263 |
+
self.attn = AttentionUnit(channels)
|
| 264 |
+
self.fuse = FuseUnit(channels)
|
| 265 |
+
self.conv_out = nn.Conv2d(channels, channels, (3, 3), stride=1)
|
| 266 |
+
|
| 267 |
+
self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 268 |
+
self.relu6 = nn.ReLU6()
|
| 269 |
+
|
| 270 |
+
def forward(self, Fc, Fs):
|
| 271 |
+
Fc = self.relu6(self.conv_in(self.pad(Fc)))
|
| 272 |
+
Fs = self.relu6(self.conv_in(self.pad(Fs)))
|
| 273 |
+
Fcs = self.attn(Fc, Fs)
|
| 274 |
+
Fcs = self.relu6(self.conv_out(self.pad(Fcs)))
|
| 275 |
+
Fcs = self.fuse(Fc, Fcs)
|
| 276 |
+
|
| 277 |
+
return Fcs
|
| 278 |
+
|
| 279 |
+
#---------------------------------------------------------------------------------------------------------------
|
| 280 |
+
|
| 281 |
+
|
style_trsfer.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.utils import make_grid
|
| 4 |
+
from PIL import Image, ImageFile
|
| 5 |
+
from net import Net
|
| 6 |
+
from utils import DEVICE, test_transform
|
| 7 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 8 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def style_transfer_method(content_image,style_img):
|
| 13 |
+
main_parser = argparse.ArgumentParser(description="main parser")
|
| 14 |
+
subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand")
|
| 15 |
+
|
| 16 |
+
main_parser.add_argument("--pretrained", type=bool, default=True,
|
| 17 |
+
help="whether to use the pre-trained checkpoints")
|
| 18 |
+
main_parser.add_argument("--requires_grad", type=bool, default=True,
|
| 19 |
+
help="set to True if the model requires model gradient")
|
| 20 |
+
|
| 21 |
+
train_parser = subparsers.add_parser("train", help="training mode parser")
|
| 22 |
+
train_parser.add_argument("--training", type=bool, default=True)
|
| 23 |
+
train_parser.add_argument("--iterations", type=int, default=60000,
|
| 24 |
+
help="total training epochs (default: 160000)")
|
| 25 |
+
train_parser.add_argument("--batch_size", type=int, default=2,
|
| 26 |
+
help="training batch size (default: 8)")
|
| 27 |
+
train_parser.add_argument("--num_workers", type=int, default=2,
|
| 28 |
+
help="iterator threads (default: 8)")
|
| 29 |
+
train_parser.add_argument("--lr", type=float, default=1e-4, help="the learning rate during training (default: 1e-4)")
|
| 30 |
+
train_parser.add_argument("--content_folder", type=str, required = True,
|
| 31 |
+
help="the root of content images, the path should point to a folder")
|
| 32 |
+
train_parser.add_argument("--style_folder", type=str, required = True,
|
| 33 |
+
help="the root of style images, the path should point to a folder")
|
| 34 |
+
train_parser.add_argument("--log_interval", type=int, default=10000,
|
| 35 |
+
help="number of images after which the training loss is logged (default: 20000)")
|
| 36 |
+
|
| 37 |
+
train_parser.add_argument("--w_content1", type=float, default=12, help="the stage1 content loss weight")
|
| 38 |
+
train_parser.add_argument("--w_content2", type=float, default=9, help="the stage2 content loss weight")
|
| 39 |
+
train_parser.add_argument("--w_content3", type=float, default=7, help="the stage3 content loss weight")
|
| 40 |
+
train_parser.add_argument("--w_remd1", type=float, default=2, help="the stage1 remd loss weight")
|
| 41 |
+
train_parser.add_argument("--w_remd2", type=float, default=2, help="the stage2 remd loss weight")
|
| 42 |
+
train_parser.add_argument("--w_remd3", type=float, default=2, help="the stage3 remd loss weight")
|
| 43 |
+
train_parser.add_argument("--w_moment1", type=float, default=2, help="the stage1 moment loss weight")
|
| 44 |
+
train_parser.add_argument("--w_moment2", type=float, default=2, help="the stage2 moment loss weight")
|
| 45 |
+
train_parser.add_argument("--w_moment3", type=float, default=2, help="the stage3 moment loss weight")
|
| 46 |
+
train_parser.add_argument("--color_on", type=str, default=True, help="turn on the color loss")
|
| 47 |
+
train_parser.add_argument("--w_color1", type=float, default=0.25, help="the stage1 color loss weight")
|
| 48 |
+
train_parser.add_argument("--w_color2", type=float, default=0.5, help="the stage2 color loss weight")
|
| 49 |
+
train_parser.add_argument("--w_color3", type=float, default=1, help="the stage3 color loss weight")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
eval_parser = subparsers.add_parser("eval", help="evaluation mode parser")
|
| 53 |
+
eval_parser.add_argument("--training", type=bool, default=False)
|
| 54 |
+
eval_parser.add_argument("--run_folder", type=bool, default=False)
|
| 55 |
+
|
| 56 |
+
args = main_parser.parse_args()
|
| 57 |
+
|
| 58 |
+
args.training = False
|
| 59 |
+
|
| 60 |
+
model = Net(args)
|
| 61 |
+
model.eval()
|
| 62 |
+
model = model.to(DEVICE)
|
| 63 |
+
|
| 64 |
+
tf = test_transform()
|
| 65 |
+
|
| 66 |
+
Ic = tf(content_image).to(DEVICE)
|
| 67 |
+
Is = tf(Image.fromarray(style_img)).to(DEVICE)
|
| 68 |
+
|
| 69 |
+
Ic = Ic.unsqueeze(dim=0)
|
| 70 |
+
Is = Is.unsqueeze(dim=0)
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
Ics = model(Ic, Is)
|
| 74 |
+
|
| 75 |
+
grid = make_grid(Ics[0])
|
| 76 |
+
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
|
| 77 |
+
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 78 |
+
im = Image.fromarray(ndarr)
|
| 79 |
+
|
| 80 |
+
return im
|
utils.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
|
| 10 |
+
DEVICE = 'cuda'
|
| 11 |
+
mse = nn.MSELoss()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def calc_histogram_loss(A, B, histogram_block):
|
| 15 |
+
input_hist = histogram_block(A)
|
| 16 |
+
target_hist = histogram_block(B)
|
| 17 |
+
histogram_loss = (1/np.sqrt(2.0) * (torch.sqrt(torch.sum(
|
| 18 |
+
torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) /
|
| 19 |
+
input_hist.shape[0])
|
| 20 |
+
|
| 21 |
+
return histogram_loss
|
| 22 |
+
|
| 23 |
+
# B, C, H, W; mean var on HW
|
| 24 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 25 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
| 26 |
+
size = feat.size()
|
| 27 |
+
assert (len(size) == 4)
|
| 28 |
+
N, C = size[:2]
|
| 29 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
| 30 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
| 31 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| 32 |
+
return feat_mean, feat_std
|
| 33 |
+
|
| 34 |
+
def mean_variance_norm(feat):
|
| 35 |
+
size = feat.size()
|
| 36 |
+
mean, std = calc_mean_std(feat)
|
| 37 |
+
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
|
| 38 |
+
return normalized_feat
|
| 39 |
+
|
| 40 |
+
def train_transform():
|
| 41 |
+
transform_list = [
|
| 42 |
+
transforms.Resize(size=512),
|
| 43 |
+
transforms.RandomCrop(256),
|
| 44 |
+
transforms.ToTensor()
|
| 45 |
+
]
|
| 46 |
+
return transforms.Compose(transform_list)
|
| 47 |
+
|
| 48 |
+
def test_transform():
|
| 49 |
+
transform_list = []
|
| 50 |
+
transform_list.append(transforms.Resize(size=(512)))
|
| 51 |
+
transform_list.append(transforms.ToTensor())
|
| 52 |
+
transform = transforms.Compose(transform_list)
|
| 53 |
+
return transform
|
| 54 |
+
|
| 55 |
+
# https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/7
|
| 56 |
+
def plot_grad_flow(named_parameters):
|
| 57 |
+
'''Plots the gradients flowing through different layers in the net during training.
|
| 58 |
+
Can be used for checking for possible gradient vanishing / exploding problems.
|
| 59 |
+
|
| 60 |
+
Usage: Plug this function in Trainer class after loss.backwards() as
|
| 61 |
+
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
|
| 62 |
+
ave_grads = []
|
| 63 |
+
max_grads= []
|
| 64 |
+
layers = []
|
| 65 |
+
for n, p in named_parameters:
|
| 66 |
+
if(p.requires_grad) and ("bias" not in n):
|
| 67 |
+
layers.append(n)
|
| 68 |
+
ave_grads.append(p.grad.abs().mean())
|
| 69 |
+
max_grads.append(p.grad.abs().max())
|
| 70 |
+
print('-'*82)
|
| 71 |
+
print(n, p.grad.abs().mean(), p.grad.abs().max())
|
| 72 |
+
print('-'*82)
|
| 73 |
+
|
| 74 |
+
def InfiniteSampler(n):
|
| 75 |
+
# i = 0
|
| 76 |
+
i = n - 1
|
| 77 |
+
order = np.random.permutation(n)
|
| 78 |
+
while True:
|
| 79 |
+
yield order[i]
|
| 80 |
+
i += 1
|
| 81 |
+
if i >= n:
|
| 82 |
+
np.random.seed()
|
| 83 |
+
order = np.random.permutation(n)
|
| 84 |
+
i = 0
|
| 85 |
+
|
| 86 |
+
class InfiniteSamplerWrapper(data.sampler.Sampler):
|
| 87 |
+
def __init__(self, data_source):
|
| 88 |
+
self.num_samples = len(data_source)
|
| 89 |
+
|
| 90 |
+
def __iter__(self):
|
| 91 |
+
return iter(InfiniteSampler(self.num_samples))
|
| 92 |
+
|
| 93 |
+
def __len__(self):
|
| 94 |
+
return 2 ** 31
|
| 95 |
+
|
| 96 |
+
class FlatFolderDataset(data.Dataset):
|
| 97 |
+
def __init__(self, root, transform):
|
| 98 |
+
super(FlatFolderDataset, self).__init__()
|
| 99 |
+
self.root = root
|
| 100 |
+
self.paths = os.listdir(self.root)
|
| 101 |
+
self.transform = transform
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, index):
|
| 104 |
+
path = self.paths[index]
|
| 105 |
+
img = Image.open(os.path.join(self.root, path)).convert('RGB')
|
| 106 |
+
img = self.transform(img)
|
| 107 |
+
return img
|
| 108 |
+
|
| 109 |
+
def __len__(self):
|
| 110 |
+
return len(self.paths)
|
| 111 |
+
|
| 112 |
+
def name(self):
|
| 113 |
+
return 'FlatFolderDataset'
|
| 114 |
+
|
| 115 |
+
def adjust_learning_rate(optimizer, iteration_count, args):
|
| 116 |
+
"""Imitating the original implementation"""
|
| 117 |
+
lr = args.lr / (1.0 + 5e-5 * iteration_count)
|
| 118 |
+
for param_group in optimizer.param_groups:
|
| 119 |
+
param_group['lr'] = lr
|
| 120 |
+
|
| 121 |
+
def cosine_dismat(A, B):
|
| 122 |
+
A = A.view(A.shape[0], A.shape[1], -1)
|
| 123 |
+
B = B.view(B.shape[0], B.shape[1], -1)
|
| 124 |
+
|
| 125 |
+
A_norm = torch.sqrt((A**2).sum(1))
|
| 126 |
+
B_norm = torch.sqrt((B**2).sum(1))
|
| 127 |
+
|
| 128 |
+
A = (A/A_norm.unsqueeze(dim=1).expand(A.shape)).permute(0,2,1)
|
| 129 |
+
B = (B/B_norm.unsqueeze(dim=1).expand(B.shape))
|
| 130 |
+
dismat = 1.-torch.bmm(A, B)
|
| 131 |
+
|
| 132 |
+
return dismat
|
| 133 |
+
|
| 134 |
+
def calc_remd_loss(A, B):
|
| 135 |
+
C = cosine_dismat(A, B)
|
| 136 |
+
m1, _ = C.min(1)
|
| 137 |
+
m2, _ = C.min(2)
|
| 138 |
+
|
| 139 |
+
remd = torch.max(m1.mean(), m2.mean())
|
| 140 |
+
|
| 141 |
+
return remd
|
| 142 |
+
|
| 143 |
+
def calc_ss_loss(A, B):
|
| 144 |
+
MA = cosine_dismat(A, A)
|
| 145 |
+
MB = cosine_dismat(B, B)
|
| 146 |
+
Lself_similarity = torch.abs(MA-MB).mean()
|
| 147 |
+
|
| 148 |
+
return Lself_similarity
|
| 149 |
+
|
| 150 |
+
def calc_moment_loss(A, B):
|
| 151 |
+
A = A.view(A.shape[0], A.shape[1], -1)
|
| 152 |
+
B = B.view(B.shape[0], B.shape[1], -1)
|
| 153 |
+
|
| 154 |
+
mu_a = torch.mean(A, 1, keepdim=True)
|
| 155 |
+
mu_b = torch.mean(B, 1, keepdim=True)
|
| 156 |
+
mu_d = torch.abs(mu_a - mu_b).mean()
|
| 157 |
+
|
| 158 |
+
A_c = A - mu_a
|
| 159 |
+
B_c = B - mu_b
|
| 160 |
+
cov_a = torch.bmm(A_c, A_c.permute(0,2,1)) / (A.shape[2]-1)
|
| 161 |
+
cov_b = torch.bmm(B_c, B_c.permute(0,2,1)) / (B.shape[2]-1)
|
| 162 |
+
cov_d = torch.abs(cov_a - cov_b).mean()
|
| 163 |
+
loss = mu_d + cov_d
|
| 164 |
+
return loss
|
| 165 |
+
|
| 166 |
+
def calc_mse_loss(A, B):
|
| 167 |
+
return mse(A, B)
|
| 168 |
+
|