Spaces:
Running
Running
init
Browse files- .gitignore +5 -0
- README.md +1 -12
- app.py +73 -5
- models/__init__.py +0 -0
- models/basic_layer.py +429 -0
- models/c2pDis.py +313 -0
- models/c2pGen.py +266 -0
- models/networks.py +244 -0
- models/p2cGen.py +76 -0
- pixelization.py +151 -0
- reference.png +0 -0
- requirements.txt +5 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
myvenv
|
| 2 |
+
myvenv/**/*
|
| 3 |
+
__pycache__
|
| 4 |
+
flagged
|
| 5 |
+
*.pth
|
README.md
CHANGED
|
@@ -1,12 +1 @@
|
|
| 1 |
-
|
| 2 |
-
title: Pixelization
|
| 3 |
-
emoji: 🚀
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: gray
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 3.16.2
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,8 +1,76 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import functools
|
| 3 |
+
from pixelization import Model
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import huggingface_hub
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
+
TOKEN = "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"
|
|
|
|
| 10 |
|
| 11 |
+
def parse_args() -> argparse.Namespace:
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--theme', type=str, default='default')
|
| 14 |
+
parser.add_argument('--live', action='store_true')
|
| 15 |
+
parser.add_argument('--share', action='store_true')
|
| 16 |
+
parser.add_argument('--port', type=int)
|
| 17 |
+
parser.add_argument('--disable-queue',
|
| 18 |
+
dest='enable_queue',
|
| 19 |
+
action='store_false')
|
| 20 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
| 21 |
+
return parser.parse_args()
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
args = parse_args()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# DL MODEL
|
| 28 |
+
# PIX_MODEL
|
| 29 |
+
os.environ['PIX_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "pixelart_vgg19.pth", token=TOKEN);
|
| 30 |
+
# NET_MODEL
|
| 31 |
+
os.environ['NET_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "160_net_G_A.pth", token=TOKEN);
|
| 32 |
+
# ALIAS_MODEL
|
| 33 |
+
os.environ['ALIAS_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "alias_net.pth", token=TOKEN);
|
| 34 |
+
|
| 35 |
+
# # For local testing
|
| 36 |
+
# # PIX_MODEL
|
| 37 |
+
# os.environ['PIX_MODEL'] = "pixelart_vgg19.pth"
|
| 38 |
+
# # NET_MODEL
|
| 39 |
+
# os.environ['NET_MODEL'] = "160_net_G_A.pth"
|
| 40 |
+
# # ALIAS_MODEL
|
| 41 |
+
# os.environ['ALIAS_MODEL'] = "alias_net.pth"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
use_cpu = True
|
| 45 |
+
m = Model(device = "cpu" if use_cpu else "cuda")
|
| 46 |
+
m.load()
|
| 47 |
+
|
| 48 |
+
# To use GPU: Change use_cpu to false, and checkout my comment on networks.py at line 107 & 108
|
| 49 |
+
# + Use torch with cuda support (Change in requirements.txt)
|
| 50 |
+
|
| 51 |
+
gr.Interface(m.pixelize_modified,
|
| 52 |
+
[
|
| 53 |
+
gr.components.Image(type='pil', label='Input'),
|
| 54 |
+
gr.components.Slider(minimum=1, maximum=16, value=4, step=1, label='Pixel Size'),
|
| 55 |
+
gr.components.Checkbox(True, label="Upscale after")
|
| 56 |
+
],
|
| 57 |
+
gr.components.Image(type='pil', label='Output'),
|
| 58 |
+
title="Pixelization",
|
| 59 |
+
description='''
|
| 60 |
+
Demo for [WuZongWei6/Pixelization](https://github.com/WuZongWei6/Pixelization)
|
| 61 |
+
|
| 62 |
+
Models that are used is private to comply with License.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
''',
|
| 66 |
+
theme=args.theme,
|
| 67 |
+
allow_flagging=args.allow_flagging,
|
| 68 |
+
live=args.live,
|
| 69 |
+
).launch(
|
| 70 |
+
enable_queue=args.enable_queue,
|
| 71 |
+
server_port=args.port,
|
| 72 |
+
share=args.share,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if __name__ == '__main__':
|
| 76 |
+
main()
|
models/__init__.py
ADDED
|
File without changes
|
models/basic_layer.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class ModulationConvBlock(nn.Module):
|
| 7 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride=1,
|
| 8 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
| 9 |
+
super(ModulationConvBlock, self).__init__()
|
| 10 |
+
self.in_c = input_dim
|
| 11 |
+
self.out_c = output_dim
|
| 12 |
+
self.ksize = kernel_size
|
| 13 |
+
self.stride = 1
|
| 14 |
+
self.padding = kernel_size // 2
|
| 15 |
+
|
| 16 |
+
self.eps = 1e-8
|
| 17 |
+
weight_shape = (output_dim, input_dim, kernel_size, kernel_size)
|
| 18 |
+
fan_in = kernel_size * kernel_size *input_dim
|
| 19 |
+
wscale = 1.0/np.sqrt(fan_in)
|
| 20 |
+
|
| 21 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
| 22 |
+
self.wscale = wscale
|
| 23 |
+
|
| 24 |
+
self.bias = nn.Parameter(torch.zeros(output_dim))
|
| 25 |
+
|
| 26 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 27 |
+
self.activate_scale = np.sqrt(2.0)
|
| 28 |
+
|
| 29 |
+
def forward(self, x, code):
|
| 30 |
+
batch,in_channel,height,width = x.shape
|
| 31 |
+
weight = self.weight * self.wscale
|
| 32 |
+
_weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c)
|
| 33 |
+
_weight = _weight * code.view(batch, 1, 1, self.in_c, 1)
|
| 34 |
+
# demodulation
|
| 35 |
+
_weight_norm = torch.sqrt(torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps)
|
| 36 |
+
_weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c)
|
| 37 |
+
# fused_modulate
|
| 38 |
+
x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3])
|
| 39 |
+
weight = _weight.permute(1, 2, 3, 0, 4).reshape(
|
| 40 |
+
self.ksize, self.ksize, self.in_c, batch * self.out_c)
|
| 41 |
+
# not use_conv2d_transpose
|
| 42 |
+
weight = weight.permute(3, 2, 0, 1)
|
| 43 |
+
x = F.conv2d(x,
|
| 44 |
+
weight=weight,
|
| 45 |
+
bias=None,
|
| 46 |
+
stride=self.stride,
|
| 47 |
+
padding=self.padding,
|
| 48 |
+
groups=(batch if True else 1))
|
| 49 |
+
|
| 50 |
+
if True:#self.fused_modulate:
|
| 51 |
+
x = x.view(batch, self.out_c, height, width)
|
| 52 |
+
x = x+self.bias.view(1,-1,1,1)
|
| 53 |
+
x = self.activate(x)*self.activate_scale
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class AliasConvBlock(nn.Module):
|
| 58 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
| 59 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
| 60 |
+
super(AliasConvBlock, self).__init__()
|
| 61 |
+
self.use_bias = True
|
| 62 |
+
# initialize padding
|
| 63 |
+
if pad_type == 'reflect':
|
| 64 |
+
self.pad = nn.ReflectionPad2d(padding)
|
| 65 |
+
elif pad_type == 'replicate':
|
| 66 |
+
self.pad = nn.ReplicationPad2d(padding)
|
| 67 |
+
elif pad_type == 'zero':
|
| 68 |
+
self.pad = nn.ZeroPad2d(padding)
|
| 69 |
+
else:
|
| 70 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
| 71 |
+
|
| 72 |
+
# initialize normalization
|
| 73 |
+
norm_dim = output_dim
|
| 74 |
+
if norm == 'bn':
|
| 75 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
| 76 |
+
elif norm == 'in':
|
| 77 |
+
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
| 78 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
| 79 |
+
elif norm == 'ln':
|
| 80 |
+
self.norm = LayerNorm(norm_dim)
|
| 81 |
+
elif norm == 'adain':
|
| 82 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
| 83 |
+
elif norm == 'none' or norm == 'sn':
|
| 84 |
+
self.norm = None
|
| 85 |
+
else:
|
| 86 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
| 87 |
+
|
| 88 |
+
# initialize activation
|
| 89 |
+
if activation == 'relu':
|
| 90 |
+
self.activation = nn.ReLU(inplace=True)
|
| 91 |
+
elif activation == 'lrelu':
|
| 92 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
| 93 |
+
elif activation == 'prelu':
|
| 94 |
+
self.activation = nn.PReLU()
|
| 95 |
+
elif activation == 'selu':
|
| 96 |
+
self.activation = nn.SELU(inplace=True)
|
| 97 |
+
elif activation == 'tanh':
|
| 98 |
+
self.activation = nn.Tanh()
|
| 99 |
+
elif activation == 'none':
|
| 100 |
+
self.activation = None
|
| 101 |
+
else:
|
| 102 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
| 103 |
+
|
| 104 |
+
# initialize convolution
|
| 105 |
+
if norm == 'sn':
|
| 106 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
x = self.conv(self.pad(x))
|
| 113 |
+
if self.norm:
|
| 114 |
+
x = self.norm(x)
|
| 115 |
+
if self.activation:
|
| 116 |
+
x = self.activation(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
class AliasResBlocks(nn.Module):
|
| 120 |
+
def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
|
| 121 |
+
super(AliasResBlocks, self).__init__()
|
| 122 |
+
self.model = []
|
| 123 |
+
for i in range(num_blocks):
|
| 124 |
+
self.model += [AliasResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
|
| 125 |
+
self.model = nn.Sequential(*self.model)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
return self.model(x)
|
| 129 |
+
class AliasResBlock(nn.Module):
|
| 130 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
| 131 |
+
super(AliasResBlock, self).__init__()
|
| 132 |
+
|
| 133 |
+
model = []
|
| 134 |
+
model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
| 135 |
+
model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
| 136 |
+
self.model = nn.Sequential(*model)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
residual = x
|
| 140 |
+
out = self.model(x)
|
| 141 |
+
out += residual
|
| 142 |
+
return out
|
| 143 |
+
##################################################################################
|
| 144 |
+
# Sequential Models
|
| 145 |
+
##################################################################################
|
| 146 |
+
class ResBlocks(nn.Module):
|
| 147 |
+
def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
|
| 148 |
+
super(ResBlocks, self).__init__()
|
| 149 |
+
self.model = []
|
| 150 |
+
for i in range(num_blocks):
|
| 151 |
+
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
|
| 152 |
+
self.model = nn.Sequential(*self.model)
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
return self.model(x)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MLP(nn.Module):
|
| 159 |
+
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
| 160 |
+
super(MLP, self).__init__()
|
| 161 |
+
self.model = []
|
| 162 |
+
self.model += [linearBlock(input_dim, input_dim, norm=norm, activation=activ)]
|
| 163 |
+
self.model += [linearBlock(input_dim, dim, norm=norm, activation=activ)]
|
| 164 |
+
for i in range(n_blk - 2):
|
| 165 |
+
self.model += [linearBlock(dim, dim, norm=norm, activation=activ)]
|
| 166 |
+
self.model += [linearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
| 167 |
+
self.model = nn.Sequential(*self.model)
|
| 168 |
+
|
| 169 |
+
# def forward(self, style0, style1, a=0):
|
| 170 |
+
# return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3](
|
| 171 |
+
# style1.view(style1.size(0), -1)))
|
| 172 |
+
def forward(self, style0, style1=None, a=0):
|
| 173 |
+
style1 = style0
|
| 174 |
+
return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3](
|
| 175 |
+
style1.view(style1.size(0), -1)))
|
| 176 |
+
##################################################################################
|
| 177 |
+
# Basic Blocks
|
| 178 |
+
##################################################################################
|
| 179 |
+
class ResBlock(nn.Module):
|
| 180 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
| 181 |
+
super(ResBlock, self).__init__()
|
| 182 |
+
|
| 183 |
+
model = []
|
| 184 |
+
model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
| 185 |
+
model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
| 186 |
+
self.model = nn.Sequential(*model)
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
residual = x
|
| 190 |
+
out = self.model(x)
|
| 191 |
+
out += residual
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ConvBlock(nn.Module):
|
| 196 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
| 197 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
| 198 |
+
super(ConvBlock, self).__init__()
|
| 199 |
+
self.use_bias = True
|
| 200 |
+
# initialize padding
|
| 201 |
+
if pad_type == 'reflect':
|
| 202 |
+
self.pad = nn.ReflectionPad2d(padding)
|
| 203 |
+
elif pad_type == 'replicate':
|
| 204 |
+
self.pad = nn.ReplicationPad2d(padding)
|
| 205 |
+
elif pad_type == 'zero':
|
| 206 |
+
self.pad = nn.ZeroPad2d(padding)
|
| 207 |
+
else:
|
| 208 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
| 209 |
+
|
| 210 |
+
# initialize normalization
|
| 211 |
+
norm_dim = output_dim
|
| 212 |
+
if norm == 'bn':
|
| 213 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
| 214 |
+
elif norm == 'in':
|
| 215 |
+
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
| 216 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
| 217 |
+
elif norm == 'ln':
|
| 218 |
+
self.norm = LayerNorm(norm_dim)
|
| 219 |
+
elif norm == 'adain':
|
| 220 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
| 221 |
+
elif norm == 'none' or norm == 'sn':
|
| 222 |
+
self.norm = None
|
| 223 |
+
else:
|
| 224 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
| 225 |
+
|
| 226 |
+
# initialize activation
|
| 227 |
+
if activation == 'relu':
|
| 228 |
+
self.activation = nn.ReLU(inplace=True)
|
| 229 |
+
elif activation == 'lrelu':
|
| 230 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
| 231 |
+
elif activation == 'prelu':
|
| 232 |
+
self.activation = nn.PReLU()
|
| 233 |
+
elif activation == 'selu':
|
| 234 |
+
self.activation = nn.SELU(inplace=True)
|
| 235 |
+
elif activation == 'tanh':
|
| 236 |
+
self.activation = nn.Tanh()
|
| 237 |
+
elif activation == 'none':
|
| 238 |
+
self.activation = None
|
| 239 |
+
else:
|
| 240 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
| 241 |
+
|
| 242 |
+
# initialize convolution
|
| 243 |
+
if norm == 'sn':
|
| 244 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
| 245 |
+
|
| 246 |
+
else:
|
| 247 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
x = self.conv(self.pad(x))
|
| 251 |
+
if self.norm:
|
| 252 |
+
x = self.norm(x)
|
| 253 |
+
if self.activation:
|
| 254 |
+
x = self.activation(x)
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
class linearBlock(nn.Module):
|
| 258 |
+
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
| 259 |
+
super(linearBlock, self).__init__()
|
| 260 |
+
use_bias = True
|
| 261 |
+
# initialize fully connected layer
|
| 262 |
+
if norm == 'sn':
|
| 263 |
+
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
| 264 |
+
else:
|
| 265 |
+
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
| 266 |
+
|
| 267 |
+
# initialize normalization
|
| 268 |
+
norm_dim = output_dim
|
| 269 |
+
if norm == 'bn':
|
| 270 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
| 271 |
+
elif norm == 'in':
|
| 272 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
| 273 |
+
elif norm == 'ln':
|
| 274 |
+
self.norm = LayerNorm(norm_dim)
|
| 275 |
+
elif norm == 'none' or norm == 'sn':
|
| 276 |
+
self.norm = None
|
| 277 |
+
else:
|
| 278 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
| 279 |
+
|
| 280 |
+
# initialize activation
|
| 281 |
+
if activation == 'relu':
|
| 282 |
+
self.activation = nn.ReLU(inplace=True)
|
| 283 |
+
elif activation == 'lrelu':
|
| 284 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
| 285 |
+
elif activation == 'prelu':
|
| 286 |
+
self.activation = nn.PReLU()
|
| 287 |
+
elif activation == 'selu':
|
| 288 |
+
self.activation = nn.SELU(inplace=True)
|
| 289 |
+
elif activation == 'tanh':
|
| 290 |
+
self.activation = nn.Tanh()
|
| 291 |
+
elif activation == 'none':
|
| 292 |
+
self.activation = None
|
| 293 |
+
else:
|
| 294 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
out = self.fc(x)
|
| 298 |
+
if self.norm:
|
| 299 |
+
out = self.norm(out)
|
| 300 |
+
if self.activation:
|
| 301 |
+
out = self.activation(out)
|
| 302 |
+
return out
|
| 303 |
+
##################################################################################
|
| 304 |
+
# Normalization layers
|
| 305 |
+
##################################################################################
|
| 306 |
+
class AdaptiveInstanceNorm2d(nn.Module):
|
| 307 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
| 308 |
+
super(AdaptiveInstanceNorm2d, self).__init__()
|
| 309 |
+
self.num_features = num_features
|
| 310 |
+
self.eps = eps
|
| 311 |
+
self.momentum = momentum
|
| 312 |
+
# weight and bias are dynamically assigned
|
| 313 |
+
self.weight = None
|
| 314 |
+
self.bias = None
|
| 315 |
+
# just dummy buffers, not used
|
| 316 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
| 317 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
| 318 |
+
|
| 319 |
+
def forward(self, x):
|
| 320 |
+
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
|
| 321 |
+
b, c = x.size(0), x.size(1)
|
| 322 |
+
running_mean = self.running_mean.repeat(b)
|
| 323 |
+
running_var = self.running_var.repeat(b)
|
| 324 |
+
|
| 325 |
+
# Apply instance norm
|
| 326 |
+
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
| 327 |
+
|
| 328 |
+
out = F.batch_norm(
|
| 329 |
+
x_reshaped, running_mean, running_var, self.weight, self.bias,
|
| 330 |
+
True, self.momentum, self.eps)
|
| 331 |
+
|
| 332 |
+
return out.view(b, c, *x.size()[2:])
|
| 333 |
+
|
| 334 |
+
def __repr__(self):
|
| 335 |
+
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class LayerNorm(nn.Module):
|
| 339 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
| 340 |
+
super(LayerNorm, self).__init__()
|
| 341 |
+
self.num_features = num_features
|
| 342 |
+
self.affine = affine
|
| 343 |
+
self.eps = eps
|
| 344 |
+
|
| 345 |
+
if self.affine:
|
| 346 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
| 347 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
| 348 |
+
|
| 349 |
+
def forward(self, x):
|
| 350 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
| 351 |
+
# print(x.size())
|
| 352 |
+
if x.size(0) == 1:
|
| 353 |
+
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
| 354 |
+
mean = x.view(-1).mean().view(*shape)
|
| 355 |
+
std = x.view(-1).std().view(*shape)
|
| 356 |
+
else:
|
| 357 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
| 358 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
| 359 |
+
|
| 360 |
+
x = (x - mean) / (std + self.eps)
|
| 361 |
+
|
| 362 |
+
if self.affine:
|
| 363 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
| 364 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
| 365 |
+
return x
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def l2normalize(v, eps=1e-12):
|
| 369 |
+
return v / (v.norm() + eps)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class SpectralNorm(nn.Module):
|
| 373 |
+
"""
|
| 374 |
+
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
|
| 375 |
+
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
| 379 |
+
super(SpectralNorm, self).__init__()
|
| 380 |
+
self.module = module
|
| 381 |
+
self.name = name
|
| 382 |
+
self.power_iterations = power_iterations
|
| 383 |
+
if not self._made_params():
|
| 384 |
+
self._make_params()
|
| 385 |
+
|
| 386 |
+
def _update_u_v(self):
|
| 387 |
+
u = getattr(self.module, self.name + "_u")
|
| 388 |
+
v = getattr(self.module, self.name + "_v")
|
| 389 |
+
w = getattr(self.module, self.name + "_bar")
|
| 390 |
+
|
| 391 |
+
height = w.data.shape[0]
|
| 392 |
+
for _ in range(self.power_iterations):
|
| 393 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
| 394 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
| 395 |
+
|
| 396 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
| 397 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
| 398 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
| 399 |
+
|
| 400 |
+
def _made_params(self):
|
| 401 |
+
try:
|
| 402 |
+
u = getattr(self.module, self.name + "_u")
|
| 403 |
+
v = getattr(self.module, self.name + "_v")
|
| 404 |
+
w = getattr(self.module, self.name + "_bar")
|
| 405 |
+
return True
|
| 406 |
+
except AttributeError:
|
| 407 |
+
return False
|
| 408 |
+
|
| 409 |
+
def _make_params(self):
|
| 410 |
+
w = getattr(self.module, self.name)
|
| 411 |
+
|
| 412 |
+
height = w.data.shape[0]
|
| 413 |
+
width = w.view(height, -1).data.shape[1]
|
| 414 |
+
|
| 415 |
+
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
| 416 |
+
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
| 417 |
+
u.data = l2normalize(u.data)
|
| 418 |
+
v.data = l2normalize(v.data)
|
| 419 |
+
w_bar = nn.Parameter(w.data)
|
| 420 |
+
|
| 421 |
+
del self.module._parameters[self.name]
|
| 422 |
+
|
| 423 |
+
self.module.register_parameter(self.name + "_u", u)
|
| 424 |
+
self.module.register_parameter(self.name + "_v", v)
|
| 425 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
| 426 |
+
|
| 427 |
+
def forward(self, *args):
|
| 428 |
+
self._update_u_v()
|
| 429 |
+
return self.module.forward(*args)
|
models/c2pDis.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .basic_layer import *
|
| 2 |
+
import math
|
| 3 |
+
from torch.nn import Parameter
|
| 4 |
+
#from pytorch_metric_learning import losses
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch.
|
| 8 |
+
'''
|
| 9 |
+
def cosine_sim(x1, x2, dim=1, eps=1e-8):
|
| 10 |
+
ip = torch.mm(x1, x2.t()) # w 7*512
|
| 11 |
+
w1 = torch.norm(x1, 2, dim)
|
| 12 |
+
w2 = torch.norm(x2, 2, dim)
|
| 13 |
+
return ip / torch.ger(w1,w2).clamp(min=eps)
|
| 14 |
+
|
| 15 |
+
class MarginCosineProduct(nn.Module):
|
| 16 |
+
r"""Implement of large margin cosine distance: :
|
| 17 |
+
Args:
|
| 18 |
+
in_features: size of each input sample
|
| 19 |
+
out_features: size of each output sample
|
| 20 |
+
s: norm of input feature
|
| 21 |
+
m: margin
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
| 25 |
+
super(MarginCosineProduct, self).__init__()
|
| 26 |
+
self.in_features = in_features
|
| 27 |
+
self.out_features = out_features
|
| 28 |
+
self.s = s
|
| 29 |
+
self.m = m
|
| 30 |
+
self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512
|
| 31 |
+
nn.init.xavier_uniform_(self.weight)
|
| 32 |
+
#stdv = 1. / math.sqrt(self.weight.size(1))
|
| 33 |
+
#self.weight.data.uniform_(-stdv, stdv)
|
| 34 |
+
|
| 35 |
+
def forward(self, input, label):
|
| 36 |
+
cosine = cosine_sim(input, self.weight) # 1*512 7*512
|
| 37 |
+
# cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
| 38 |
+
# --------------------------- convert label to one-hot ---------------------------
|
| 39 |
+
# https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
|
| 40 |
+
one_hot = torch.zeros_like(cosine)
|
| 41 |
+
one_hot.scatter_(1, label.view(-1, 1), 1.0)
|
| 42 |
+
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
| 43 |
+
output = self.s * (cosine - one_hot * self.m)
|
| 44 |
+
|
| 45 |
+
return output
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
return self.__class__.__name__ + '(' \
|
| 49 |
+
+ 'in_features=' + str(self.in_features) \
|
| 50 |
+
+ ', out_features=' + str(self.out_features) \
|
| 51 |
+
+ ', s=' + str(self.s) \
|
| 52 |
+
+ ', m=' + str(self.m) + ')'
|
| 53 |
+
|
| 54 |
+
class ArcMarginProduct(nn.Module):
|
| 55 |
+
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False):
|
| 56 |
+
super(ArcMarginProduct, self).__init__()
|
| 57 |
+
self.in_feature = in_feature
|
| 58 |
+
self.out_feature = out_feature
|
| 59 |
+
self.s = s
|
| 60 |
+
self.m = m
|
| 61 |
+
self.weight = Parameter(torch.Tensor(out_feature, in_feature))
|
| 62 |
+
nn.init.xavier_uniform_(self.weight)
|
| 63 |
+
|
| 64 |
+
self.easy_margin = easy_margin
|
| 65 |
+
self.cos_m = math.cos(m)
|
| 66 |
+
self.sin_m = math.sin(m)
|
| 67 |
+
|
| 68 |
+
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
|
| 69 |
+
self.th = math.cos(math.pi - m)
|
| 70 |
+
self.mm = math.sin(math.pi - m) * m
|
| 71 |
+
|
| 72 |
+
def forward(self, x, label):
|
| 73 |
+
# cos(theta)
|
| 74 |
+
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
|
| 75 |
+
# cos(theta + m)
|
| 76 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
| 77 |
+
phi = cosine * self.cos_m - sine * self.sin_m
|
| 78 |
+
|
| 79 |
+
if self.easy_margin:
|
| 80 |
+
phi = torch.where(cosine > 0, phi, cosine)
|
| 81 |
+
else:
|
| 82 |
+
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
|
| 83 |
+
|
| 84 |
+
#one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 85 |
+
one_hot = torch.zeros_like(cosine)
|
| 86 |
+
one_hot.scatter_(1, label.view(-1, 1), 1)
|
| 87 |
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
| 88 |
+
output = output * self.s
|
| 89 |
+
|
| 90 |
+
return output
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class MultiMarginProduct(nn.Module):
|
| 94 |
+
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False):
|
| 95 |
+
super(MultiMarginProduct, self).__init__()
|
| 96 |
+
self.in_feature = in_feature
|
| 97 |
+
self.out_feature = out_feature
|
| 98 |
+
self.s = s
|
| 99 |
+
self.m1 = m1
|
| 100 |
+
self.m2 = m2
|
| 101 |
+
self.weight = Parameter(torch.Tensor(out_feature, in_feature))
|
| 102 |
+
nn.init.xavier_uniform_(self.weight)
|
| 103 |
+
|
| 104 |
+
self.easy_margin = easy_margin
|
| 105 |
+
self.cos_m1 = math.cos(m1)
|
| 106 |
+
self.sin_m1 = math.sin(m1)
|
| 107 |
+
|
| 108 |
+
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
|
| 109 |
+
self.th = math.cos(math.pi - m1)
|
| 110 |
+
self.mm = math.sin(math.pi - m1) * m1
|
| 111 |
+
|
| 112 |
+
def forward(self, x, label):
|
| 113 |
+
# cos(theta)
|
| 114 |
+
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
|
| 115 |
+
# cos(theta + m1)
|
| 116 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
| 117 |
+
phi = cosine * self.cos_m1 - sine * self.sin_m1
|
| 118 |
+
|
| 119 |
+
if self.easy_margin:
|
| 120 |
+
phi = torch.where(cosine > 0, phi, cosine)
|
| 121 |
+
else:
|
| 122 |
+
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
one_hot = torch.zeros_like(cosine)
|
| 126 |
+
one_hot.scatter_(1, label.view(-1, 1), 1)
|
| 127 |
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin
|
| 128 |
+
output = output - one_hot * self.m2 # additive cosine margin
|
| 129 |
+
output = output * self.s
|
| 130 |
+
|
| 131 |
+
return output
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CPDis(nn.Module):
|
| 135 |
+
"""PatchGAN."""
|
| 136 |
+
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
|
| 137 |
+
super(CPDis, self).__init__()
|
| 138 |
+
|
| 139 |
+
layers = []
|
| 140 |
+
if norm == 'SN':
|
| 141 |
+
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
|
| 142 |
+
else:
|
| 143 |
+
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
|
| 144 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
| 145 |
+
|
| 146 |
+
curr_dim = conv_dim
|
| 147 |
+
for i in range(1, repeat_num):
|
| 148 |
+
if norm == 'SN':
|
| 149 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
|
| 150 |
+
else:
|
| 151 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
|
| 152 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
| 153 |
+
curr_dim = curr_dim * 2
|
| 154 |
+
|
| 155 |
+
# k_size = int(image_size / np.power(2, repeat_num))
|
| 156 |
+
if norm == 'SN':
|
| 157 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
|
| 158 |
+
else:
|
| 159 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
|
| 160 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
| 161 |
+
curr_dim = curr_dim * 2
|
| 162 |
+
|
| 163 |
+
self.main = nn.Sequential(*layers)
|
| 164 |
+
if norm == 'SN':
|
| 165 |
+
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
|
| 166 |
+
else:
|
| 167 |
+
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
if x.ndim == 5:
|
| 171 |
+
x = x.squeeze(0)
|
| 172 |
+
assert x.ndim == 4, x.ndim
|
| 173 |
+
h = self.main(x)
|
| 174 |
+
# out_real = self.conv1(h)
|
| 175 |
+
out_makeup = self.conv1(h)
|
| 176 |
+
# return out_real.squeeze(), out_makeup.squeeze()
|
| 177 |
+
return out_makeup
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class CPDis_cls(nn.Module):
|
| 181 |
+
"""PatchGAN."""
|
| 182 |
+
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
|
| 183 |
+
super(CPDis_cls, self).__init__()
|
| 184 |
+
|
| 185 |
+
layers = []
|
| 186 |
+
if norm == 'SN':
|
| 187 |
+
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
|
| 188 |
+
else:
|
| 189 |
+
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
|
| 190 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
| 191 |
+
|
| 192 |
+
curr_dim = conv_dim
|
| 193 |
+
for i in range(1, repeat_num):
|
| 194 |
+
if norm == 'SN':
|
| 195 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
|
| 196 |
+
else:
|
| 197 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
|
| 198 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
| 199 |
+
curr_dim = curr_dim * 2
|
| 200 |
+
|
| 201 |
+
# k_size = int(image_size / np.power(2, repeat_num))
|
| 202 |
+
if norm == 'SN':
|
| 203 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
|
| 204 |
+
else:
|
| 205 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
|
| 206 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
| 207 |
+
curr_dim = curr_dim * 2
|
| 208 |
+
|
| 209 |
+
self.main = nn.Sequential(*layers)
|
| 210 |
+
if norm == 'SN':
|
| 211 |
+
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
|
| 212 |
+
self.classifier_pool = nn.AdaptiveAvgPool2d(1)
|
| 213 |
+
self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0)
|
| 214 |
+
self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7)
|
| 215 |
+
print("Using Large Margin Cosine Loss.")
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
|
| 219 |
+
|
| 220 |
+
def forward(self, x, label):
|
| 221 |
+
if x.ndim == 5:
|
| 222 |
+
x = x.squeeze(0)
|
| 223 |
+
assert x.ndim == 4, x.ndim
|
| 224 |
+
h = self.main(x) # ([1, 512, 31, 31])
|
| 225 |
+
#print(out_cls.shape)
|
| 226 |
+
out_cls = self.classifier_pool(h)
|
| 227 |
+
#print(out_cls.shape)
|
| 228 |
+
out_cls = self.classifier_conv(out_cls)
|
| 229 |
+
#print(out_cls.shape)
|
| 230 |
+
out_cls = torch.squeeze(out_cls, -1)
|
| 231 |
+
out_cls = torch.squeeze(out_cls, -1)
|
| 232 |
+
out_cls = self.classifier(out_cls, label)
|
| 233 |
+
out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30])
|
| 234 |
+
# return out_real.squeeze(), out_makeup.squeeze()
|
| 235 |
+
return out_makeup, out_cls
|
| 236 |
+
|
| 237 |
+
class SpectralNorm(object):
|
| 238 |
+
def __init__(self):
|
| 239 |
+
self.name = "weight"
|
| 240 |
+
# print(self.name)
|
| 241 |
+
self.power_iterations = 1
|
| 242 |
+
|
| 243 |
+
def compute_weight(self, module):
|
| 244 |
+
u = getattr(module, self.name + "_u")
|
| 245 |
+
v = getattr(module, self.name + "_v")
|
| 246 |
+
w = getattr(module, self.name + "_bar")
|
| 247 |
+
|
| 248 |
+
height = w.data.shape[0]
|
| 249 |
+
for _ in range(self.power_iterations):
|
| 250 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
| 251 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
| 252 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
| 253 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
| 254 |
+
return w / sigma.expand_as(w)
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def apply(module):
|
| 258 |
+
name = "weight"
|
| 259 |
+
fn = SpectralNorm()
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
u = getattr(module, name + "_u")
|
| 263 |
+
v = getattr(module, name + "_v")
|
| 264 |
+
w = getattr(module, name + "_bar")
|
| 265 |
+
except AttributeError:
|
| 266 |
+
w = getattr(module, name)
|
| 267 |
+
height = w.data.shape[0]
|
| 268 |
+
width = w.view(height, -1).data.shape[1]
|
| 269 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
| 270 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
| 271 |
+
w_bar = Parameter(w.data)
|
| 272 |
+
|
| 273 |
+
# del module._parameters[name]
|
| 274 |
+
|
| 275 |
+
module.register_parameter(name + "_u", u)
|
| 276 |
+
module.register_parameter(name + "_v", v)
|
| 277 |
+
module.register_parameter(name + "_bar", w_bar)
|
| 278 |
+
|
| 279 |
+
# remove w from parameter list
|
| 280 |
+
del module._parameters[name]
|
| 281 |
+
|
| 282 |
+
setattr(module, name, fn.compute_weight(module))
|
| 283 |
+
|
| 284 |
+
# recompute weight before every forward()
|
| 285 |
+
module.register_forward_pre_hook(fn)
|
| 286 |
+
|
| 287 |
+
return fn
|
| 288 |
+
|
| 289 |
+
def remove(self, module):
|
| 290 |
+
weight = self.compute_weight(module)
|
| 291 |
+
delattr(module, self.name)
|
| 292 |
+
del module._parameters[self.name + '_u']
|
| 293 |
+
del module._parameters[self.name + '_v']
|
| 294 |
+
del module._parameters[self.name + '_bar']
|
| 295 |
+
module.register_parameter(self.name, Parameter(weight.data))
|
| 296 |
+
|
| 297 |
+
def __call__(self, module, inputs):
|
| 298 |
+
setattr(module, self.name, self.compute_weight(module))
|
| 299 |
+
|
| 300 |
+
def spectral_norm(module):
|
| 301 |
+
SpectralNorm.apply(module)
|
| 302 |
+
return module
|
| 303 |
+
|
| 304 |
+
def remove_spectral_norm(module):
|
| 305 |
+
name = 'weight'
|
| 306 |
+
for k, hook in module._forward_pre_hooks.items():
|
| 307 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
| 308 |
+
hook.remove(module)
|
| 309 |
+
del module._forward_pre_hooks[k]
|
| 310 |
+
return module
|
| 311 |
+
|
| 312 |
+
raise ValueError("spectral_norm of '{}' not found in {}"
|
| 313 |
+
.format(name, module))
|
models/c2pGen.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .basic_layer import *
|
| 2 |
+
import torchvision.models as models
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AliasNet(nn.Module):
|
| 8 |
+
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
|
| 9 |
+
super(AliasNet, self).__init__()
|
| 10 |
+
self.RGBEnc = AliasRGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
|
| 11 |
+
self.RGBDec = AliasRGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
|
| 12 |
+
activ=activ, pad_type=pad_type)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
x = self.RGBEnc(x)
|
| 16 |
+
x = self.RGBDec(x)
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AliasRGBEncoder(nn.Module):
|
| 21 |
+
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
|
| 22 |
+
super(AliasRGBEncoder, self).__init__()
|
| 23 |
+
self.model = []
|
| 24 |
+
self.model += [AliasConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
| 25 |
+
# downsampling blocks
|
| 26 |
+
for i in range(n_downsample):
|
| 27 |
+
self.model += [AliasConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
| 28 |
+
dim *= 2
|
| 29 |
+
# residual blocks
|
| 30 |
+
self.model += [AliasResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
| 31 |
+
self.model = nn.Sequential(*self.model)
|
| 32 |
+
self.output_dim = dim
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return self.model(x)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AliasRGBDecoder(nn.Module):
|
| 39 |
+
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
|
| 40 |
+
super(AliasRGBDecoder, self).__init__()
|
| 41 |
+
# self.model = []
|
| 42 |
+
# # AdaIN residual blocks
|
| 43 |
+
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
| 44 |
+
# # upsampling blocks
|
| 45 |
+
# for i in range(n_upsample):
|
| 46 |
+
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
|
| 47 |
+
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
| 48 |
+
# dim //= 2
|
| 49 |
+
# # use reflection padding in the last conv layer
|
| 50 |
+
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
| 51 |
+
# self.model = nn.Sequential(*self.model)
|
| 52 |
+
self.Res_Blocks = AliasResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
|
| 53 |
+
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
|
| 54 |
+
self.conv_1 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
| 55 |
+
dim //= 2
|
| 56 |
+
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
|
| 57 |
+
self.conv_2 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
| 58 |
+
dim //= 2
|
| 59 |
+
self.conv_3 = AliasConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
x = self.Res_Blocks(x)
|
| 63 |
+
# print(x.shape)
|
| 64 |
+
x = self.upsample_block1(x)
|
| 65 |
+
# print(x.shape)
|
| 66 |
+
x = self.conv_1(x)
|
| 67 |
+
# print(x_small.shape)
|
| 68 |
+
x = self.upsample_block2(x)
|
| 69 |
+
# print(x.shape)
|
| 70 |
+
x = self.conv_2(x)
|
| 71 |
+
# print(x_middle.shape)
|
| 72 |
+
x = self.conv_3(x)
|
| 73 |
+
# print(x_big.shape)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class C2PGen(nn.Module):
|
| 78 |
+
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, style_dim, mlp_dim, activ='relu', pad_type='reflect'):
|
| 79 |
+
super(C2PGen, self).__init__()
|
| 80 |
+
self.PBEnc = PixelBlockEncoder(input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
|
| 81 |
+
self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
|
| 82 |
+
self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='adain',
|
| 83 |
+
activ=activ, pad_type=pad_type)
|
| 84 |
+
self.MLP = MLP(style_dim, 2048, mlp_dim, 3, norm='none', activ=activ)
|
| 85 |
+
|
| 86 |
+
def forward(self, clipart, pixelart, s=1):
|
| 87 |
+
feature = self.RGBEnc(clipart)
|
| 88 |
+
code = self.PBEnc(pixelart)
|
| 89 |
+
result, cellcode = self.fuse(feature, code, s)
|
| 90 |
+
return result#, cellcode #return cellcode when visualizing the cell size code
|
| 91 |
+
|
| 92 |
+
def fuse(self, content, style_code, s=1):
|
| 93 |
+
#print("MLP input:code's shape:", style_code.shape)
|
| 94 |
+
adain_params = self.MLP(style_code) * s # [batch,2048]
|
| 95 |
+
#print("MLP output:adain_params's shape", adain_params.shape)
|
| 96 |
+
#self.assign_adain_params(adain_params, self.RGBDec)
|
| 97 |
+
images = self.RGBDec(content, adain_params)
|
| 98 |
+
return images, adain_params
|
| 99 |
+
|
| 100 |
+
def assign_adain_params(self, adain_params, model):
|
| 101 |
+
# assign the adain_params to the AdaIN layers in model
|
| 102 |
+
for m in model.modules():
|
| 103 |
+
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
| 104 |
+
mean = adain_params[:, :m.num_features]
|
| 105 |
+
std = adain_params[:, m.num_features:2 * m.num_features]
|
| 106 |
+
m.bias = mean.contiguous().view(-1)
|
| 107 |
+
m.weight = std.contiguous().view(-1)
|
| 108 |
+
if adain_params.size(1) > 2 * m.num_features:
|
| 109 |
+
adain_params = adain_params[:, 2 * m.num_features:]
|
| 110 |
+
|
| 111 |
+
def get_num_adain_params(self, model):
|
| 112 |
+
# return the number of AdaIN parameters needed by the model
|
| 113 |
+
num_adain_params = 0
|
| 114 |
+
for m in model.modules():
|
| 115 |
+
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
| 116 |
+
num_adain_params += 2 * m.num_features
|
| 117 |
+
return num_adain_params
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class PixelBlockEncoder(nn.Module):
|
| 121 |
+
def __init__(self, input_dim, dim, style_dim, norm, activ, pad_type):
|
| 122 |
+
super(PixelBlockEncoder, self).__init__()
|
| 123 |
+
vgg19 = models.vgg.vgg19()
|
| 124 |
+
vgg19.classifier._modules['6'] = nn.Linear(4096, 7, bias=True)
|
| 125 |
+
vgg19.load_state_dict(torch.load('./pixelart_vgg19.pth' if not os.environ['PIX_MODEL'] else os.environ['PIX_MODEL'], map_location=torch.device('cpu')))
|
| 126 |
+
self.vgg = vgg19.features
|
| 127 |
+
for p in self.vgg.parameters():
|
| 128 |
+
p.requires_grad = False
|
| 129 |
+
# vgg19 = models.vgg.vgg19(pretrained=False)
|
| 130 |
+
# vgg19.load_state_dict(torch.load('./vgg.pth'))
|
| 131 |
+
# self.vgg = vgg19.features
|
| 132 |
+
# for p in self.vgg.parameters():
|
| 133 |
+
# p.requires_grad = False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
self.conv1 = ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64,concat
|
| 137 |
+
dim = dim * 2
|
| 138 |
+
self.conv2 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128
|
| 139 |
+
dim = dim * 2
|
| 140 |
+
self.conv3 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256
|
| 141 |
+
dim = dim * 2
|
| 142 |
+
self.conv4 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512
|
| 143 |
+
dim = dim * 2
|
| 144 |
+
|
| 145 |
+
self.model = []
|
| 146 |
+
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
| 147 |
+
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
| 148 |
+
self.model = nn.Sequential(*self.model)
|
| 149 |
+
self.output_dim = dim
|
| 150 |
+
|
| 151 |
+
def get_features(self, image, model, layers=None):
|
| 152 |
+
if layers is None:
|
| 153 |
+
layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'}
|
| 154 |
+
features = {}
|
| 155 |
+
x = image
|
| 156 |
+
# model._modules is a dictionary holding each module in the model
|
| 157 |
+
for name, layer in model._modules.items():
|
| 158 |
+
x = layer(x)
|
| 159 |
+
if name in layers:
|
| 160 |
+
features[layers[name]] = x
|
| 161 |
+
return features
|
| 162 |
+
|
| 163 |
+
def componet_enc(self, x):
|
| 164 |
+
# x [16,3,256,256]
|
| 165 |
+
# factor_img [16,7,256,256]
|
| 166 |
+
vgg_aux = self.get_features(x, self.vgg) # x是3通道灰度图
|
| 167 |
+
#x = torch.cat([x, factor_img], dim=1) # [16,3+7,256,256]
|
| 168 |
+
x = self.conv1(x) # 64 256 256
|
| 169 |
+
x = torch.cat([x, vgg_aux['conv1_1']], dim=1) # 128 256 256
|
| 170 |
+
x = self.conv2(x) # 128 128 128
|
| 171 |
+
x = torch.cat([x, vgg_aux['conv2_1']], dim=1) # 256 128 128
|
| 172 |
+
x = self.conv3(x) # 256 64 64
|
| 173 |
+
x = torch.cat([x, vgg_aux['conv3_1']], dim=1) # 512 64 64
|
| 174 |
+
x = self.conv4(x) # 512 32 32
|
| 175 |
+
x = torch.cat([x, vgg_aux['conv4_1']], dim=1) # 1024 32 32
|
| 176 |
+
x = self.model(x)
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
code = self.componet_enc(x)
|
| 181 |
+
return code
|
| 182 |
+
|
| 183 |
+
class RGBEncoder(nn.Module):
|
| 184 |
+
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
|
| 185 |
+
super(RGBEncoder, self).__init__()
|
| 186 |
+
self.model = []
|
| 187 |
+
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
| 188 |
+
# downsampling blocks
|
| 189 |
+
for i in range(n_downsample):
|
| 190 |
+
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
| 191 |
+
dim *= 2
|
| 192 |
+
# residual blocks
|
| 193 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
| 194 |
+
self.model = nn.Sequential(*self.model)
|
| 195 |
+
self.output_dim = dim
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
return self.model(x)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class RGBDecoder(nn.Module):
|
| 202 |
+
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
|
| 203 |
+
super(RGBDecoder, self).__init__()
|
| 204 |
+
# self.model = []
|
| 205 |
+
# # AdaIN residual blocks
|
| 206 |
+
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
| 207 |
+
# # upsampling blocks
|
| 208 |
+
# for i in range(n_upsample):
|
| 209 |
+
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
|
| 210 |
+
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
| 211 |
+
# dim //= 2
|
| 212 |
+
# # use reflection padding in the last conv layer
|
| 213 |
+
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
| 214 |
+
# self.model = nn.Sequential(*self.model)
|
| 215 |
+
#self.Res_Blocks = ModulationResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
|
| 216 |
+
self.mod_conv_1 = ModulationConvBlock(256,256,3)
|
| 217 |
+
self.mod_conv_2 = ModulationConvBlock(256,256,3)
|
| 218 |
+
self.mod_conv_3 = ModulationConvBlock(256,256,3)
|
| 219 |
+
self.mod_conv_4 = ModulationConvBlock(256,256,3)
|
| 220 |
+
self.mod_conv_5 = ModulationConvBlock(256,256,3)
|
| 221 |
+
self.mod_conv_6 = ModulationConvBlock(256,256,3)
|
| 222 |
+
self.mod_conv_7 = ModulationConvBlock(256,256,3)
|
| 223 |
+
self.mod_conv_8 = ModulationConvBlock(256,256,3)
|
| 224 |
+
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
|
| 225 |
+
self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
| 226 |
+
dim //= 2
|
| 227 |
+
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
|
| 228 |
+
self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
| 229 |
+
dim //= 2
|
| 230 |
+
self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
| 231 |
+
|
| 232 |
+
# def forward(self, x):
|
| 233 |
+
# residual = x
|
| 234 |
+
# out = self.model(x)
|
| 235 |
+
# out += residual
|
| 236 |
+
# return out
|
| 237 |
+
def forward(self, x, code):
|
| 238 |
+
residual = x
|
| 239 |
+
x = self.mod_conv_1(x, code[:, :256])
|
| 240 |
+
x = self.mod_conv_2(x, code[:, 256*1:256*2])
|
| 241 |
+
x += residual
|
| 242 |
+
residual = x
|
| 243 |
+
x = self.mod_conv_2(x, code[:, 256*2:256 * 3])
|
| 244 |
+
x = self.mod_conv_2(x, code[:, 256*3:256 * 4])
|
| 245 |
+
x += residual
|
| 246 |
+
residual =x
|
| 247 |
+
x = self.mod_conv_2(x, code[:, 256*4:256 * 5])
|
| 248 |
+
x = self.mod_conv_2(x, code[:, 256*5:256 * 6])
|
| 249 |
+
x += residual
|
| 250 |
+
residual = x
|
| 251 |
+
x = self.mod_conv_2(x, code[:, 256*6:256 * 7])
|
| 252 |
+
x = self.mod_conv_2(x, code[:, 256*7:256 * 8])
|
| 253 |
+
x += residual
|
| 254 |
+
# print(x.shape)
|
| 255 |
+
x = self.upsample_block1(x)
|
| 256 |
+
# print(x.shape)
|
| 257 |
+
x = self.conv_1(x)
|
| 258 |
+
# print(x_small.shape)
|
| 259 |
+
x = self.upsample_block2(x)
|
| 260 |
+
# print(x.shape)
|
| 261 |
+
x = self.conv_2(x)
|
| 262 |
+
# print(x_middle.shape)
|
| 263 |
+
x = self.conv_3(x)
|
| 264 |
+
# print(x_big.shape)
|
| 265 |
+
return x
|
| 266 |
+
|
models/networks.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import init
|
| 4 |
+
import functools
|
| 5 |
+
from torch.optim import lr_scheduler
|
| 6 |
+
from .c2pGen import *
|
| 7 |
+
from .p2cGen import *
|
| 8 |
+
from .c2pDis import *
|
| 9 |
+
|
| 10 |
+
class Identity(nn.Module):
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
return x
|
| 13 |
+
|
| 14 |
+
def get_norm_layer(norm_type='instance'):
|
| 15 |
+
"""Return a normalization layer
|
| 16 |
+
|
| 17 |
+
Parameters:
|
| 18 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
| 19 |
+
|
| 20 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
| 21 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
| 22 |
+
"""
|
| 23 |
+
if norm_type == 'batch':
|
| 24 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
| 25 |
+
elif norm_type == 'instance':
|
| 26 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
| 27 |
+
elif norm_type == 'none':
|
| 28 |
+
def norm_layer(x): return Identity()
|
| 29 |
+
else:
|
| 30 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
| 31 |
+
return norm_layer
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_scheduler(optimizer, opt):
|
| 35 |
+
"""Return a learning rate scheduler
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
optimizer -- the optimizer of the network
|
| 39 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
| 40 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
| 41 |
+
|
| 42 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
| 43 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
| 44 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
| 45 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
| 46 |
+
"""
|
| 47 |
+
if opt.lr_policy == 'linear':
|
| 48 |
+
def lambda_rule(epoch):
|
| 49 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
| 50 |
+
return lr_l
|
| 51 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
| 52 |
+
elif opt.lr_policy == 'step':
|
| 53 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
| 54 |
+
elif opt.lr_policy == 'plateau':
|
| 55 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
| 56 |
+
elif opt.lr_policy == 'cosine':
|
| 57 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
| 58 |
+
else:
|
| 59 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
| 60 |
+
return scheduler
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
| 64 |
+
"""Initialize network weights.
|
| 65 |
+
|
| 66 |
+
Parameters:
|
| 67 |
+
net (network) -- network to be initialized
|
| 68 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
| 69 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
def init_func(m): # define the initialization function
|
| 73 |
+
classname = m.__class__.__name__
|
| 74 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
| 75 |
+
if init_type == 'normal':
|
| 76 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
| 77 |
+
elif init_type == 'xavier':
|
| 78 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
| 79 |
+
elif init_type == 'kaiming':
|
| 80 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 81 |
+
elif init_type == 'orthogonal':
|
| 82 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
| 83 |
+
else:
|
| 84 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 85 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 86 |
+
init.constant_(m.bias.data, 0.0)
|
| 87 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
| 88 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
| 89 |
+
init.constant_(m.bias.data, 0.0)
|
| 90 |
+
|
| 91 |
+
#print('initialize network with %s' % init_type)
|
| 92 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
| 96 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
| 97 |
+
Parameters:
|
| 98 |
+
net (network) -- the network to be initialized
|
| 99 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
| 100 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 101 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 102 |
+
|
| 103 |
+
Return an initialized network.
|
| 104 |
+
"""
|
| 105 |
+
gpu_ids = [0]
|
| 106 |
+
if len(gpu_ids) > 0:
|
| 107 |
+
# assert(torch.cuda.is_available()) #uncomment this for using gpu
|
| 108 |
+
net.to(torch.device("cpu")) #change this for using gpu to gpu_ids[0]
|
| 109 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
| 110 |
+
init_weights(net, init_type, init_gain=init_gain)
|
| 111 |
+
return net
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
| 115 |
+
"""Create a generator
|
| 116 |
+
|
| 117 |
+
Parameters:
|
| 118 |
+
input_nc (int) -- the number of channels in input images
|
| 119 |
+
output_nc (int) -- the number of channels in output images
|
| 120 |
+
ngf (int) -- the number of filters in the last conv layer
|
| 121 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
| 122 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
| 123 |
+
use_dropout (bool) -- if use dropout layers.
|
| 124 |
+
init_type (str) -- the name of our initialization method.
|
| 125 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 126 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 127 |
+
|
| 128 |
+
Returns a generator
|
| 129 |
+
"""
|
| 130 |
+
net = None
|
| 131 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 132 |
+
|
| 133 |
+
if netG == 'c2pGen': # style_dim mlp_dim
|
| 134 |
+
net = C2PGen(input_nc, output_nc, ngf, 2, 4, 256, 256, activ='relu', pad_type='reflect')
|
| 135 |
+
#print('c2pgen resblock is 8')
|
| 136 |
+
elif netG == 'p2cGen':
|
| 137 |
+
net = P2CGen(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect')
|
| 138 |
+
elif netG == 'antialias':
|
| 139 |
+
net = AliasNet(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect')
|
| 140 |
+
else:
|
| 141 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
| 142 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
|
| 147 |
+
"""Create a discriminator
|
| 148 |
+
|
| 149 |
+
Parameters:
|
| 150 |
+
input_nc (int) -- the number of channels in input images
|
| 151 |
+
ndf (int) -- the number of filters in the first conv layer
|
| 152 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
| 153 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
| 154 |
+
norm (str) -- the type of normalization layers used in the network.
|
| 155 |
+
init_type (str) -- the name of the initialization method.
|
| 156 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
| 157 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 158 |
+
|
| 159 |
+
Returns a discriminator
|
| 160 |
+
"""
|
| 161 |
+
net = None
|
| 162 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if netD == 'CPDis':
|
| 166 |
+
net = CPDis(image_size=256, conv_dim=64, repeat_num=3, norm='SN')
|
| 167 |
+
elif netD == 'CPDis_cls':
|
| 168 |
+
net = CPDis_cls(image_size=256, conv_dim=64, repeat_num=3, norm='SN')
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
| 171 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class GANLoss(nn.Module):
|
| 175 |
+
"""Define different GAN objectives.
|
| 176 |
+
|
| 177 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
| 178 |
+
that has the same size as the input.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
| 182 |
+
""" Initialize the GANLoss class.
|
| 183 |
+
|
| 184 |
+
Parameters:
|
| 185 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
| 186 |
+
target_real_label (bool) - - label for a real image
|
| 187 |
+
target_fake_label (bool) - - label of a fake image
|
| 188 |
+
|
| 189 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
| 190 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
| 191 |
+
"""
|
| 192 |
+
super(GANLoss, self).__init__()
|
| 193 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
| 194 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
| 195 |
+
self.gan_mode = gan_mode
|
| 196 |
+
if gan_mode == 'lsgan':
|
| 197 |
+
self.loss = nn.MSELoss()
|
| 198 |
+
elif gan_mode == 'vanilla':
|
| 199 |
+
self.loss = nn.BCEWithLogitsLoss()
|
| 200 |
+
elif gan_mode in ['wgangp']:
|
| 201 |
+
self.loss = None
|
| 202 |
+
else:
|
| 203 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
| 204 |
+
|
| 205 |
+
def get_target_tensor(self, prediction, target_is_real):
|
| 206 |
+
"""Create label tensors with the same size as the input.
|
| 207 |
+
|
| 208 |
+
Parameters:
|
| 209 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
| 210 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
A label tensor filled with ground truth label, and with the size of the input
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
if target_is_real:
|
| 217 |
+
target_tensor = self.real_label
|
| 218 |
+
else:
|
| 219 |
+
target_tensor = self.fake_label
|
| 220 |
+
return target_tensor.expand_as(prediction)
|
| 221 |
+
|
| 222 |
+
def __call__(self, prediction, target_is_real):
|
| 223 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
| 224 |
+
|
| 225 |
+
Parameters:
|
| 226 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
| 227 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
the calculated loss.
|
| 231 |
+
"""
|
| 232 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
| 233 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
| 234 |
+
loss = self.loss(prediction, target_tensor)
|
| 235 |
+
elif self.gan_mode == 'wgangp':
|
| 236 |
+
if target_is_real:
|
| 237 |
+
loss = -prediction.mean()
|
| 238 |
+
else:
|
| 239 |
+
loss = prediction.mean()
|
| 240 |
+
return loss
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
models/p2cGen.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .basic_layer import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class P2CGen(nn.Module):
|
| 5 |
+
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
|
| 6 |
+
super(P2CGen, self).__init__()
|
| 7 |
+
self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
|
| 8 |
+
self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
|
| 9 |
+
activ=activ, pad_type=pad_type)
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
x = self.RGBEnc(x)
|
| 13 |
+
# print("encoder->>", x.shape)
|
| 14 |
+
x = self.RGBDec(x)
|
| 15 |
+
# print(x_small.shape)
|
| 16 |
+
# print(x_middle.shape)
|
| 17 |
+
# print(x_big.shape)
|
| 18 |
+
#return y_small, y_middle, y_big
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RGBEncoder(nn.Module):
|
| 23 |
+
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
|
| 24 |
+
super(RGBEncoder, self).__init__()
|
| 25 |
+
self.model = []
|
| 26 |
+
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
| 27 |
+
# downsampling blocks
|
| 28 |
+
for i in range(n_downsample):
|
| 29 |
+
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
| 30 |
+
dim *= 2
|
| 31 |
+
# residual blocks
|
| 32 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
| 33 |
+
self.model = nn.Sequential(*self.model)
|
| 34 |
+
self.output_dim = dim
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.model(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RGBDecoder(nn.Module):
|
| 41 |
+
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
|
| 42 |
+
super(RGBDecoder, self).__init__()
|
| 43 |
+
# self.model = []
|
| 44 |
+
# # AdaIN residual blocks
|
| 45 |
+
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
| 46 |
+
# # upsampling blocks
|
| 47 |
+
# for i in range(n_upsample):
|
| 48 |
+
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
|
| 49 |
+
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
| 50 |
+
# dim //= 2
|
| 51 |
+
# # use reflection padding in the last conv layer
|
| 52 |
+
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
| 53 |
+
# self.model = nn.Sequential(*self.model)
|
| 54 |
+
self.Res_Blocks = ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
|
| 55 |
+
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
|
| 56 |
+
self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
| 57 |
+
dim //= 2
|
| 58 |
+
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
|
| 59 |
+
self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
| 60 |
+
dim //= 2
|
| 61 |
+
self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
x = self.Res_Blocks(x)
|
| 65 |
+
# print(x.shape)
|
| 66 |
+
x = self.upsample_block1(x)
|
| 67 |
+
# print(x.shape)
|
| 68 |
+
x = self.conv_1(x)
|
| 69 |
+
# print(x_small.shape)
|
| 70 |
+
x = self.upsample_block2(x)
|
| 71 |
+
# print(x.shape)
|
| 72 |
+
x = self.conv_2(x)
|
| 73 |
+
# print(x_middle.shape)
|
| 74 |
+
x = self.conv_3(x)
|
| 75 |
+
# print(x_big.shape)
|
| 76 |
+
return x
|
pixelization.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
from models.networks import define_G
|
| 7 |
+
import glob
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Model():
|
| 11 |
+
def __init__(self, device="cpu"):
|
| 12 |
+
self.device = torch.device(device)
|
| 13 |
+
self.G_A_net = None
|
| 14 |
+
self.alias_net = None
|
| 15 |
+
self.ref_t = None
|
| 16 |
+
|
| 17 |
+
def load(self):
|
| 18 |
+
with torch.no_grad():
|
| 19 |
+
self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
|
| 20 |
+
self.alias_net = define_G(3, 3, 64, "antialias", "instance", False, "normal", 0.02, [0])
|
| 21 |
+
|
| 22 |
+
G_A_state = torch.load("160_net_G_A.pth" if not os.environ['NET_MODEL'] else os.environ['NET_MODEL'], map_location=str(self.device))
|
| 23 |
+
for p in list(G_A_state.keys()):
|
| 24 |
+
G_A_state["module."+str(p)] = G_A_state.pop(p)
|
| 25 |
+
self.G_A_net.load_state_dict(G_A_state)
|
| 26 |
+
|
| 27 |
+
alias_state = torch.load("alias_net.pth" if not os.environ['ALIAS_MODEL'] else os.environ['ALIAS_MODEL'], map_location=str(self.device))
|
| 28 |
+
for p in list(alias_state.keys()):
|
| 29 |
+
alias_state["module."+str(p)] = alias_state.pop(p)
|
| 30 |
+
self.alias_net.load_state_dict(alias_state)
|
| 31 |
+
|
| 32 |
+
ref_img = Image.open("reference.png").convert('L')
|
| 33 |
+
self.ref_t = process(greyscale(ref_img)).to(self.device)
|
| 34 |
+
|
| 35 |
+
def pixelize(self, in_img, out_img):
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
in_img = Image.open(in_img).convert('RGB')
|
| 38 |
+
in_t = process(in_img).to(self.device)
|
| 39 |
+
|
| 40 |
+
out_t = self.alias_net(self.G_A_net(in_t, self.ref_t))
|
| 41 |
+
|
| 42 |
+
save(out_t, out_img)
|
| 43 |
+
|
| 44 |
+
def pixelize_modified(self, in_img, pixel_size, upscale_after) -> Image.Image:
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
in_img = in_img.convert('RGB')
|
| 47 |
+
|
| 48 |
+
# limit in_img size to 1024x1024 so it didn't destroyed by large image
|
| 49 |
+
if in_img.size[0] > 1024 or in_img.size[1] > 1024:
|
| 50 |
+
in_img.thumbnail((1024, 1024), Image.NEAREST)
|
| 51 |
+
|
| 52 |
+
in_img.resize((in_img.size[0] * 4 // pixel_size, in_img.size[1] * 4 // pixel_size))
|
| 53 |
+
|
| 54 |
+
in_t = process(in_img).to(self.device)
|
| 55 |
+
|
| 56 |
+
out_t = self.alias_net(self.G_A_net(in_t, self.ref_t))
|
| 57 |
+
img = to_image(out_t, pixel_size, upscale_after)
|
| 58 |
+
return img
|
| 59 |
+
|
| 60 |
+
def to_image(tensor, pixel_size, upscale_after):
|
| 61 |
+
img = tensor.data[0].cpu().float().numpy()
|
| 62 |
+
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
|
| 63 |
+
img = img.astype(np.uint8)
|
| 64 |
+
img = Image.fromarray(img)
|
| 65 |
+
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
|
| 66 |
+
if upscale_after:
|
| 67 |
+
img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST)
|
| 68 |
+
|
| 69 |
+
return img
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def greyscale(img):
|
| 73 |
+
gray = np.array(img.convert('L'))
|
| 74 |
+
tmp = np.expand_dims(gray, axis=2)
|
| 75 |
+
tmp = np.concatenate((tmp, tmp, tmp), axis=-1)
|
| 76 |
+
return Image.fromarray(tmp)
|
| 77 |
+
|
| 78 |
+
def process(img):
|
| 79 |
+
ow,oh = img.size
|
| 80 |
+
|
| 81 |
+
nw = int(round(ow / 4) * 4)
|
| 82 |
+
nh = int(round(oh / 4) * 4)
|
| 83 |
+
|
| 84 |
+
left = (ow - nw)//2
|
| 85 |
+
top = (oh - nh)//2
|
| 86 |
+
right = left + nw
|
| 87 |
+
bottom = top + nh
|
| 88 |
+
|
| 89 |
+
img = img.crop((left, top, right, bottom))
|
| 90 |
+
|
| 91 |
+
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 92 |
+
|
| 93 |
+
return trans(img)[None, :, :, :]
|
| 94 |
+
|
| 95 |
+
def save(tensor, file):
|
| 96 |
+
img = tensor.data[0].cpu().float().numpy()
|
| 97 |
+
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
|
| 98 |
+
img = img.astype(np.uint8)
|
| 99 |
+
img = Image.fromarray(img)
|
| 100 |
+
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
|
| 101 |
+
img = img.resize((img.size[0]*4, img.size[1]*4), resample=Image.Resampling.NEAREST)
|
| 102 |
+
img.save(file)
|
| 103 |
+
|
| 104 |
+
def pixelize_cli():
|
| 105 |
+
import argparse
|
| 106 |
+
import os
|
| 107 |
+
parser = argparse.ArgumentParser(description='Pixelization')
|
| 108 |
+
parser.add_argument('--input', type=str, default=None, required=True, help='path to image or directory')
|
| 109 |
+
parser.add_argument('--output', type=str, default=None, required=False, help='path to save image/images')
|
| 110 |
+
parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU')
|
| 111 |
+
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
in_path = args.input
|
| 114 |
+
out_path = args.output
|
| 115 |
+
use_cpu = args.cpu
|
| 116 |
+
|
| 117 |
+
if not os.path.exists("alias_net.pth" if not os.environ['ALIAS_MODEL'] else os.environ['ALIAS_MODEL']):
|
| 118 |
+
print("missing models")
|
| 119 |
+
|
| 120 |
+
pairs = []
|
| 121 |
+
|
| 122 |
+
if os.path.isdir(in_path):
|
| 123 |
+
in_images = glob.glob(in_path + "/*.png") + glob.glob(in_path + "/*.jpg")
|
| 124 |
+
if not out_path:
|
| 125 |
+
out_path = os.path.join(in_path, "outputs")
|
| 126 |
+
if not os.path.exists(out_path):
|
| 127 |
+
os.makedirs(out_path)
|
| 128 |
+
elif os.path.isfile(out_path):
|
| 129 |
+
print("output cant be a file if input is a directory")
|
| 130 |
+
return
|
| 131 |
+
for i in in_images:
|
| 132 |
+
pairs += [(i, i.replace(in_path, out_path))]
|
| 133 |
+
elif os.path.isfile(in_path):
|
| 134 |
+
if not out_path:
|
| 135 |
+
base, ext = os.path.splitext(in_path)
|
| 136 |
+
out_path = base+"_pixelized"+ext
|
| 137 |
+
else:
|
| 138 |
+
if os.path.isdir(out_path):
|
| 139 |
+
_, file = os.path.split(in_path)
|
| 140 |
+
out_path = os.path.join(out_path, file)
|
| 141 |
+
pairs = [(in_path, out_path)]
|
| 142 |
+
|
| 143 |
+
m = Model(device = "cpu" if use_cpu else "cuda")
|
| 144 |
+
m.load()
|
| 145 |
+
|
| 146 |
+
for in_file, out_file in pairs:
|
| 147 |
+
print("PROCESSING", in_file, "TO", out_file)
|
| 148 |
+
m.pixelize(in_file, out_file)
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
pixelize_cli()
|
reference.png
ADDED
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
transforms
|
| 4 |
+
numpy==1.24.1
|
| 5 |
+
pillow
|