Spaces:
Runtime error
Runtime error
birdortyedi
commited on
Commit
·
2a92dc2
1
Parent(s):
2529c2e
Add application file
Browse files- app.py +79 -0
- configs/__pycache__/default.cpython-37.pyc +0 -0
- configs/default.py +114 -0
- images/examples/10_Nashville.jpg +0 -0
- images/examples/11_Sutro.jpg +0 -0
- images/examples/12_Toaster.jpg +0 -0
- images/examples/14_Willow.jpg +0 -0
- images/examples/15_X-ProII.jpg +0 -0
- images/examples/16_Lo-Fi.jpg +0 -0
- images/examples/18_Gingham.jpg +0 -0
- images/examples/1_Clarendon.jpg +0 -0
- images/examples/2_Brannan.jpg +0 -0
- images/examples/30_Perpetua.jpg +0 -0
- images/examples/3_Mayfair.jpg +0 -0
- images/examples/4_Hudson.jpg +0 -0
- images/examples/5_Amaro.jpg +0 -0
- images/examples/6_1977.jpg +0 -0
- images/examples/8_Valencia.jpg +0 -0
- images/examples/98_He-Fe.jpg +0 -0
- images/examples/9_Lo-Fi.jpg +0 -0
- layers/__pycache__/blocks.cpython-37.pyc +0 -0
- layers/__pycache__/normalization.cpython-37.pyc +0 -0
- layers/blocks.py +93 -0
- layers/normalization.py +16 -0
- modeling/__pycache__/arch.cpython-37.pyc +0 -0
- modeling/__pycache__/base.cpython-37.pyc +0 -0
- modeling/__pycache__/build.cpython-37.pyc +0 -0
- modeling/arch.py +272 -0
- modeling/base.py +60 -0
- modeling/build.py +25 -0
- requirements.txt +6 -0
- utils/__pycache__/data_utils.cpython-37.pyc +0 -0
- utils/data_utils.py +29 -0
app.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import os
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.models as models
|
| 7 |
+
|
| 8 |
+
from configs.default import get_cfg_defaults
|
| 9 |
+
from modeling.build import build_model
|
| 10 |
+
from utils.data_utils import linear_scaling
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
url = "https://www.dropbox.com/s/uxvax5sjx5iysyl/cifr.pth?dl=0"
|
| 14 |
+
r = requests.get(url, stream=True)
|
| 15 |
+
if not os.path.exists("cifr.pth"):
|
| 16 |
+
with open("cifr.pth", 'wb') as f:
|
| 17 |
+
for data in r:
|
| 18 |
+
f.write(data)
|
| 19 |
+
|
| 20 |
+
cfg = get_cfg_defaults()
|
| 21 |
+
cfg.MODEL.CKPT = "cifr.pth"
|
| 22 |
+
net, _ = build_model(cfg)
|
| 23 |
+
net = net.eval()
|
| 24 |
+
vgg16 = models.vgg16(pretrained=True).features.eval()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_checkpoints_from_ckpt(ckpt_path):
|
| 28 |
+
checkpoints = torch.load(ckpt_path, map_location=torch.device('cuda'))
|
| 29 |
+
net.load_state_dict(checkpoints["ifr"])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def filter_removal(img):
|
| 36 |
+
arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
| 37 |
+
arr = torch.tensor(arr).float() / 255.
|
| 38 |
+
arr = linear_scaling(arr)
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
feat = vgg16(arr)
|
| 41 |
+
out, _ = net(arr, feat)
|
| 42 |
+
out = torch.clamp(out, max=1., min=0.)
|
| 43 |
+
return out.squeeze(0).permute(1, 2, 0).numpy()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
title = "Contrastive Instagram Filter Removal (CIFR)"
|
| 47 |
+
description = "This is the demo for CIFR, contrastive strategy for filter removal on fashionable images on Instagram. " \
|
| 48 |
+
"To use it, simply upload your filtered image, or click one of the examples to load them."
|
| 49 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.07486'>Contrastive Instagram Filter Removal (CIFR)</a> | <a href='https://github.com/birdortyedi/cifr-pytorch'>Github Repo</a></p>"
|
| 50 |
+
|
| 51 |
+
gr.Interface(
|
| 52 |
+
filter_removal,
|
| 53 |
+
gr.inputs.Image(shape=(256, 256)),
|
| 54 |
+
gr.outputs.Image(),
|
| 55 |
+
title=title,
|
| 56 |
+
description=description,
|
| 57 |
+
article=article,
|
| 58 |
+
allow_flagging=False,
|
| 59 |
+
examples_per_page=17,
|
| 60 |
+
examples=[
|
| 61 |
+
["images/examples/98_He-Fe.jpg"],
|
| 62 |
+
["images/examples/2_Brannan.jpg"],
|
| 63 |
+
["images/examples/12_Toaster.jpg"],
|
| 64 |
+
["images/examples/18_Gingham.jpg"],
|
| 65 |
+
["images/examples/11_Sutro.jpg"],
|
| 66 |
+
["images/examples/9_Lo-Fi.jpg"],
|
| 67 |
+
["images/examples/3_Mayfair.jpg"],
|
| 68 |
+
["images/examples/4_Hudson.jpg"],
|
| 69 |
+
["images/examples/5_Amaro.jpg"],
|
| 70 |
+
["images/examples/6_1977.jpg"],
|
| 71 |
+
["images/examples/8_Valencia.jpg"],
|
| 72 |
+
["images/examples/16_Lo-Fi.jpg"],
|
| 73 |
+
["images/examples/10_Nashville.jpg"],
|
| 74 |
+
["images/examples/15_X-ProII.jpg"],
|
| 75 |
+
["images/examples/14_Willow.jpg"],
|
| 76 |
+
["images/examples/30_Perpetua.jpg"],
|
| 77 |
+
["images/examples/1_Clarendon.jpg"],
|
| 78 |
+
]
|
| 79 |
+
).launch()
|
configs/__pycache__/default.cpython-37.pyc
ADDED
|
Binary file (2.68 kB). View file
|
|
|
configs/default.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from yacs.config import CfgNode as CN
|
| 2 |
+
|
| 3 |
+
_C = CN()
|
| 4 |
+
|
| 5 |
+
_C.SYSTEM = CN()
|
| 6 |
+
_C.SYSTEM.NUM_GPU = 2
|
| 7 |
+
_C.SYSTEM.NUM_WORKERS = 4
|
| 8 |
+
|
| 9 |
+
_C.WANDB = CN()
|
| 10 |
+
_C.WANDB.PROJECT_NAME = "contrastive-style-learning-for-ifr"
|
| 11 |
+
_C.WANDB.ENTITY = "vvgl-ozu"
|
| 12 |
+
_C.WANDB.RUN = 3
|
| 13 |
+
_C.WANDB.LOG_DIR = ""
|
| 14 |
+
_C.WANDB.NUM_ROW = 0
|
| 15 |
+
|
| 16 |
+
_C.TRAIN = CN()
|
| 17 |
+
_C.TRAIN.NUM_TOTAL_STEP = 200000
|
| 18 |
+
_C.TRAIN.START_STEP = 0
|
| 19 |
+
_C.TRAIN.BATCH_SIZE = 16
|
| 20 |
+
_C.TRAIN.SHUFFLE = True
|
| 21 |
+
_C.TRAIN.LOG_INTERVAL = 100
|
| 22 |
+
_C.TRAIN.EVAL_INTERVAL = 1000
|
| 23 |
+
_C.TRAIN.SAVE_INTERVAL = 1000
|
| 24 |
+
_C.TRAIN.SAVE_DIR = "./weights"
|
| 25 |
+
_C.TRAIN.RESUME = True
|
| 26 |
+
_C.TRAIN.VISUALIZE_INTERVAL = 100
|
| 27 |
+
_C.TRAIN.TUNE = False
|
| 28 |
+
|
| 29 |
+
_C.MODEL = CN()
|
| 30 |
+
_C.MODEL.NAME = "cifr"
|
| 31 |
+
_C.MODEL.IS_TRAIN = True
|
| 32 |
+
_C.MODEL.NUM_CLASS = 17
|
| 33 |
+
_C.MODEL.CKPT = ""
|
| 34 |
+
_C.MODEL.PRETRAINED = ""
|
| 35 |
+
|
| 36 |
+
_C.MODEL.IFR = CN()
|
| 37 |
+
_C.MODEL.IFR.NAME = "ContrastiveInstaFilterRemovalNetwork"
|
| 38 |
+
_C.MODEL.IFR.NUM_CHANNELS = 32
|
| 39 |
+
_C.MODEL.IFR.DESTYLER_CHANNELS = 32
|
| 40 |
+
_C.MODEL.IFR.SOLVER = CN()
|
| 41 |
+
_C.MODEL.IFR.SOLVER.LR = 2e-4
|
| 42 |
+
_C.MODEL.IFR.SOLVER.BETAS = (0.5, 0.999)
|
| 43 |
+
_C.MODEL.IFR.SOLVER.SCHEDULER = []
|
| 44 |
+
_C.MODEL.IFR.SOLVER.DECAY_RATE = 0.
|
| 45 |
+
_C.MODEL.IFR.DS_FACTOR = 4
|
| 46 |
+
|
| 47 |
+
_C.MODEL.PATCH = CN()
|
| 48 |
+
_C.MODEL.PATCH.NUM_CHANNELS = 256
|
| 49 |
+
_C.MODEL.PATCH.NUM_PATCHES = 256
|
| 50 |
+
_C.MODEL.PATCH.NUM_LAYERS = 6
|
| 51 |
+
_C.MODEL.PATCH.USE_MLP = True
|
| 52 |
+
_C.MODEL.PATCH.SHUFFLE_Y = True
|
| 53 |
+
_C.MODEL.PATCH.LR = 1e-4
|
| 54 |
+
_C.MODEL.PATCH.BETAS = (0.5, 0.999)
|
| 55 |
+
_C.MODEL.PATCH.T = 0.07
|
| 56 |
+
|
| 57 |
+
_C.MODEL.D = CN()
|
| 58 |
+
_C.MODEL.D.NAME = "1-ChOutputDiscriminator"
|
| 59 |
+
_C.MODEL.D.NUM_CHANNELS = 32
|
| 60 |
+
_C.MODEL.D.NUM_CRITICS = 3
|
| 61 |
+
_C.MODEL.D.SOLVER = CN()
|
| 62 |
+
_C.MODEL.D.SOLVER.LR = 1e-4
|
| 63 |
+
_C.MODEL.D.SOLVER.BETAS = (0.5, 0.999)
|
| 64 |
+
_C.MODEL.D.SOLVER.SCHEDULER = []
|
| 65 |
+
_C.MODEL.D.SOLVER.DECAY_RATE = 0.01
|
| 66 |
+
|
| 67 |
+
_C.ESRGAN = CN()
|
| 68 |
+
_C.ESRGAN.WEIGHTS = "weights/RealESRGAN_x{}plus.pth"
|
| 69 |
+
|
| 70 |
+
_C.FASHIONMASKRCNN = CN()
|
| 71 |
+
_C.FASHIONMASKRCNN.CFG_PATH = "configs/fashion.yaml"
|
| 72 |
+
_C.FASHIONMASKRCNN.WEIGHTS = "weights/fashion.pth"
|
| 73 |
+
_C.FASHIONMASKRCNN.SCORE_THRESH_TEST = 0.6
|
| 74 |
+
_C.FASHIONMASKRCNN.MIN_SIZE_TEST = 512
|
| 75 |
+
|
| 76 |
+
_C.OPTIM = CN()
|
| 77 |
+
_C.OPTIM.GP = 10.
|
| 78 |
+
_C.OPTIM.MASK = 1
|
| 79 |
+
_C.OPTIM.RECON = 1.4
|
| 80 |
+
_C.OPTIM.SEMANTIC = 1e-1
|
| 81 |
+
_C.OPTIM.TEXTURE = 2e-1
|
| 82 |
+
_C.OPTIM.ADVERSARIAL = 1e-3
|
| 83 |
+
_C.OPTIM.AUX = 0.5
|
| 84 |
+
_C.OPTIM.CONTRASTIVE = 0.1
|
| 85 |
+
_C.OPTIM.NLL = 1.0
|
| 86 |
+
|
| 87 |
+
_C.DATASET = CN()
|
| 88 |
+
_C.DATASET.NAME = "IFFI"
|
| 89 |
+
_C.DATASET.ROOT = "../../Downloads/IFFI-dataset/train" # "../../Downloads/IFFI-dataset/train"
|
| 90 |
+
_C.DATASET.TEST_ROOT = "../../Datasets/IFFI-dataset/test" # "../../Downloads/IFFI-dataset/test"
|
| 91 |
+
_C.DATASET.DS_TEST_ROOT = "../../Downloads/IFFI-dataset/test/" # "../../Downloads/IFFI-dataset/test"
|
| 92 |
+
_C.DATASET.DS_JSON_FILE = "../../Downloads/IFFI-dataset-only-orgs/instances_default.json"
|
| 93 |
+
_C.DATASET.SIZE = 256
|
| 94 |
+
_C.DATASET.CROP_SIZE = 512
|
| 95 |
+
_C.DATASET.MEAN = [0.5, 0.5, 0.5]
|
| 96 |
+
_C.DATASET.STD = [0.5, 0.5, 0.5]
|
| 97 |
+
|
| 98 |
+
_C.TEST = CN()
|
| 99 |
+
_C.TEST.OUTPUT_DIR = "./outputs"
|
| 100 |
+
_C.TEST.ABLATION = False
|
| 101 |
+
_C.TEST.WEIGHTS = ""
|
| 102 |
+
_C.TEST.BATCH_SIZE = 32
|
| 103 |
+
_C.TEST.IMG_ID = 52
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_cfg_defaults():
|
| 107 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
| 108 |
+
# Return a clone so that the defaults will not be altered
|
| 109 |
+
# This is for the "local variable" use pattern
|
| 110 |
+
return _C.clone()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# provide a way to import the defaults as a global singleton:
|
| 114 |
+
cfg = _C # users can `from config import cfg`
|
images/examples/10_Nashville.jpg
ADDED
|
images/examples/11_Sutro.jpg
ADDED
|
images/examples/12_Toaster.jpg
ADDED
|
images/examples/14_Willow.jpg
ADDED
|
images/examples/15_X-ProII.jpg
ADDED
|
images/examples/16_Lo-Fi.jpg
ADDED
|
images/examples/18_Gingham.jpg
ADDED
|
images/examples/1_Clarendon.jpg
ADDED
|
images/examples/2_Brannan.jpg
ADDED
|
images/examples/30_Perpetua.jpg
ADDED
|
images/examples/3_Mayfair.jpg
ADDED
|
images/examples/4_Hudson.jpg
ADDED
|
images/examples/5_Amaro.jpg
ADDED
|
images/examples/6_1977.jpg
ADDED
|
images/examples/8_Valencia.jpg
ADDED
|
images/examples/98_He-Fe.jpg
ADDED
|
images/examples/9_Lo-Fi.jpg
ADDED
|
layers/__pycache__/blocks.cpython-37.pyc
ADDED
|
Binary file (2.93 kB). View file
|
|
|
layers/__pycache__/normalization.cpython-37.pyc
ADDED
|
Binary file (859 Bytes). View file
|
|
|
layers/blocks.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
from layers.normalization import AdaIN
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DestyleResBlock(nn.Module):
|
| 7 |
+
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
|
| 8 |
+
super(DestyleResBlock, self).__init__()
|
| 9 |
+
|
| 10 |
+
# uses 1x1 convolutions for downsampling
|
| 11 |
+
if not channels_in or channels_in == channels_out:
|
| 12 |
+
channels_in = channels_out
|
| 13 |
+
self.projection = None
|
| 14 |
+
else:
|
| 15 |
+
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
|
| 16 |
+
self.use_dropout = use_dropout
|
| 17 |
+
|
| 18 |
+
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
|
| 19 |
+
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 20 |
+
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
| 21 |
+
self.adain = AdaIN()
|
| 22 |
+
if self.use_dropout:
|
| 23 |
+
self.dropout = nn.Dropout()
|
| 24 |
+
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 25 |
+
|
| 26 |
+
def forward(self, x, feat):
|
| 27 |
+
residual = x
|
| 28 |
+
out = self.conv1(x)
|
| 29 |
+
out = self.lrelu1(out)
|
| 30 |
+
out = self.conv2(out)
|
| 31 |
+
_, _, h, w = out.size()
|
| 32 |
+
out = self.adain(out, feat)
|
| 33 |
+
if self.use_dropout:
|
| 34 |
+
out = self.dropout(out)
|
| 35 |
+
if self.projection:
|
| 36 |
+
residual = self.projection(x)
|
| 37 |
+
out = out + residual
|
| 38 |
+
out = self.lrelu2(out)
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ResBlock(nn.Module):
|
| 43 |
+
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
|
| 44 |
+
super(ResBlock, self).__init__()
|
| 45 |
+
|
| 46 |
+
# uses 1x1 convolutions for downsampling
|
| 47 |
+
if not channels_in or channels_in == channels_out:
|
| 48 |
+
channels_in = channels_out
|
| 49 |
+
self.projection = None
|
| 50 |
+
else:
|
| 51 |
+
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
|
| 52 |
+
self.use_dropout = use_dropout
|
| 53 |
+
|
| 54 |
+
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
|
| 55 |
+
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 56 |
+
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
| 57 |
+
self.n2 = nn.BatchNorm2d(channels_out)
|
| 58 |
+
if self.use_dropout:
|
| 59 |
+
self.dropout = nn.Dropout()
|
| 60 |
+
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
residual = x
|
| 64 |
+
out = self.conv1(x)
|
| 65 |
+
out = self.lrelu1(out)
|
| 66 |
+
out = self.conv2(out)
|
| 67 |
+
# out = self.n2(out)
|
| 68 |
+
if self.use_dropout:
|
| 69 |
+
out = self.dropout(out)
|
| 70 |
+
if self.projection:
|
| 71 |
+
residual = self.projection(x)
|
| 72 |
+
out = out + residual
|
| 73 |
+
out = self.lrelu2(out)
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Destyler(nn.Module):
|
| 78 |
+
def __init__(self, in_features, num_features):
|
| 79 |
+
super(Destyler, self).__init__()
|
| 80 |
+
self.fc1 = nn.Linear(in_features, num_features)
|
| 81 |
+
self.fc2 = nn.Linear(num_features, num_features)
|
| 82 |
+
self.fc3 = nn.Linear(num_features, num_features)
|
| 83 |
+
self.fc4 = nn.Linear(num_features, num_features)
|
| 84 |
+
self.fc5 = nn.Linear(num_features, num_features)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
x = self.fc1(x)
|
| 88 |
+
x = self.fc2(x)
|
| 89 |
+
x = self.fc3(x)
|
| 90 |
+
x = self.fc4(x)
|
| 91 |
+
x = self.fc5(x)
|
| 92 |
+
return x
|
| 93 |
+
|
layers/normalization.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AdaIN(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, x, y):
|
| 10 |
+
ch = y.size(1)
|
| 11 |
+
sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1)
|
| 12 |
+
|
| 13 |
+
x_mu = x.mean(dim=[2, 3], keepdim=True)
|
| 14 |
+
x_sigma = x.std(dim=[2, 3], keepdim=True)
|
| 15 |
+
|
| 16 |
+
return sigma * ((x - x_mu) / x_sigma) + mu
|
modeling/__pycache__/arch.cpython-37.pyc
ADDED
|
Binary file (8.89 kB). View file
|
|
|
modeling/__pycache__/base.cpython-37.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
modeling/__pycache__/build.cpython-37.pyc
ADDED
|
Binary file (1.22 kB). View file
|
|
|
modeling/arch.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn.utils import spectral_norm
|
| 4 |
+
|
| 5 |
+
from modeling.base import BaseNetwork
|
| 6 |
+
from layers.blocks import DestyleResBlock, Destyler, ResBlock
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class IFRNet(BaseNetwork):
|
| 10 |
+
def __init__(self, base_n_channels, destyler_n_channels):
|
| 11 |
+
super(IFRNet, self).__init__()
|
| 12 |
+
self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features
|
| 13 |
+
|
| 14 |
+
self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2)
|
| 15 |
+
self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
|
| 16 |
+
self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4)
|
| 17 |
+
self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
|
| 18 |
+
self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4)
|
| 19 |
+
self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
| 20 |
+
self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8)
|
| 21 |
+
self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
|
| 22 |
+
self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8)
|
| 23 |
+
self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
| 24 |
+
self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16)
|
| 25 |
+
self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
|
| 26 |
+
|
| 27 |
+
self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
|
| 28 |
+
|
| 29 |
+
self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
| 30 |
+
self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
| 31 |
+
self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
| 32 |
+
self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
| 33 |
+
self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
|
| 34 |
+
|
| 35 |
+
self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
|
| 36 |
+
|
| 37 |
+
self.init_weights(init_type="normal", gain=0.02)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, vgg_feat):
|
| 40 |
+
b_size, ch, h, w = vgg_feat.size()
|
| 41 |
+
vgg_feat = vgg_feat.view(b_size, ch * h * w)
|
| 42 |
+
vgg_feat = self.destyler(vgg_feat)
|
| 43 |
+
|
| 44 |
+
out = self.ds_res1(x, self.ds_fc1(vgg_feat))
|
| 45 |
+
out = self.ds_res2(out, self.ds_fc2(vgg_feat))
|
| 46 |
+
out = self.ds_res3(out, self.ds_fc3(vgg_feat))
|
| 47 |
+
out = self.ds_res4(out, self.ds_fc4(vgg_feat))
|
| 48 |
+
out = self.ds_res5(out, self.ds_fc5(vgg_feat))
|
| 49 |
+
aux = self.ds_res6(out, self.ds_fc6(vgg_feat))
|
| 50 |
+
|
| 51 |
+
out = self.upsample(aux)
|
| 52 |
+
out = self.res1(out)
|
| 53 |
+
out = self.res2(out)
|
| 54 |
+
out = self.upsample(out)
|
| 55 |
+
out = self.res3(out)
|
| 56 |
+
out = self.res4(out)
|
| 57 |
+
out = self.upsample(out)
|
| 58 |
+
out = self.res5(out)
|
| 59 |
+
out = self.conv1(out)
|
| 60 |
+
|
| 61 |
+
return out, aux
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CIFR_Encoder(IFRNet):
|
| 65 |
+
def __init__(self, base_n_channels, destyler_n_channels):
|
| 66 |
+
super(CIFR_Encoder, self).__init__(base_n_channels, destyler_n_channels)
|
| 67 |
+
|
| 68 |
+
def forward(self, x, vgg_feat):
|
| 69 |
+
b_size, ch, h, w = vgg_feat.size()
|
| 70 |
+
vgg_feat = vgg_feat.view(b_size, ch * h * w)
|
| 71 |
+
vgg_feat = self.destyler(vgg_feat)
|
| 72 |
+
|
| 73 |
+
feat1 = self.ds_res1(x, self.ds_fc1(vgg_feat))
|
| 74 |
+
feat2 = self.ds_res2(feat1, self.ds_fc2(vgg_feat))
|
| 75 |
+
feat3 = self.ds_res3(feat2, self.ds_fc3(vgg_feat))
|
| 76 |
+
feat4 = self.ds_res4(feat3, self.ds_fc4(vgg_feat))
|
| 77 |
+
feat5 = self.ds_res5(feat4, self.ds_fc5(vgg_feat))
|
| 78 |
+
feat6 = self.ds_res6(feat5, self.ds_fc6(vgg_feat))
|
| 79 |
+
|
| 80 |
+
feats = [feat1, feat2, feat3, feat4, feat5, feat6]
|
| 81 |
+
|
| 82 |
+
out = self.upsample(feat6)
|
| 83 |
+
out = self.res1(out)
|
| 84 |
+
out = self.res2(out)
|
| 85 |
+
out = self.upsample(out)
|
| 86 |
+
out = self.res3(out)
|
| 87 |
+
out = self.res4(out)
|
| 88 |
+
out = self.upsample(out)
|
| 89 |
+
out = self.res5(out)
|
| 90 |
+
out = self.conv1(out)
|
| 91 |
+
|
| 92 |
+
return out, feats
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Normalize(nn.Module):
|
| 96 |
+
def __init__(self, power=2):
|
| 97 |
+
super(Normalize, self).__init__()
|
| 98 |
+
self.power = power
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
| 102 |
+
out = x.div(norm + 1e-7)
|
| 103 |
+
return out
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class PatchSampleF(BaseNetwork):
|
| 107 |
+
def __init__(self, base_n_channels, style_or_content, use_mlp=False, nc=256):
|
| 108 |
+
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
| 109 |
+
super(PatchSampleF, self).__init__()
|
| 110 |
+
self.is_content = True if style_or_content == "content" else False
|
| 111 |
+
self.l2norm = Normalize(2)
|
| 112 |
+
self.use_mlp = use_mlp
|
| 113 |
+
self.nc = nc # hard-coded
|
| 114 |
+
|
| 115 |
+
self.mlp_0 = nn.Sequential(*[nn.Linear(base_n_channels, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
| 116 |
+
self.mlp_1 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
| 117 |
+
self.mlp_2 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
| 118 |
+
self.mlp_3 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
| 119 |
+
self.mlp_4 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
| 120 |
+
self.mlp_5 = nn.Sequential(*[nn.Linear(base_n_channels * 8, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
| 121 |
+
self.init_weights(init_type="normal", gain=0.02)
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def gram_matrix(x):
|
| 125 |
+
# a, b, c, d = x.size() # a=batch size(=1)
|
| 126 |
+
a, b = x.size()
|
| 127 |
+
# b=number of feature maps
|
| 128 |
+
# (c,d)=dimensions of a f. map (N=c*d)
|
| 129 |
+
|
| 130 |
+
# features = x.view(a * b, c * d) # resise F_XL into \hat F_XL
|
| 131 |
+
|
| 132 |
+
G = torch.mm(x, x.t()) # compute the gram product
|
| 133 |
+
|
| 134 |
+
# we 'normalize' the values of the gram matrix
|
| 135 |
+
# by dividing by the number of element in each feature maps.
|
| 136 |
+
return G.div(a * b)
|
| 137 |
+
|
| 138 |
+
def forward(self, feats, num_patches=64, patch_ids=None):
|
| 139 |
+
return_ids = []
|
| 140 |
+
return_feats = []
|
| 141 |
+
|
| 142 |
+
for feat_id, feat in enumerate(feats):
|
| 143 |
+
B, C, H, W = feat.shape
|
| 144 |
+
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
| 145 |
+
if num_patches > 0:
|
| 146 |
+
if patch_ids is not None:
|
| 147 |
+
patch_id = patch_ids[feat_id]
|
| 148 |
+
else:
|
| 149 |
+
patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
| 150 |
+
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
| 151 |
+
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
|
| 152 |
+
else:
|
| 153 |
+
x_sample = feat_reshape
|
| 154 |
+
patch_id = []
|
| 155 |
+
if self.use_mlp:
|
| 156 |
+
mlp = getattr(self, 'mlp_%d' % feat_id)
|
| 157 |
+
x_sample = mlp(x_sample)
|
| 158 |
+
if not self.is_content:
|
| 159 |
+
x_sample = self.gram_matrix(x_sample)
|
| 160 |
+
return_ids.append(patch_id)
|
| 161 |
+
x_sample = self.l2norm(x_sample)
|
| 162 |
+
|
| 163 |
+
if num_patches == 0:
|
| 164 |
+
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
| 165 |
+
return_feats.append(x_sample)
|
| 166 |
+
return return_feats, return_ids
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class MLP(nn.Module):
|
| 170 |
+
def __init__(self, base_n_channels, out_features=14):
|
| 171 |
+
super(MLP, self).__init__()
|
| 172 |
+
self.aux_classifier = nn.Sequential(
|
| 173 |
+
nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1),
|
| 174 |
+
nn.MaxPool2d(2),
|
| 175 |
+
nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1),
|
| 176 |
+
nn.MaxPool2d(2),
|
| 177 |
+
# nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1),
|
| 178 |
+
# nn.MaxPool2d(2),
|
| 179 |
+
Flatten(),
|
| 180 |
+
nn.Linear(base_n_channels * 8 * 8 * 2, out_features),
|
| 181 |
+
# nn.Softmax(dim=-1)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
return self.aux_classifier(x)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Flatten(nn.Module):
|
| 189 |
+
def forward(self, input):
|
| 190 |
+
"""
|
| 191 |
+
Note that input.size(0) is usually the batch size.
|
| 192 |
+
So what it does is that given any input with input.size(0) # of batches,
|
| 193 |
+
will flatten to be 1 * nb_elements.
|
| 194 |
+
"""
|
| 195 |
+
batch_size = input.size(0)
|
| 196 |
+
out = input.view(batch_size, -1)
|
| 197 |
+
return out # (batch_size, *size)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Discriminator(BaseNetwork):
|
| 201 |
+
def __init__(self, base_n_channels):
|
| 202 |
+
"""
|
| 203 |
+
img_size : (int, int, int)
|
| 204 |
+
Height and width must be powers of 2. E.g. (32, 32, 1) or
|
| 205 |
+
(64, 128, 3). Last number indicates number of channels, e.g. 1 for
|
| 206 |
+
grayscale or 3 for RGB
|
| 207 |
+
"""
|
| 208 |
+
super(Discriminator, self).__init__()
|
| 209 |
+
|
| 210 |
+
self.image_to_features = nn.Sequential(
|
| 211 |
+
spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)),
|
| 212 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 213 |
+
spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)),
|
| 214 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 215 |
+
spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)),
|
| 216 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 217 |
+
spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
|
| 218 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 219 |
+
# spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
|
| 220 |
+
# nn.LeakyReLU(0.2, inplace=True),
|
| 221 |
+
spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)),
|
| 222 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
output_size = 8 * base_n_channels * 3 * 3
|
| 226 |
+
self.features_to_prob = nn.Sequential(
|
| 227 |
+
spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)),
|
| 228 |
+
Flatten(),
|
| 229 |
+
nn.Linear(output_size, 1)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.init_weights(init_type="normal", gain=0.02)
|
| 233 |
+
|
| 234 |
+
def forward(self, input_data):
|
| 235 |
+
x = self.image_to_features(input_data)
|
| 236 |
+
return self.features_to_prob(x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class PatchDiscriminator(Discriminator):
|
| 240 |
+
def __init__(self, base_n_channels):
|
| 241 |
+
super(PatchDiscriminator, self).__init__(base_n_channels)
|
| 242 |
+
|
| 243 |
+
self.features_to_prob = nn.Sequential(
|
| 244 |
+
spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)),
|
| 245 |
+
Flatten()
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def forward(self, input_data):
|
| 249 |
+
x = self.image_to_features(input_data)
|
| 250 |
+
return self.features_to_prob(x)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == '__main__':
|
| 254 |
+
import torchvision
|
| 255 |
+
ifrnet = CIFR_Encoder(32, 128).cuda()
|
| 256 |
+
x = torch.rand((2, 3, 256, 256)).cuda()
|
| 257 |
+
vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
vgg_feat = vgg16(x)
|
| 260 |
+
output, feats = ifrnet(x, vgg_feat)
|
| 261 |
+
print(output.size())
|
| 262 |
+
for i, feat in enumerate(feats):
|
| 263 |
+
print(i, feat.size())
|
| 264 |
+
|
| 265 |
+
disc = Discriminator(32).cuda()
|
| 266 |
+
d_out = disc(output)
|
| 267 |
+
print(d_out.size())
|
| 268 |
+
|
| 269 |
+
patch_disc = PatchDiscriminator(32).cuda()
|
| 270 |
+
p_d_out = patch_disc(output)
|
| 271 |
+
print(p_d_out.size())
|
| 272 |
+
|
modeling/base.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BaseNetwork(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super(BaseNetwork, self).__init__()
|
| 7 |
+
|
| 8 |
+
def forward(self, x, y):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def print_network(self):
|
| 12 |
+
if isinstance(self, list):
|
| 13 |
+
self = self[0]
|
| 14 |
+
num_params = 0
|
| 15 |
+
for param in self.parameters():
|
| 16 |
+
num_params += param.numel()
|
| 17 |
+
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
| 18 |
+
'To see the architecture, do print(network).'
|
| 19 |
+
% (type(self).__name__, num_params / 1000000))
|
| 20 |
+
|
| 21 |
+
def set_requires_grad(self, requires_grad=False):
|
| 22 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
| 23 |
+
Parameters:
|
| 24 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
| 25 |
+
"""
|
| 26 |
+
for param in self.parameters():
|
| 27 |
+
param.requires_grad = requires_grad
|
| 28 |
+
|
| 29 |
+
def init_weights(self, init_type='xavier', gain=0.02):
|
| 30 |
+
def init_func(m):
|
| 31 |
+
classname = m.__class__.__name__
|
| 32 |
+
if classname.find('BatchNorm2d') != -1:
|
| 33 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
| 34 |
+
nn.init.normal_(m.weight.data, 1.0, gain)
|
| 35 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 36 |
+
nn.init.constant_(m.bias.data, 0.0)
|
| 37 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
| 38 |
+
if init_type == 'normal':
|
| 39 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
| 40 |
+
elif init_type == 'xavier':
|
| 41 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
| 42 |
+
elif init_type == 'xavier_uniform':
|
| 43 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
| 44 |
+
elif init_type == 'kaiming':
|
| 45 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 46 |
+
elif init_type == 'orthogonal':
|
| 47 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
| 48 |
+
elif init_type == 'none': # uses pytorch's default init method
|
| 49 |
+
m.reset_parameters()
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 52 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 53 |
+
nn.init.constant_(m.bias.data, 0.0)
|
| 54 |
+
|
| 55 |
+
self.apply(init_func)
|
| 56 |
+
|
| 57 |
+
# propagate to children
|
| 58 |
+
for m in self.children():
|
| 59 |
+
if hasattr(m, 'init_weights'):
|
| 60 |
+
m.init_weights(init_type, gain)
|
modeling/build.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modeling.arch import IFRNet, CIFR_Encoder, Discriminator, PatchDiscriminator, MLP, PatchSampleF
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def build_model(args):
|
| 5 |
+
if args.MODEL.NAME.lower() == "ifrnet":
|
| 6 |
+
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
| 7 |
+
mlp = MLP(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, out_features=args.MODEL.NUM_CLASS)
|
| 8 |
+
elif args.MODEL.NAME.lower() == "cifr":
|
| 9 |
+
net = CIFR_Encoder(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
| 10 |
+
mlp = None
|
| 11 |
+
elif args.MODEL.NAME.lower() == "ifr-no-aux":
|
| 12 |
+
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
| 13 |
+
mlp = None
|
| 14 |
+
else:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
return net, mlp
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_discriminators(args):
|
| 20 |
+
return Discriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS), PatchDiscriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def build_patch_sampler(args):
|
| 24 |
+
return PatchSampleF(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, style_or_content="content", use_mlp=args.MODEL.PATCH.USE_MLP, nc=args.MODEL.PATCH.NUM_CHANNELS), \
|
| 25 |
+
PatchSampleF(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, style_or_content="style", use_mlp=args.MODEL.PATCH.USE_MLP, nc=args.MODEL.PATCH.NUM_CHANNELS)
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==2.9.4
|
| 2 |
+
numpy==1.21.2
|
| 3 |
+
requests==2.27.1
|
| 4 |
+
torch==1.10.1
|
| 5 |
+
torchvision==0.11.2
|
| 6 |
+
yacs==0.1.8
|
utils/__pycache__/data_utils.cpython-37.pyc
ADDED
|
Binary file (1.07 kB). View file
|
|
|
utils/data_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def linear_scaling(x):
|
| 5 |
+
return (x * 255.) / 127.5 - 1.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def linear_unscaling(x):
|
| 9 |
+
return (x + 1.) * 127.5 / 255.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def read_json(path):
|
| 13 |
+
"""
|
| 14 |
+
:param path (str or os.Path): JSON file path.
|
| 15 |
+
:return: (Dict): the data in the JSON file.
|
| 16 |
+
"""
|
| 17 |
+
with open(path) as f:
|
| 18 |
+
data = json.load(f)
|
| 19 |
+
return data
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def write_json(path, datagroup):
|
| 23 |
+
"""
|
| 24 |
+
:param path (str or os.Path): File path for the output JSON file.
|
| 25 |
+
:param datagroup (Dict): The data which should be dump to the JSON file.
|
| 26 |
+
:return: void.
|
| 27 |
+
"""
|
| 28 |
+
with open(path, "w+") as f:
|
| 29 |
+
json.dump(datagroup, f)
|