Add application file
Browse files- app.py +94 -29
- example/generate_examples.py +49 -0
- inference.py +410 -0
- losses.py +248 -0
- models/__init__.py +3 -0
- models/anime_gan.py +112 -0
- models/anime_gan_v2.py +65 -0
- models/anime_gan_v3.py +14 -0
- models/conv_blocks.py +171 -0
- models/layers.py +28 -0
- models/vgg.py +80 -0
- predict.py +35 -0
- requirements.txt +8 -9
- train.py +163 -0
- trainer/__init__.py +437 -0
- utils/__init__.py +21 -0
- utils/common.py +188 -0
- utils/fast_numpyio.py +43 -0
- utils/image_processing.py +135 -0
- utils/logger.py +24 -0
app.py
CHANGED
|
@@ -1,33 +1,98 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
|
|
|
| 3 |
import gradio as gr
|
|
|
|
|
|
|
| 4 |
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
+
from inference import Predictor
|
| 6 |
+
from utils.image_processing import resize_image
|
| 7 |
|
| 8 |
+
os.makedirs('output', exist_ok=True)
|
| 9 |
|
| 10 |
|
| 11 |
+
def inference(
|
| 12 |
+
image: np.ndarray,
|
| 13 |
+
style,
|
| 14 |
+
imgsz=None,
|
| 15 |
+
):
|
| 16 |
+
if imgsz is not None:
|
| 17 |
+
imgsz = int(imgsz)
|
| 18 |
+
|
| 19 |
+
retain_color = False
|
| 20 |
+
|
| 21 |
+
weight = {
|
| 22 |
+
"AnimeGAN_Hayao": "hayao",
|
| 23 |
+
"AnimeGAN_Shinkai": "shinkai",
|
| 24 |
+
"AnimeGANv2_Hayao": "hayao:v2",
|
| 25 |
+
"AnimeGANv2_Shinkai": "shinkai:v2",
|
| 26 |
+
"AnimeGANv2_Arcane": "arcane:v2",
|
| 27 |
+
}[style]
|
| 28 |
+
predictor = Predictor(
|
| 29 |
+
weight,
|
| 30 |
+
device='cpu',
|
| 31 |
+
retain_color=retain_color,
|
| 32 |
+
imgsz=imgsz,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
save_path = f"output/out.jpg"
|
| 36 |
+
image = resize_image(image, width=imgsz)
|
| 37 |
+
anime_image = predictor.transform(image)[0]
|
| 38 |
+
cv2.imwrite(save_path, anime_image[..., ::-1])
|
| 39 |
+
return anime_image, save_path
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
title = "AnimeGANv2: To produce your own animation."
|
| 43 |
+
description = r"""Turn your photo into anime style 😊"""
|
| 44 |
+
article = r"""
|
| 45 |
+
[](https://github.com/ptran1203/pytorch-animeGAN)
|
| 46 |
+
### 🗻 Demo
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
gr.Interface(
|
| 51 |
+
fn=inference,
|
| 52 |
+
inputs=[
|
| 53 |
+
gr.components.Image(label="Input"),
|
| 54 |
+
gr.Dropdown(
|
| 55 |
+
[
|
| 56 |
+
'AnimeGAN_Hayao',
|
| 57 |
+
'AnimeGAN_Shinkai',
|
| 58 |
+
'AnimeGANv2_Hayao',
|
| 59 |
+
'AnimeGANv2_Shinkai',
|
| 60 |
+
'AnimeGANv2_Arcane',
|
| 61 |
+
],
|
| 62 |
+
type="value",
|
| 63 |
+
value='AnimeGANv2_Hayao',
|
| 64 |
+
label='Style'
|
| 65 |
+
),
|
| 66 |
+
gr.Dropdown(
|
| 67 |
+
[
|
| 68 |
+
None,
|
| 69 |
+
416,
|
| 70 |
+
512,
|
| 71 |
+
768,
|
| 72 |
+
1024,
|
| 73 |
+
1536,
|
| 74 |
+
],
|
| 75 |
+
type="value",
|
| 76 |
+
value=None,
|
| 77 |
+
label='Image size'
|
| 78 |
+
)
|
| 79 |
+
],
|
| 80 |
+
outputs=[
|
| 81 |
+
gr.components.Image(type="numpy", label="Output (The whole image)"),
|
| 82 |
+
gr.components.File(label="Download the output image")
|
| 83 |
+
],
|
| 84 |
+
title=title,
|
| 85 |
+
description=description,
|
| 86 |
+
article=article,
|
| 87 |
+
allow_flagging="never",
|
| 88 |
+
examples=[
|
| 89 |
+
['example/face/girl4.jpg', 'AnimeGANv2_Arcane', None],
|
| 90 |
+
['example/face/leo.jpg', 'AnimeGANv2_Arcane', None],
|
| 91 |
+
['example/face/cap.jpg', 'AnimeGANv2_Arcane', None],
|
| 92 |
+
['example/face/anne.jpg', 'AnimeGANv2_Arcane', None],
|
| 93 |
+
# ['example/boy2.jpg', 'AnimeGANv3_Arcane', "No"],
|
| 94 |
+
# ['example/cap.jpg', 'AnimeGANv3_Arcane', "No"],
|
| 95 |
+
['example/landscape/pexels-camilacarneiro-6318793.jpg', 'AnimeGANv2_Hayao', None],
|
| 96 |
+
['example/landscape/pexels-nandhukumar-450441.jpg', 'AnimeGANv2_Hayao', None],
|
| 97 |
+
]
|
| 98 |
+
).launch()
|
example/generate_examples.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
REG = re.compile(r"[0-9]{3}")
|
| 6 |
+
dir_ = './example/result'
|
| 7 |
+
readme = './README.md'
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def anime_2_input(fi):
|
| 11 |
+
return fi.replace("_anime", "")
|
| 12 |
+
|
| 13 |
+
def rename(f):
|
| 14 |
+
return f.replace(" ", "").replace("(", "").replace(")", "")
|
| 15 |
+
|
| 16 |
+
def rename_back(f):
|
| 17 |
+
nums = REG.search(f)
|
| 18 |
+
if nums:
|
| 19 |
+
nums = nums.group()
|
| 20 |
+
return f.replace(nums, f"{nums[0]} ({nums[1:]})")
|
| 21 |
+
|
| 22 |
+
return f.replace('jpeg', 'jpg')
|
| 23 |
+
|
| 24 |
+
def copyfile(src, dest):
|
| 25 |
+
# copy and resize
|
| 26 |
+
im = cv2.imread(src)
|
| 27 |
+
|
| 28 |
+
if im is None:
|
| 29 |
+
raise FileNotFoundError(src)
|
| 30 |
+
|
| 31 |
+
h, w = im.shape[1], im.shape[0]
|
| 32 |
+
|
| 33 |
+
s = 448
|
| 34 |
+
size = (s, round(s * w / h))
|
| 35 |
+
im = cv2.resize(im, size)
|
| 36 |
+
|
| 37 |
+
print(w, h, im.shape)
|
| 38 |
+
cv2.imwrite(dest, im)
|
| 39 |
+
|
| 40 |
+
files = os.listdir(dir_)
|
| 41 |
+
new_files = []
|
| 42 |
+
for f in files:
|
| 43 |
+
input_ver = os.path.join(dir_, anime_2_input(f))
|
| 44 |
+
copyfile(f"dataset/test/HR_photo/{rename_back(anime_2_input(f))}", rename(input_ver))
|
| 45 |
+
|
| 46 |
+
os.rename(
|
| 47 |
+
os.path.join(dir_, f),
|
| 48 |
+
os.path.join(dir_, rename(f))
|
| 49 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from models.anime_gan import GeneratorV1
|
| 10 |
+
from models.anime_gan_v2 import GeneratorV2
|
| 11 |
+
from models.anime_gan_v3 import GeneratorV3
|
| 12 |
+
from utils.common import load_checkpoint, RELEASED_WEIGHTS
|
| 13 |
+
from utils.image_processing import resize_image, normalize_input, denormalize_input
|
| 14 |
+
from utils import read_image, is_image_file, is_video_file
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
from color_transfer import color_transfer_pytorch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
except ImportError:
|
| 22 |
+
plt = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
|
| 26 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
| 27 |
+
except ImportError:
|
| 28 |
+
ffmpeg_writer = None
|
| 29 |
+
VideoFileClip = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def profile(func):
|
| 33 |
+
def wrap(*args, **kwargs):
|
| 34 |
+
started_at = time.time()
|
| 35 |
+
result = func(*args, **kwargs)
|
| 36 |
+
elapsed = time.time() - started_at
|
| 37 |
+
print(f"Processed in {elapsed:.3f}s")
|
| 38 |
+
return result
|
| 39 |
+
return wrap
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def auto_load_weight(weight, version=None, map_location=None):
|
| 43 |
+
"""Auto load Generator version from weight."""
|
| 44 |
+
weight_name = os.path.basename(weight).lower()
|
| 45 |
+
if version is not None:
|
| 46 |
+
version = version.lower()
|
| 47 |
+
assert version in {"v1", "v2", "v3"}, f"Version {version} does not exist"
|
| 48 |
+
# If version is provided, use it.
|
| 49 |
+
cls = {
|
| 50 |
+
"v1": GeneratorV1,
|
| 51 |
+
"v2": GeneratorV2,
|
| 52 |
+
"v3": GeneratorV3
|
| 53 |
+
}[version]
|
| 54 |
+
else:
|
| 55 |
+
# Try to get class by name of weight file
|
| 56 |
+
# For convenenice, weight should start with classname
|
| 57 |
+
# e.g: Generatorv2_{anything}.pt
|
| 58 |
+
if weight_name in RELEASED_WEIGHTS:
|
| 59 |
+
version = RELEASED_WEIGHTS[weight_name][0]
|
| 60 |
+
return auto_load_weight(weight, version=version, map_location=map_location)
|
| 61 |
+
|
| 62 |
+
elif weight_name.startswith("generatorv2"):
|
| 63 |
+
cls = GeneratorV2
|
| 64 |
+
elif weight_name.startswith("generatorv3"):
|
| 65 |
+
cls = GeneratorV3
|
| 66 |
+
elif weight_name.startswith("generator"):
|
| 67 |
+
cls = GeneratorV1
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError((f"Can not get Model from {weight_name}, "
|
| 70 |
+
"you might need to explicitly specify version"))
|
| 71 |
+
model = cls()
|
| 72 |
+
load_checkpoint(model, weight, strip_optimizer=True, map_location=map_location)
|
| 73 |
+
model.eval()
|
| 74 |
+
return model
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Predictor:
|
| 78 |
+
"""
|
| 79 |
+
Generic class for transfering Image to anime like image.
|
| 80 |
+
"""
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
weight='hayao',
|
| 84 |
+
device='cuda',
|
| 85 |
+
amp=True,
|
| 86 |
+
retain_color=False,
|
| 87 |
+
imgsz=None,
|
| 88 |
+
):
|
| 89 |
+
if not torch.cuda.is_available():
|
| 90 |
+
device = 'cpu'
|
| 91 |
+
# Amp not working on cpu
|
| 92 |
+
amp = False
|
| 93 |
+
print("Use CPU device")
|
| 94 |
+
else:
|
| 95 |
+
print(f"Use GPU {torch.cuda.get_device_name()}")
|
| 96 |
+
|
| 97 |
+
self.imgsz = imgsz
|
| 98 |
+
self.retain_color = retain_color
|
| 99 |
+
self.amp = amp # Automatic Mixed Precision
|
| 100 |
+
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
| 101 |
+
self.device = torch.device(device)
|
| 102 |
+
self.G = auto_load_weight(weight, map_location=device)
|
| 103 |
+
self.G.to(self.device)
|
| 104 |
+
|
| 105 |
+
def transform_and_show(
|
| 106 |
+
self,
|
| 107 |
+
image_path,
|
| 108 |
+
figsize=(18, 10),
|
| 109 |
+
save_path=None
|
| 110 |
+
):
|
| 111 |
+
image = resize_image(read_image(image_path))
|
| 112 |
+
anime_img = self.transform(image)
|
| 113 |
+
anime_img = anime_img.astype('uint8')
|
| 114 |
+
|
| 115 |
+
fig = plt.figure(figsize=figsize)
|
| 116 |
+
fig.add_subplot(1, 2, 1)
|
| 117 |
+
# plt.title("Input")
|
| 118 |
+
plt.imshow(image)
|
| 119 |
+
plt.axis('off')
|
| 120 |
+
fig.add_subplot(1, 2, 2)
|
| 121 |
+
# plt.title("Anime style")
|
| 122 |
+
plt.imshow(anime_img[0])
|
| 123 |
+
plt.axis('off')
|
| 124 |
+
plt.tight_layout()
|
| 125 |
+
plt.show()
|
| 126 |
+
if save_path is not None:
|
| 127 |
+
plt.savefig(save_path)
|
| 128 |
+
|
| 129 |
+
def transform(self, image, denorm=True):
|
| 130 |
+
'''
|
| 131 |
+
Transform a image to animation
|
| 132 |
+
|
| 133 |
+
@Arguments:
|
| 134 |
+
- image: np.array, shape = (Batch, width, height, channels)
|
| 135 |
+
|
| 136 |
+
@Returns:
|
| 137 |
+
- anime version of image: np.array
|
| 138 |
+
'''
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
image = self.preprocess_images(image)
|
| 141 |
+
# image = image.to(self.device)
|
| 142 |
+
# with autocast(self.device_type, enabled=self.amp):
|
| 143 |
+
# print(image.dtype, self.G)
|
| 144 |
+
fake = self.G(image)
|
| 145 |
+
# Transfer color of fake image look similiar color as image
|
| 146 |
+
if self.retain_color:
|
| 147 |
+
fake = color_transfer_pytorch(fake, image)
|
| 148 |
+
fake = (fake / 0.5) - 1.0 # remap to [-1. 1]
|
| 149 |
+
fake = fake.detach().cpu().numpy()
|
| 150 |
+
# Channel last
|
| 151 |
+
fake = fake.transpose(0, 2, 3, 1)
|
| 152 |
+
|
| 153 |
+
if denorm:
|
| 154 |
+
fake = denormalize_input(fake, dtype=np.uint8)
|
| 155 |
+
return fake
|
| 156 |
+
|
| 157 |
+
def read_and_resize(self, path, max_size=1536):
|
| 158 |
+
image = read_image(path)
|
| 159 |
+
_, ext = os.path.splitext(path)
|
| 160 |
+
h, w = image.shape[:2]
|
| 161 |
+
if self.imgsz is not None:
|
| 162 |
+
image = resize_image(image, width=self.imgsz)
|
| 163 |
+
elif max(h, w) > max_size:
|
| 164 |
+
print(f"Image {os.path.basename(path)} is too big ({h}x{w}), resize to max size {max_size}")
|
| 165 |
+
image = resize_image(
|
| 166 |
+
image,
|
| 167 |
+
width=max_size if w > h else None,
|
| 168 |
+
height=max_size if w < h else None,
|
| 169 |
+
)
|
| 170 |
+
cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
|
| 171 |
+
else:
|
| 172 |
+
image = resize_image(image)
|
| 173 |
+
# image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| 174 |
+
# image = np.stack([image, image, image], -1)
|
| 175 |
+
# cv2.imwrite(path.replace(ext, ".jpg"), image[:,:,::-1])
|
| 176 |
+
return image
|
| 177 |
+
|
| 178 |
+
@profile
|
| 179 |
+
def transform_file(self, file_path, save_path):
|
| 180 |
+
if not is_image_file(save_path):
|
| 181 |
+
raise ValueError(f"{save_path} is not valid")
|
| 182 |
+
|
| 183 |
+
image = self.read_and_resize(file_path)
|
| 184 |
+
anime_img = self.transform(image)[0]
|
| 185 |
+
cv2.imwrite(save_path, anime_img[..., ::-1])
|
| 186 |
+
print(f"Anime image saved to {save_path}")
|
| 187 |
+
return anime_img
|
| 188 |
+
|
| 189 |
+
@profile
|
| 190 |
+
def transform_gif(self, file_path, save_path, batch_size=4):
|
| 191 |
+
import imageio
|
| 192 |
+
|
| 193 |
+
def _preprocess_gif(img):
|
| 194 |
+
if img.shape[-1] == 4:
|
| 195 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
| 196 |
+
return resize_image(img)
|
| 197 |
+
|
| 198 |
+
images = imageio.mimread(file_path)
|
| 199 |
+
images = np.stack([
|
| 200 |
+
_preprocess_gif(img)
|
| 201 |
+
for img in images
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
print(images.shape)
|
| 205 |
+
|
| 206 |
+
anime_gif = np.zeros_like(images)
|
| 207 |
+
|
| 208 |
+
for i in tqdm(range(0, len(images), batch_size)):
|
| 209 |
+
end = i + batch_size
|
| 210 |
+
anime_gif[i: end] = self.transform(
|
| 211 |
+
images[i: end]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if end < len(images) - 1:
|
| 215 |
+
# transform last frame
|
| 216 |
+
print("LAST", images[end: ].shape)
|
| 217 |
+
anime_gif[end:] = self.transform(images[end:])
|
| 218 |
+
|
| 219 |
+
print(anime_gif.shape)
|
| 220 |
+
imageio.mimsave(
|
| 221 |
+
save_path,
|
| 222 |
+
anime_gif,
|
| 223 |
+
|
| 224 |
+
)
|
| 225 |
+
print(f"Anime image saved to {save_path}")
|
| 226 |
+
|
| 227 |
+
@profile
|
| 228 |
+
def transform_in_dir(self, img_dir, dest_dir, max_images=0, img_size=(512, 512)):
|
| 229 |
+
'''
|
| 230 |
+
Read all images from img_dir, transform and write the result
|
| 231 |
+
to dest_dir
|
| 232 |
+
|
| 233 |
+
'''
|
| 234 |
+
os.makedirs(dest_dir, exist_ok=True)
|
| 235 |
+
|
| 236 |
+
files = os.listdir(img_dir)
|
| 237 |
+
files = [f for f in files if is_image_file(f)]
|
| 238 |
+
print(f'Found {len(files)} images in {img_dir}')
|
| 239 |
+
|
| 240 |
+
if max_images:
|
| 241 |
+
files = files[:max_images]
|
| 242 |
+
|
| 243 |
+
bar = tqdm(files)
|
| 244 |
+
for fname in bar:
|
| 245 |
+
path = os.path.join(img_dir, fname)
|
| 246 |
+
image = self.read_and_resize(path)
|
| 247 |
+
anime_img = self.transform(image)[0]
|
| 248 |
+
# anime_img = resize_image(anime_img, width=320)
|
| 249 |
+
ext = fname.split('.')[-1]
|
| 250 |
+
fname = fname.replace(f'.{ext}', '')
|
| 251 |
+
cv2.imwrite(os.path.join(dest_dir, f'{fname}.jpg'), anime_img[..., ::-1])
|
| 252 |
+
bar.set_description(f"{fname} {image.shape}")
|
| 253 |
+
|
| 254 |
+
def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
|
| 255 |
+
'''
|
| 256 |
+
Transform a video to animation version
|
| 257 |
+
https://github.com/lengstrom/fast-style-transfer/blob/master/evaluate.py#L21
|
| 258 |
+
'''
|
| 259 |
+
if VideoFileClip is None:
|
| 260 |
+
raise ImportError("moviepy is not installed, please install with `pip install moviepy>=1.0.3`")
|
| 261 |
+
# Force to None
|
| 262 |
+
end = end or None
|
| 263 |
+
|
| 264 |
+
if not os.path.isfile(input_path):
|
| 265 |
+
raise FileNotFoundError(f'{input_path} does not exist')
|
| 266 |
+
|
| 267 |
+
output_dir = os.path.dirname(output_path)
|
| 268 |
+
if output_dir:
|
| 269 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 270 |
+
|
| 271 |
+
is_gg_drive = '/drive/' in output_path
|
| 272 |
+
temp_file = ''
|
| 273 |
+
|
| 274 |
+
if is_gg_drive:
|
| 275 |
+
# Writing directly into google drive can be inefficient
|
| 276 |
+
temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
|
| 277 |
+
|
| 278 |
+
def transform_and_write(frames, count, writer):
|
| 279 |
+
anime_images = self.transform(frames)
|
| 280 |
+
for i in range(0, count):
|
| 281 |
+
img = np.clip(anime_images[i], 0, 255)
|
| 282 |
+
writer.write_frame(img)
|
| 283 |
+
|
| 284 |
+
video_clip = VideoFileClip(input_path, audio=False)
|
| 285 |
+
if start or end:
|
| 286 |
+
video_clip = video_clip.subclip(start, end)
|
| 287 |
+
|
| 288 |
+
video_writer = ffmpeg_writer.FFMPEG_VideoWriter(
|
| 289 |
+
temp_file or output_path,
|
| 290 |
+
video_clip.size, video_clip.fps,
|
| 291 |
+
codec="libx264",
|
| 292 |
+
# preset="medium", bitrate="2000k",
|
| 293 |
+
ffmpeg_params=None)
|
| 294 |
+
|
| 295 |
+
total_frames = round(video_clip.fps * video_clip.duration)
|
| 296 |
+
print(f'Transfroming video {input_path}, {total_frames} frames, size: {video_clip.size}')
|
| 297 |
+
|
| 298 |
+
batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
|
| 299 |
+
frame_count = 0
|
| 300 |
+
frames = np.zeros(batch_shape, dtype=np.float32)
|
| 301 |
+
for frame in tqdm(video_clip.iter_frames(), total=total_frames):
|
| 302 |
+
try:
|
| 303 |
+
frames[frame_count] = frame
|
| 304 |
+
frame_count += 1
|
| 305 |
+
if frame_count == batch_size:
|
| 306 |
+
transform_and_write(frames, frame_count, video_writer)
|
| 307 |
+
frame_count = 0
|
| 308 |
+
except Exception as e:
|
| 309 |
+
print(e)
|
| 310 |
+
break
|
| 311 |
+
|
| 312 |
+
# The last frames
|
| 313 |
+
if frame_count != 0:
|
| 314 |
+
transform_and_write(frames, frame_count, video_writer)
|
| 315 |
+
|
| 316 |
+
if temp_file:
|
| 317 |
+
# move to output path
|
| 318 |
+
shutil.move(temp_file, output_path)
|
| 319 |
+
|
| 320 |
+
print(f'Animation video saved to {output_path}')
|
| 321 |
+
video_writer.close()
|
| 322 |
+
|
| 323 |
+
def preprocess_images(self, images):
|
| 324 |
+
'''
|
| 325 |
+
Preprocess image for inference
|
| 326 |
+
|
| 327 |
+
@Arguments:
|
| 328 |
+
- images: np.ndarray
|
| 329 |
+
|
| 330 |
+
@Returns
|
| 331 |
+
- images: torch.tensor
|
| 332 |
+
'''
|
| 333 |
+
images = images.astype(np.float32)
|
| 334 |
+
|
| 335 |
+
# Normalize to [-1, 1]
|
| 336 |
+
images = normalize_input(images)
|
| 337 |
+
images = torch.from_numpy(images)
|
| 338 |
+
|
| 339 |
+
images = images.to(self.device)
|
| 340 |
+
|
| 341 |
+
# Add batch dim
|
| 342 |
+
if len(images.shape) == 3:
|
| 343 |
+
images = images.unsqueeze(0)
|
| 344 |
+
|
| 345 |
+
# channel first
|
| 346 |
+
images = images.permute(0, 3, 1, 2)
|
| 347 |
+
|
| 348 |
+
return images
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def parse_args():
|
| 352 |
+
import argparse
|
| 353 |
+
parser = argparse.ArgumentParser()
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
'--weight',
|
| 356 |
+
type=str,
|
| 357 |
+
default="hayao:v2",
|
| 358 |
+
help=f'Model weight, can be path or pretrained {tuple(RELEASED_WEIGHTS.keys())}'
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument('--src', type=str, help='Source, can be directory contains images, image file or video file.')
|
| 361 |
+
parser.add_argument('--device', type=str, default='cuda', help='Device, cuda or cpu')
|
| 362 |
+
parser.add_argument('--imgsz', type=int, default=None, help='Resize image to specified size if provided')
|
| 363 |
+
parser.add_argument('--out', type=str, default='inference_images', help='Output, can be directory or file')
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
'--retain-color',
|
| 366 |
+
action='store_true',
|
| 367 |
+
help='If provided the generated image will retain original color of input image')
|
| 368 |
+
# Video params
|
| 369 |
+
parser.add_argument('--batch-size', type=int, default=4, help='Batch size when inference video')
|
| 370 |
+
parser.add_argument('--start', type=int, default=0, help='Start time of video (second)')
|
| 371 |
+
parser.add_argument('--end', type=int, default=0, help='End time of video (second), 0 if not set')
|
| 372 |
+
|
| 373 |
+
return parser.parse_args()
|
| 374 |
+
|
| 375 |
+
if __name__ == '__main__':
|
| 376 |
+
args = parse_args()
|
| 377 |
+
|
| 378 |
+
predictor = Predictor(
|
| 379 |
+
args.weight,
|
| 380 |
+
args.device,
|
| 381 |
+
retain_color=args.retain_color,
|
| 382 |
+
imgsz=args.imgsz,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if not os.path.exists(args.src):
|
| 386 |
+
raise FileNotFoundError(args.src)
|
| 387 |
+
|
| 388 |
+
if is_video_file(args.src):
|
| 389 |
+
predictor.transform_video(
|
| 390 |
+
args.src,
|
| 391 |
+
args.out,
|
| 392 |
+
args.batch_size,
|
| 393 |
+
start=args.start,
|
| 394 |
+
end=args.end
|
| 395 |
+
)
|
| 396 |
+
elif os.path.isdir(args.src):
|
| 397 |
+
predictor.transform_in_dir(args.src, args.out)
|
| 398 |
+
elif os.path.isfile(args.src):
|
| 399 |
+
save_path = args.out
|
| 400 |
+
if not is_image_file(args.out):
|
| 401 |
+
os.makedirs(args.out, exist_ok=True)
|
| 402 |
+
save_path = os.path.join(args.out, os.path.basename(args.src))
|
| 403 |
+
|
| 404 |
+
if args.src.endswith('.gif'):
|
| 405 |
+
# GIF file
|
| 406 |
+
predictor.transform_gif(args.src, save_path, args.batch_size)
|
| 407 |
+
else:
|
| 408 |
+
predictor.transform_file(args.src, save_path)
|
| 409 |
+
else:
|
| 410 |
+
raise NotImplementedError(f"{args.src} is not supported")
|
losses.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from models.vgg import Vgg19
|
| 5 |
+
from utils.image_processing import gram
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def to_gray_scale(image):
|
| 9 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional/_color.py#L33
|
| 10 |
+
# Image are assum in range 1, -1
|
| 11 |
+
image = (image + 1.0) / 2.0 # To [0, 1]
|
| 12 |
+
r, g, b = image.unbind(dim=-3)
|
| 13 |
+
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
|
| 14 |
+
l_img = l_img.unsqueeze(dim=-3)
|
| 15 |
+
l_img = l_img.to(image.dtype)
|
| 16 |
+
l_img = l_img.expand(image.shape)
|
| 17 |
+
l_img = l_img / 0.5 - 1.0 # To [-1, 1]
|
| 18 |
+
return l_img
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ColorLoss(nn.Module):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super(ColorLoss, self).__init__()
|
| 24 |
+
self.l1 = nn.L1Loss()
|
| 25 |
+
self.huber = nn.SmoothL1Loss()
|
| 26 |
+
# self._rgb_to_yuv_kernel = torch.tensor([
|
| 27 |
+
# [0.299, -0.14714119, 0.61497538],
|
| 28 |
+
# [0.587, -0.28886916, -0.51496512],
|
| 29 |
+
# [0.114, 0.43601035, -0.10001026]
|
| 30 |
+
# ]).float()
|
| 31 |
+
|
| 32 |
+
self._rgb_to_yuv_kernel = torch.tensor([
|
| 33 |
+
[0.299, 0.587, 0.114],
|
| 34 |
+
[-0.14714119, -0.28886916, 0.43601035],
|
| 35 |
+
[0.61497538, -0.51496512, -0.10001026],
|
| 36 |
+
]).float()
|
| 37 |
+
|
| 38 |
+
def to(self, device):
|
| 39 |
+
new_self = super(ColorLoss, self).to(device)
|
| 40 |
+
new_self._rgb_to_yuv_kernel = new_self._rgb_to_yuv_kernel.to(device)
|
| 41 |
+
return new_self
|
| 42 |
+
|
| 43 |
+
def rgb_to_yuv(self, image):
|
| 44 |
+
'''
|
| 45 |
+
https://en.wikipedia.org/wiki/YUV
|
| 46 |
+
|
| 47 |
+
output: Image of shape (H, W, C) (channel last)
|
| 48 |
+
'''
|
| 49 |
+
# -1 1 -> 0 1
|
| 50 |
+
image = (image + 1.0) / 2.0
|
| 51 |
+
image = image.permute(0, 2, 3, 1) # To channel last
|
| 52 |
+
|
| 53 |
+
yuv_img = image @ self._rgb_to_yuv_kernel.T
|
| 54 |
+
|
| 55 |
+
return yuv_img
|
| 56 |
+
|
| 57 |
+
def forward(self, image, image_g):
|
| 58 |
+
image = self.rgb_to_yuv(image)
|
| 59 |
+
image_g = self.rgb_to_yuv(image_g)
|
| 60 |
+
# After convert to yuv, both images have channel last
|
| 61 |
+
return (
|
| 62 |
+
self.l1(image[:, :, :, 0], image_g[:, :, :, 0])
|
| 63 |
+
+ self.huber(image[:, :, :, 1], image_g[:, :, :, 1])
|
| 64 |
+
+ self.huber(image[:, :, :, 2], image_g[:, :, :, 2])
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AnimeGanLoss:
|
| 69 |
+
def __init__(self, args, device, gray_adv=False):
|
| 70 |
+
if isinstance(device, str):
|
| 71 |
+
device = torch.device(device)
|
| 72 |
+
|
| 73 |
+
self.content_loss = nn.L1Loss().to(device)
|
| 74 |
+
self.gram_loss = nn.L1Loss().to(device)
|
| 75 |
+
self.color_loss = ColorLoss().to(device)
|
| 76 |
+
self.wadvg = args.wadvg
|
| 77 |
+
self.wadvd = args.wadvd
|
| 78 |
+
self.wcon = args.wcon
|
| 79 |
+
self.wgra = args.wgra
|
| 80 |
+
self.wcol = args.wcol
|
| 81 |
+
self.wtvar = args.wtvar
|
| 82 |
+
# If true, use gray scale image to calculate adversarial loss
|
| 83 |
+
self.gray_adv = gray_adv
|
| 84 |
+
self.vgg19 = Vgg19().to(device).eval()
|
| 85 |
+
self.adv_type = args.gan_loss
|
| 86 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
| 87 |
+
|
| 88 |
+
def compute_loss_G(self, fake_img, img, fake_logit, anime_gray):
|
| 89 |
+
'''
|
| 90 |
+
Compute loss for Generator
|
| 91 |
+
|
| 92 |
+
@Args:
|
| 93 |
+
- fake_img: generated image
|
| 94 |
+
- img: real image
|
| 95 |
+
- fake_logit: output of Discriminator given fake image
|
| 96 |
+
- anime_gray: grayscale of anime image
|
| 97 |
+
|
| 98 |
+
@Returns:
|
| 99 |
+
- Adversarial Loss of fake logits
|
| 100 |
+
- Content loss between real and fake features (vgg19)
|
| 101 |
+
- Gram loss between anime and fake features (Vgg19)
|
| 102 |
+
- Color loss between image and fake image
|
| 103 |
+
- Total variation loss of fake image
|
| 104 |
+
'''
|
| 105 |
+
fake_feat = self.vgg19(fake_img)
|
| 106 |
+
gray_feat = self.vgg19(anime_gray)
|
| 107 |
+
img_feat = self.vgg19(img)
|
| 108 |
+
# fake_gray_feat = self.vgg19(to_gray_scale(fake_img))
|
| 109 |
+
|
| 110 |
+
return [
|
| 111 |
+
# Want to be real image.
|
| 112 |
+
self.wadvg * self.adv_loss_g(fake_logit),
|
| 113 |
+
self.wcon * self.content_loss(img_feat, fake_feat),
|
| 114 |
+
self.wgra * self.gram_loss(gram(gray_feat), gram(fake_feat)),
|
| 115 |
+
self.wcol * self.color_loss(img, fake_img),
|
| 116 |
+
self.wtvar * self.total_variation_loss(fake_img)
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
def compute_loss_D(
|
| 120 |
+
self,
|
| 121 |
+
fake_img_d,
|
| 122 |
+
real_anime_d,
|
| 123 |
+
real_anime_gray_d,
|
| 124 |
+
real_anime_smooth_gray_d=None
|
| 125 |
+
):
|
| 126 |
+
if self.gray_adv:
|
| 127 |
+
# Treat gray scale image as real
|
| 128 |
+
return (
|
| 129 |
+
self.adv_loss_d_real(real_anime_gray_d)
|
| 130 |
+
+ self.adv_loss_d_fake(fake_img_d)
|
| 131 |
+
+ 0.3 * self.adv_loss_d_fake(real_anime_smooth_gray_d)
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
return (
|
| 135 |
+
# Classify real anime as real
|
| 136 |
+
self.adv_loss_d_real(real_anime_d)
|
| 137 |
+
# Classify generated as fake
|
| 138 |
+
+ self.adv_loss_d_fake(fake_img_d)
|
| 139 |
+
# Classify real anime gray as fake
|
| 140 |
+
# + self.adv_loss_d_fake(real_anime_gray_d)
|
| 141 |
+
# Classify real anime as fake
|
| 142 |
+
# + 0.1 * self.adv_loss_d_fake(real_anime_smooth_gray_d)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def total_variation_loss(self, fake_img):
|
| 146 |
+
"""
|
| 147 |
+
A smooth loss in fact. Like the smooth prior in MRF.
|
| 148 |
+
V(y) = || y_{n+1} - y_n ||_2
|
| 149 |
+
"""
|
| 150 |
+
# Channel first -> channel last
|
| 151 |
+
fake_img = fake_img.permute(0, 2, 3, 1)
|
| 152 |
+
def _l2(x):
|
| 153 |
+
# sum(t ** 2) / 2
|
| 154 |
+
return torch.sum(x ** 2) / 2
|
| 155 |
+
|
| 156 |
+
dh = fake_img[:, :-1, ...] - fake_img[:, 1:, ...]
|
| 157 |
+
dw = fake_img[:, :, :-1, ...] - fake_img[:, :, 1:, ...]
|
| 158 |
+
return _l2(dh) / dh.numel() + _l2(dw) / dw.numel()
|
| 159 |
+
|
| 160 |
+
def content_loss_vgg(self, image, recontruction):
|
| 161 |
+
feat = self.vgg19(image)
|
| 162 |
+
re_feat = self.vgg19(recontruction)
|
| 163 |
+
feature_loss = self.content_loss(feat, re_feat)
|
| 164 |
+
content_loss = self.content_loss(image, recontruction)
|
| 165 |
+
return feature_loss# + 0.5 * content_loss
|
| 166 |
+
|
| 167 |
+
def adv_loss_d_real(self, pred):
|
| 168 |
+
"""Push pred to class 1 (real)"""
|
| 169 |
+
if self.adv_type == 'hinge':
|
| 170 |
+
return torch.mean(F.relu(1.0 - pred))
|
| 171 |
+
|
| 172 |
+
elif self.adv_type == 'lsgan':
|
| 173 |
+
# pred = torch.sigmoid(pred)
|
| 174 |
+
return torch.mean(torch.square(pred - 1.0))
|
| 175 |
+
|
| 176 |
+
elif self.adv_type == 'bce':
|
| 177 |
+
return self.bce_loss(pred, torch.ones_like(pred))
|
| 178 |
+
|
| 179 |
+
raise ValueError(f'Do not support loss type {self.adv_type}')
|
| 180 |
+
|
| 181 |
+
def adv_loss_d_fake(self, pred):
|
| 182 |
+
"""Push pred to class 0 (fake)"""
|
| 183 |
+
if self.adv_type == 'hinge':
|
| 184 |
+
return torch.mean(F.relu(1.0 + pred))
|
| 185 |
+
|
| 186 |
+
elif self.adv_type == 'lsgan':
|
| 187 |
+
# pred = torch.sigmoid(pred)
|
| 188 |
+
return torch.mean(torch.square(pred))
|
| 189 |
+
|
| 190 |
+
elif self.adv_type == 'bce':
|
| 191 |
+
return self.bce_loss(pred, torch.zeros_like(pred))
|
| 192 |
+
|
| 193 |
+
raise ValueError(f'Do not support loss type {self.adv_type}')
|
| 194 |
+
|
| 195 |
+
def adv_loss_g(self, pred):
|
| 196 |
+
"""Push pred to class 1 (real)"""
|
| 197 |
+
if self.adv_type == 'hinge':
|
| 198 |
+
return -torch.mean(pred)
|
| 199 |
+
|
| 200 |
+
elif self.adv_type == 'lsgan':
|
| 201 |
+
# pred = torch.sigmoid(pred)
|
| 202 |
+
return torch.mean(torch.square(pred - 1.0))
|
| 203 |
+
|
| 204 |
+
elif self.adv_type == 'bce':
|
| 205 |
+
return self.bce_loss(pred, torch.ones_like(pred))
|
| 206 |
+
|
| 207 |
+
raise ValueError(f'Do not support loss type {self.adv_type}')
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class LossSummary:
|
| 211 |
+
def __init__(self):
|
| 212 |
+
self.reset()
|
| 213 |
+
|
| 214 |
+
def reset(self):
|
| 215 |
+
self.loss_g_adv = []
|
| 216 |
+
self.loss_content = []
|
| 217 |
+
self.loss_gram = []
|
| 218 |
+
self.loss_color = []
|
| 219 |
+
self.loss_d_adv = []
|
| 220 |
+
|
| 221 |
+
def update_loss_G(self, adv, gram, color, content):
|
| 222 |
+
self.loss_g_adv.append(adv.cpu().detach().numpy())
|
| 223 |
+
self.loss_gram.append(gram.cpu().detach().numpy())
|
| 224 |
+
self.loss_color.append(color.cpu().detach().numpy())
|
| 225 |
+
self.loss_content.append(content.cpu().detach().numpy())
|
| 226 |
+
|
| 227 |
+
def update_loss_D(self, loss):
|
| 228 |
+
self.loss_d_adv.append(loss.cpu().detach().numpy())
|
| 229 |
+
|
| 230 |
+
def avg_loss_G(self):
|
| 231 |
+
return (
|
| 232 |
+
self._avg(self.loss_g_adv),
|
| 233 |
+
self._avg(self.loss_gram),
|
| 234 |
+
self._avg(self.loss_color),
|
| 235 |
+
self._avg(self.loss_content),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def avg_loss_D(self):
|
| 239 |
+
return self._avg(self.loss_d_adv)
|
| 240 |
+
|
| 241 |
+
def get_loss_description(self):
|
| 242 |
+
avg_adv, avg_gram, avg_color, avg_content = self.avg_loss_G()
|
| 243 |
+
avg_adv_d = self.avg_loss_D()
|
| 244 |
+
return f'loss G: adv {avg_adv:2f} con {avg_content:2f} gram {avg_gram:2f} color {avg_color:2f} / loss D: {avg_adv_d:2f}'
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def _avg(losses):
|
| 248 |
+
return sum(losses) / len(losses)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .anime_gan import GeneratorV1
|
| 2 |
+
from .anime_gan_v2 import GeneratorV2
|
| 3 |
+
from .anime_gan_v3 import GeneratorV3
|
models/anime_gan.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils import spectral_norm
|
| 5 |
+
from .conv_blocks import DownConv
|
| 6 |
+
from .conv_blocks import UpConv
|
| 7 |
+
from .conv_blocks import SeparableConv2D
|
| 8 |
+
from .conv_blocks import InvertedResBlock
|
| 9 |
+
from .conv_blocks import ConvBlock
|
| 10 |
+
from .layers import get_norm
|
| 11 |
+
from utils.common import initialize_weights
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GeneratorV1(nn.Module):
|
| 15 |
+
def __init__(self, dataset=''):
|
| 16 |
+
super(GeneratorV1, self).__init__()
|
| 17 |
+
self.name = f'{self.__class__.__name__}_{dataset}'
|
| 18 |
+
bias = False
|
| 19 |
+
|
| 20 |
+
self.encode_blocks = nn.Sequential(
|
| 21 |
+
ConvBlock(3, 64, bias=bias),
|
| 22 |
+
ConvBlock(64, 128, bias=bias),
|
| 23 |
+
DownConv(128, bias=bias),
|
| 24 |
+
ConvBlock(128, 128, bias=bias),
|
| 25 |
+
SeparableConv2D(128, 256, bias=bias),
|
| 26 |
+
DownConv(256, bias=bias),
|
| 27 |
+
ConvBlock(256, 256, bias=bias),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.res_blocks = nn.Sequential(
|
| 31 |
+
InvertedResBlock(256, 256),
|
| 32 |
+
InvertedResBlock(256, 256),
|
| 33 |
+
InvertedResBlock(256, 256),
|
| 34 |
+
InvertedResBlock(256, 256),
|
| 35 |
+
InvertedResBlock(256, 256),
|
| 36 |
+
InvertedResBlock(256, 256),
|
| 37 |
+
InvertedResBlock(256, 256),
|
| 38 |
+
InvertedResBlock(256, 256),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.decode_blocks = nn.Sequential(
|
| 42 |
+
ConvBlock(256, 128, bias=bias),
|
| 43 |
+
UpConv(128, bias=bias),
|
| 44 |
+
SeparableConv2D(128, 128, bias=bias),
|
| 45 |
+
ConvBlock(128, 128, bias=bias),
|
| 46 |
+
UpConv(128, bias=bias),
|
| 47 |
+
ConvBlock(128, 64, bias=bias),
|
| 48 |
+
ConvBlock(64, 64, bias=bias),
|
| 49 |
+
nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
|
| 50 |
+
nn.Tanh(),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
initialize_weights(self)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
out = self.encode_blocks(x)
|
| 57 |
+
out = self.res_blocks(out)
|
| 58 |
+
img = self.decode_blocks(out)
|
| 59 |
+
|
| 60 |
+
return img
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Discriminator(nn.Module):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
dataset=None,
|
| 67 |
+
num_layers=1,
|
| 68 |
+
use_sn=False,
|
| 69 |
+
norm_type="instance",
|
| 70 |
+
):
|
| 71 |
+
super(Discriminator, self).__init__()
|
| 72 |
+
self.name = f'discriminator_{dataset}'
|
| 73 |
+
self.bias = False
|
| 74 |
+
channels = 32
|
| 75 |
+
|
| 76 |
+
layers = [
|
| 77 |
+
nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
| 78 |
+
nn.LeakyReLU(0.2, True)
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
in_channels = channels
|
| 82 |
+
for i in range(num_layers):
|
| 83 |
+
layers += [
|
| 84 |
+
nn.Conv2d(in_channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
|
| 85 |
+
nn.LeakyReLU(0.2, True),
|
| 86 |
+
nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
| 87 |
+
get_norm(norm_type, channels * 4),
|
| 88 |
+
nn.LeakyReLU(0.2, True),
|
| 89 |
+
]
|
| 90 |
+
in_channels = channels * 4
|
| 91 |
+
channels *= 2
|
| 92 |
+
|
| 93 |
+
channels *= 2
|
| 94 |
+
layers += [
|
| 95 |
+
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
| 96 |
+
get_norm(norm_type, channels),
|
| 97 |
+
nn.LeakyReLU(0.2, True),
|
| 98 |
+
nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
if use_sn:
|
| 102 |
+
for i in range(len(layers)):
|
| 103 |
+
if isinstance(layers[i], nn.Conv2d):
|
| 104 |
+
layers[i] = spectral_norm(layers[i])
|
| 105 |
+
|
| 106 |
+
self.discriminate = nn.Sequential(*layers)
|
| 107 |
+
|
| 108 |
+
initialize_weights(self)
|
| 109 |
+
|
| 110 |
+
def forward(self, img):
|
| 111 |
+
logits = self.discriminate(img)
|
| 112 |
+
return logits
|
models/anime_gan_v2.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.conv_blocks import InvertedResBlock
|
| 5 |
+
from models.conv_blocks import ConvBlock
|
| 6 |
+
from models.conv_blocks import UpConvLNormLReLU
|
| 7 |
+
from utils.common import initialize_weights
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneratorV2(nn.Module):
|
| 11 |
+
def __init__(self, dataset=''):
|
| 12 |
+
super(GeneratorV2, self).__init__()
|
| 13 |
+
self.name = f'{self.__class__.__name__}_{dataset}'
|
| 14 |
+
|
| 15 |
+
self.conv_block1 = nn.Sequential(
|
| 16 |
+
ConvBlock(3, 32, kernel_size=7, stride=1, padding=3, norm_type="layer"),
|
| 17 |
+
ConvBlock(32, 64, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
|
| 18 |
+
ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
self.conv_block2 = nn.Sequential(
|
| 22 |
+
ConvBlock(64, 128, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
|
| 23 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.res_blocks = nn.Sequential(
|
| 27 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
| 28 |
+
InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer"),
|
| 29 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
|
| 30 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
|
| 31 |
+
InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
|
| 32 |
+
ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer"),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.conv_block3 = nn.Sequential(
|
| 36 |
+
# UpConvLNormLReLU(128, 128, norm_type="layer"),
|
| 37 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
| 38 |
+
ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.conv_block4 = nn.Sequential(
|
| 42 |
+
# UpConvLNormLReLU(128, 64, norm_type="layer"),
|
| 43 |
+
ConvBlock(128, 64, kernel_size=3, stride=1, norm_type="layer"),
|
| 44 |
+
ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
|
| 45 |
+
ConvBlock(64, 32, kernel_size=7, padding=3, stride=1, norm_type="layer"),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.decode_blocks = nn.Sequential(
|
| 49 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
| 50 |
+
nn.Tanh(),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
initialize_weights(self)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
out = self.conv_block1(x)
|
| 57 |
+
out = self.conv_block2(out)
|
| 58 |
+
out = self.res_blocks(out)
|
| 59 |
+
out = F.interpolate(out, scale_factor=2, mode="bilinear")
|
| 60 |
+
out = self.conv_block3(out)
|
| 61 |
+
out = F.interpolate(out, scale_factor=2, mode="bilinear")
|
| 62 |
+
out = self.conv_block4(out)
|
| 63 |
+
img = self.decode_blocks(out)
|
| 64 |
+
|
| 65 |
+
return img
|
models/anime_gan_v3.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils import spectral_norm
|
| 5 |
+
from models.conv_blocks import DownConv
|
| 6 |
+
from models.conv_blocks import UpConv
|
| 7 |
+
from models.conv_blocks import SeparableConv2D
|
| 8 |
+
from models.conv_blocks import InvertedResBlock
|
| 9 |
+
from models.conv_blocks import ConvBlock
|
| 10 |
+
from utils.common import initialize_weights
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GeneratorV3(nn.Module):
|
| 14 |
+
pass
|
models/conv_blocks.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from utils.common import initialize_weights
|
| 4 |
+
from .layers import LayerNorm2d, get_norm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DownConv(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, channels, bias=False):
|
| 10 |
+
super(DownConv, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias)
|
| 13 |
+
self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
out1 = self.conv1(x)
|
| 17 |
+
out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
|
| 18 |
+
out2 = self.conv2(out2)
|
| 19 |
+
|
| 20 |
+
return out1 + out2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UpConv(nn.Module):
|
| 24 |
+
def __init__(self, channels, bias=False):
|
| 25 |
+
super(UpConv, self).__init__()
|
| 26 |
+
|
| 27 |
+
self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
|
| 31 |
+
out = self.conv(out)
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class UpConvLNormLReLU(nn.Module):
|
| 36 |
+
"""Upsample Conv block with Layer Norm and Leaky ReLU"""
|
| 37 |
+
def __init__(self, in_channels, out_channels, norm_type="instance", bias=False):
|
| 38 |
+
super(UpConvLNormLReLU, self).__init__()
|
| 39 |
+
|
| 40 |
+
self.conv_block = ConvBlock(
|
| 41 |
+
in_channels,
|
| 42 |
+
out_channels,
|
| 43 |
+
kernel_size=3,
|
| 44 |
+
norm_type=norm_type,
|
| 45 |
+
bias=bias,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
|
| 50 |
+
out = self.conv_block(out)
|
| 51 |
+
return out
|
| 52 |
+
|
| 53 |
+
class SeparableConv2D(nn.Module):
|
| 54 |
+
def __init__(self, in_channels, out_channels, stride=1, bias=False):
|
| 55 |
+
super(SeparableConv2D, self).__init__()
|
| 56 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
|
| 57 |
+
stride=stride, padding=1, groups=in_channels, bias=bias)
|
| 58 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels,
|
| 59 |
+
kernel_size=1, stride=1, bias=bias)
|
| 60 |
+
# self.pad =
|
| 61 |
+
self.ins_norm1 = nn.InstanceNorm2d(in_channels)
|
| 62 |
+
self.activation1 = nn.LeakyReLU(0.2, True)
|
| 63 |
+
self.ins_norm2 = nn.InstanceNorm2d(out_channels)
|
| 64 |
+
self.activation2 = nn.LeakyReLU(0.2, True)
|
| 65 |
+
|
| 66 |
+
initialize_weights(self)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
out = self.depthwise(x)
|
| 70 |
+
out = self.ins_norm1(out)
|
| 71 |
+
out = self.activation1(out)
|
| 72 |
+
|
| 73 |
+
out = self.pointwise(out)
|
| 74 |
+
out = self.ins_norm2(out)
|
| 75 |
+
|
| 76 |
+
return self.activation2(out)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ConvBlock(nn.Module):
|
| 80 |
+
"""Stack of Conv2D + Norm + LeakyReLU"""
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
channels,
|
| 84 |
+
out_channels,
|
| 85 |
+
kernel_size=3,
|
| 86 |
+
stride=1,
|
| 87 |
+
groups=1,
|
| 88 |
+
padding=1,
|
| 89 |
+
bias=False,
|
| 90 |
+
norm_type="instance"
|
| 91 |
+
):
|
| 92 |
+
super(ConvBlock, self).__init__()
|
| 93 |
+
|
| 94 |
+
# if kernel_size == 3 and stride == 1:
|
| 95 |
+
# self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
|
| 96 |
+
# elif kernel_size == 7 and stride == 1:
|
| 97 |
+
# self.pad = nn.ReflectionPad2d((3, 3, 3, 3))
|
| 98 |
+
# elif stride == 2:
|
| 99 |
+
# self.pad = nn.ReflectionPad2d((0, 1, 1, 0))
|
| 100 |
+
# else:
|
| 101 |
+
# self.pad = None
|
| 102 |
+
|
| 103 |
+
self.pad = nn.ReflectionPad2d(padding)
|
| 104 |
+
self.conv = nn.Conv2d(
|
| 105 |
+
channels,
|
| 106 |
+
out_channels,
|
| 107 |
+
kernel_size=kernel_size,
|
| 108 |
+
stride=stride,
|
| 109 |
+
groups=groups,
|
| 110 |
+
padding=0,
|
| 111 |
+
bias=bias
|
| 112 |
+
)
|
| 113 |
+
self.ins_norm = get_norm(norm_type, out_channels)
|
| 114 |
+
self.activation = nn.LeakyReLU(0.2, True)
|
| 115 |
+
|
| 116 |
+
# initialize_weights(self)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
if self.pad is not None:
|
| 120 |
+
x = self.pad(x)
|
| 121 |
+
out = self.conv(x)
|
| 122 |
+
out = self.ins_norm(out)
|
| 123 |
+
out = self.activation(out)
|
| 124 |
+
return out
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class InvertedResBlock(nn.Module):
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
channels=256,
|
| 131 |
+
out_channels=256,
|
| 132 |
+
expand_ratio=2,
|
| 133 |
+
norm_type="instance",
|
| 134 |
+
):
|
| 135 |
+
super(InvertedResBlock, self).__init__()
|
| 136 |
+
bottleneck_dim = round(expand_ratio * channels)
|
| 137 |
+
self.conv_block = ConvBlock(
|
| 138 |
+
channels,
|
| 139 |
+
bottleneck_dim,
|
| 140 |
+
kernel_size=1,
|
| 141 |
+
padding=0,
|
| 142 |
+
norm_type=norm_type,
|
| 143 |
+
bias=False
|
| 144 |
+
)
|
| 145 |
+
self.conv_block2 = ConvBlock(
|
| 146 |
+
bottleneck_dim,
|
| 147 |
+
bottleneck_dim,
|
| 148 |
+
groups=bottleneck_dim,
|
| 149 |
+
norm_type=norm_type,
|
| 150 |
+
bias=True
|
| 151 |
+
)
|
| 152 |
+
self.conv = nn.Conv2d(
|
| 153 |
+
bottleneck_dim,
|
| 154 |
+
out_channels,
|
| 155 |
+
kernel_size=1,
|
| 156 |
+
padding=0,
|
| 157 |
+
bias=False
|
| 158 |
+
)
|
| 159 |
+
self.norm = get_norm(norm_type, out_channels)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
out = self.conv_block(x)
|
| 163 |
+
out = self.conv_block2(out)
|
| 164 |
+
# out = self.activation(out)
|
| 165 |
+
out = self.conv(out)
|
| 166 |
+
out = self.norm(out)
|
| 167 |
+
|
| 168 |
+
if out.shape[1] != x.shape[1]:
|
| 169 |
+
# Only concate if same shape
|
| 170 |
+
return out
|
| 171 |
+
return out + x
|
models/layers.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 8 |
+
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
| 9 |
+
def __init__(self, num_channels, eps=1e-6, affine=True):
|
| 10 |
+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
| 11 |
+
|
| 12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
# https://pytorch.org/vision/0.12/_modules/torchvision/models/convnext.html
|
| 14 |
+
x = x.permute(0, 2, 3, 1)
|
| 15 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 16 |
+
x = x.permute(0, 3, 1, 2)
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_norm(norm_type, channels):
|
| 21 |
+
if norm_type == "instance":
|
| 22 |
+
return nn.InstanceNorm2d(channels)
|
| 23 |
+
elif norm_type == "layer":
|
| 24 |
+
# return LayerNorm2d
|
| 25 |
+
return nn.GroupNorm(num_groups=1, num_channels=channels, affine=True)
|
| 26 |
+
# return partial(nn.GroupNorm, 1, out_ch, 1e-5, True)
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError(norm_type)
|
models/vgg.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numpy.lib.arraysetops import isin
|
| 2 |
+
import torchvision.models as models
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Vgg19(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(Vgg19, self).__init__()
|
| 11 |
+
self.vgg19 = self.get_vgg19().eval()
|
| 12 |
+
vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float()
|
| 13 |
+
vgg_std = torch.tensor([0.229, 0.224, 0.225]).float()
|
| 14 |
+
self.mean = vgg_mean.view(-1, 1 ,1)
|
| 15 |
+
self.std = vgg_std.view(-1, 1, 1)
|
| 16 |
+
|
| 17 |
+
def to(self, device):
|
| 18 |
+
new_self = super(Vgg19, self).to(device)
|
| 19 |
+
new_self.mean = new_self.mean.to(device)
|
| 20 |
+
new_self.std = new_self.std.to(device)
|
| 21 |
+
return new_self
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.vgg19(self.normalize_vgg(x))
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_vgg19(last_layer='conv4_4'):
|
| 28 |
+
vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
|
| 29 |
+
model_list = []
|
| 30 |
+
|
| 31 |
+
i = 0
|
| 32 |
+
j = 1
|
| 33 |
+
for layer in vgg.children():
|
| 34 |
+
if isinstance(layer, nn.MaxPool2d):
|
| 35 |
+
i = 0
|
| 36 |
+
j += 1
|
| 37 |
+
|
| 38 |
+
elif isinstance(layer, nn.Conv2d):
|
| 39 |
+
i += 1
|
| 40 |
+
|
| 41 |
+
name = f'conv{j}_{i}'
|
| 42 |
+
|
| 43 |
+
if name == last_layer:
|
| 44 |
+
model_list.append(layer)
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
model_list.append(layer)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
model = nn.Sequential(*model_list)
|
| 51 |
+
return model
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def normalize_vgg(self, image):
|
| 55 |
+
'''
|
| 56 |
+
Expect input in range -1 1
|
| 57 |
+
'''
|
| 58 |
+
image = (image + 1.0) / 2.0
|
| 59 |
+
return (image - self.mean) / self.std
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
from PIL import Image
|
| 64 |
+
import numpy as np
|
| 65 |
+
from utils.image_processing import normalize_input
|
| 66 |
+
|
| 67 |
+
image = Image.open("example/10.jpg")
|
| 68 |
+
image = image.resize((224, 224))
|
| 69 |
+
np_img = np.array(image).astype('float32')
|
| 70 |
+
np_img = normalize_input(np_img)
|
| 71 |
+
|
| 72 |
+
img = torch.from_numpy(np_img)
|
| 73 |
+
img = img.permute(2, 0, 1)
|
| 74 |
+
img = img.unsqueeze(0)
|
| 75 |
+
|
| 76 |
+
vgg = Vgg19()
|
| 77 |
+
|
| 78 |
+
feat = vgg(img)
|
| 79 |
+
|
| 80 |
+
print(feat.shape)
|
predict.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from inference import Predictor as MyPredictor
|
| 3 |
+
from utils import read_image
|
| 4 |
+
import cv2
|
| 5 |
+
import tempfile
|
| 6 |
+
from utils.image_processing import resize_image, normalize_input, denormalize_input
|
| 7 |
+
import numpy as np
|
| 8 |
+
from cog import BasePredictor, Path, Input
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Predictor(BasePredictor):
|
| 12 |
+
def setup(self):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def predict(
|
| 16 |
+
self,
|
| 17 |
+
image: Path = Input(description="Image"),
|
| 18 |
+
model: str = Input(
|
| 19 |
+
description="Style",
|
| 20 |
+
default='Hayao:v2',
|
| 21 |
+
choices=[
|
| 22 |
+
'Hayao',
|
| 23 |
+
'Shinkai',
|
| 24 |
+
'Hayao:v2'
|
| 25 |
+
]
|
| 26 |
+
)
|
| 27 |
+
) -> Path:
|
| 28 |
+
version = model.split(":")[-1]
|
| 29 |
+
predictor = MyPredictor(model, version)
|
| 30 |
+
img = read_image(str(image))
|
| 31 |
+
anime_img = predictor.transform(resize_image(img))[0]
|
| 32 |
+
out_path = Path(tempfile.mkdtemp()) / "out.png"
|
| 33 |
+
cv2.imwrite(str(out_path), anime_img[..., ::-1])
|
| 34 |
+
return out_path
|
| 35 |
+
|
requirements.txt
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
torch
|
| 2 |
-
torchvision
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
opencv-python-headless
|
|
|
|
| 1 |
+
torch==2.3.0
|
| 2 |
+
torchvision==0.18.0
|
| 3 |
+
numpy==1.24.2
|
| 4 |
+
# ipython==7.21.0
|
| 5 |
+
opencv-python==4.9.0.80
|
| 6 |
+
tqdm==4.21.0
|
| 7 |
+
# moviepy==1.0.3
|
| 8 |
+
color_transfer_py==0.0.4
|
|
|
train.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
from models.anime_gan import GeneratorV1
|
| 5 |
+
from models.anime_gan_v2 import GeneratorV2
|
| 6 |
+
from models.anime_gan_v3 import GeneratorV3
|
| 7 |
+
from models.anime_gan import Discriminator
|
| 8 |
+
from datasets import AnimeDataSet
|
| 9 |
+
from utils.common import load_checkpoint
|
| 10 |
+
from trainer import Trainer
|
| 11 |
+
from utils.logger import get_logger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def parse_args():
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
|
| 17 |
+
parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
|
| 18 |
+
parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
|
| 19 |
+
parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
|
| 20 |
+
parser.add_argument('--epochs', type=int, default=70)
|
| 21 |
+
parser.add_argument('--init_epochs', type=int, default=10)
|
| 22 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
| 23 |
+
parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
|
| 24 |
+
parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
|
| 25 |
+
parser.add_argument('--resume', action='store_true', help="Continue from current dir")
|
| 26 |
+
parser.add_argument('--resume_G_init', type=str, default='False')
|
| 27 |
+
parser.add_argument('--resume_G', type=str, default='False')
|
| 28 |
+
parser.add_argument('--resume_D', type=str, default='False')
|
| 29 |
+
parser.add_argument('--device', type=str, default='cuda')
|
| 30 |
+
parser.add_argument('--use_sn', action='store_true')
|
| 31 |
+
parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
|
| 32 |
+
parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
|
| 33 |
+
parser.add_argument('--save_interval', type=int, default=1)
|
| 34 |
+
parser.add_argument('--debug_samples', type=int, default=0)
|
| 35 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
| 36 |
+
parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
|
| 37 |
+
help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
|
| 38 |
+
parser.add_argument('--resize_method', type=str, default="crop",
|
| 39 |
+
help="Resize image method if origin photo larger than imgsz")
|
| 40 |
+
# Loss stuff
|
| 41 |
+
parser.add_argument('--lr_g', type=float, default=2e-5)
|
| 42 |
+
parser.add_argument('--lr_d', type=float, default=4e-5)
|
| 43 |
+
parser.add_argument('--init_lr', type=float, default=1e-4)
|
| 44 |
+
parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
|
| 45 |
+
parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
'--gray_adv', action='store_true',
|
| 48 |
+
help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
|
| 49 |
+
# Loss weight VGG19
|
| 50 |
+
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
|
| 51 |
+
parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
|
| 52 |
+
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
|
| 53 |
+
parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
|
| 54 |
+
parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
|
| 55 |
+
parser.add_argument('--d_noise', action='store_true')
|
| 56 |
+
|
| 57 |
+
# DDP
|
| 58 |
+
parser.add_argument('--ddp', action='store_true')
|
| 59 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
| 60 |
+
parser.add_argument("--world-size", default=2, type=int)
|
| 61 |
+
|
| 62 |
+
return parser.parse_args()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def check_params(args):
|
| 66 |
+
# dataset/Hayao + dataset/train_photo -> train_photo_Hayao
|
| 67 |
+
args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
|
| 68 |
+
assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def main(args, logger):
|
| 72 |
+
check_params(args)
|
| 73 |
+
|
| 74 |
+
if not torch.cuda.is_available():
|
| 75 |
+
logger.info("CUDA not found, use CPU")
|
| 76 |
+
# Just for debugging purpose, set to minimum config
|
| 77 |
+
# to avoid 🔥 the computer...
|
| 78 |
+
args.device = 'cpu'
|
| 79 |
+
args.debug_samples = 10
|
| 80 |
+
args.batch_size = 2
|
| 81 |
+
else:
|
| 82 |
+
logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")
|
| 83 |
+
|
| 84 |
+
norm_type = "instance"
|
| 85 |
+
if args.model == 'v1':
|
| 86 |
+
G = GeneratorV1(args.dataset)
|
| 87 |
+
elif args.model == 'v2':
|
| 88 |
+
G = GeneratorV2(args.dataset)
|
| 89 |
+
norm_type = "layer"
|
| 90 |
+
elif args.model == 'v3':
|
| 91 |
+
G = GeneratorV3(args.dataset)
|
| 92 |
+
|
| 93 |
+
D = Discriminator(
|
| 94 |
+
args.dataset,
|
| 95 |
+
num_layers=args.d_layers,
|
| 96 |
+
use_sn=args.use_sn,
|
| 97 |
+
norm_type=norm_type,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
start_e = 0
|
| 101 |
+
start_e_init = 0
|
| 102 |
+
|
| 103 |
+
trainer = Trainer(
|
| 104 |
+
generator=G,
|
| 105 |
+
discriminator=D,
|
| 106 |
+
config=args,
|
| 107 |
+
logger=logger,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if args.resume_G_init.lower() != 'false':
|
| 111 |
+
start_e_init = load_checkpoint(G, args.resume_G_init) + 1
|
| 112 |
+
if args.local_rank == 0:
|
| 113 |
+
logger.info(f"G content weight loaded from {args.resume_G_init}")
|
| 114 |
+
elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
|
| 115 |
+
# You should provide both
|
| 116 |
+
try:
|
| 117 |
+
start_e = load_checkpoint(G, args.resume_G)
|
| 118 |
+
if args.local_rank == 0:
|
| 119 |
+
logger.info(f"G weight loaded from {args.resume_G}")
|
| 120 |
+
load_checkpoint(D, args.resume_D)
|
| 121 |
+
if args.local_rank == 0:
|
| 122 |
+
logger.info(f"D weight loaded from {args.resume_D}")
|
| 123 |
+
# If loaded both weight, turn off init G phrase
|
| 124 |
+
args.init_epochs = 0
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print('Could not load checkpoint, train from scratch', e)
|
| 128 |
+
elif args.resume:
|
| 129 |
+
# Try to load from working dir
|
| 130 |
+
logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
|
| 131 |
+
start_e = load_checkpoint(G, trainer.checkpoint_path_G)
|
| 132 |
+
logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
|
| 133 |
+
load_checkpoint(D, trainer.checkpoint_path_D)
|
| 134 |
+
args.init_epochs = 0
|
| 135 |
+
|
| 136 |
+
dataset = AnimeDataSet(
|
| 137 |
+
args.anime_image_dir,
|
| 138 |
+
args.real_image_dir,
|
| 139 |
+
args.debug_samples,
|
| 140 |
+
args.cache,
|
| 141 |
+
imgsz=args.imgsz,
|
| 142 |
+
resize_method=args.resize_method,
|
| 143 |
+
)
|
| 144 |
+
if args.local_rank == 0:
|
| 145 |
+
logger.info(f"Start from epoch {start_e}, {start_e_init}")
|
| 146 |
+
trainer.train(dataset, start_e, start_e_init)
|
| 147 |
+
|
| 148 |
+
if __name__ == '__main__':
|
| 149 |
+
args = parse_args()
|
| 150 |
+
real_name = os.path.basename(args.real_image_dir)
|
| 151 |
+
anime_name = os.path.basename(args.anime_image_dir)
|
| 152 |
+
args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"
|
| 153 |
+
|
| 154 |
+
os.makedirs(args.exp_dir, exist_ok=True)
|
| 155 |
+
logger = get_logger(os.path.join(args.exp_dir, "train.log"))
|
| 156 |
+
|
| 157 |
+
if args.local_rank == 0:
|
| 158 |
+
logger.info("# ==== Train Config ==== #")
|
| 159 |
+
for arg in vars(args):
|
| 160 |
+
logger.info(f"{arg} {getattr(args, arg)}")
|
| 161 |
+
logger.info("==========================")
|
| 162 |
+
|
| 163 |
+
main(args, logger)
|
trainer/__init__.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
import numpy as np
|
| 9 |
+
from glob import glob
|
| 10 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 11 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from utils.image_processing import denormalize_input, preprocess_images, resize_image
|
| 15 |
+
from losses import LossSummary, AnimeGanLoss, to_gray_scale
|
| 16 |
+
from utils import load_checkpoint, save_checkpoint, read_image
|
| 17 |
+
from utils.common import set_lr
|
| 18 |
+
from color_transfer import color_transfer_pytorch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def transfer_color_and_rescale(src, target):
|
| 22 |
+
"""Transfer color from src image to target then rescale to [-1, 1]"""
|
| 23 |
+
out = color_transfer_pytorch(src, target) # [0, 1]
|
| 24 |
+
out = (out / 0.5) - 1
|
| 25 |
+
return out
|
| 26 |
+
|
| 27 |
+
def gaussian_noise():
|
| 28 |
+
gaussian_mean = torch.tensor(0.0)
|
| 29 |
+
gaussian_std = torch.tensor(0.1)
|
| 30 |
+
return torch.normal(gaussian_mean, gaussian_std)
|
| 31 |
+
|
| 32 |
+
def convert_to_readable(seconds):
|
| 33 |
+
return time.strftime('%H:%M:%S', time.gmtime(seconds))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def revert_to_np_image(image_tensor):
|
| 37 |
+
image = image_tensor.cpu().numpy()
|
| 38 |
+
# CHW
|
| 39 |
+
image = image.transpose(1, 2, 0)
|
| 40 |
+
image = denormalize_input(image, dtype=np.int16)
|
| 41 |
+
return image[..., ::-1] # to RGB
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def save_generated_images(images: torch.Tensor, save_dir: str):
|
| 45 |
+
"""Save generated images `(*, 3, H, W)` range [-1, 1] into disk"""
|
| 46 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 47 |
+
images = images.clone().detach().cpu().numpy()
|
| 48 |
+
images = images.transpose(0, 2, 3, 1)
|
| 49 |
+
n_images = len(images)
|
| 50 |
+
|
| 51 |
+
for i in range(n_images):
|
| 52 |
+
img = images[i]
|
| 53 |
+
img = denormalize_input(img, dtype=np.int16)
|
| 54 |
+
img = img[..., ::-1]
|
| 55 |
+
cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DDPTrainer:
|
| 59 |
+
def _init_distributed(self):
|
| 60 |
+
if self.cfg.ddp:
|
| 61 |
+
self.logger.info("Setting up DDP")
|
| 62 |
+
self.pg = torch.distributed.init_process_group(
|
| 63 |
+
backend="nccl",
|
| 64 |
+
rank=self.cfg.local_rank,
|
| 65 |
+
world_size=self.cfg.world_size
|
| 66 |
+
)
|
| 67 |
+
self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg)
|
| 68 |
+
self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg)
|
| 69 |
+
torch.cuda.set_device(self.cfg.local_rank)
|
| 70 |
+
self.G.cuda(self.cfg.local_rank)
|
| 71 |
+
self.D.cuda(self.cfg.local_rank)
|
| 72 |
+
self.logger.info("Setting up DDP Done")
|
| 73 |
+
|
| 74 |
+
def _init_amp(self, enabled=False):
|
| 75 |
+
# self.scaler = torch.cuda.amp.GradScaler(enabled=enabled, growth_interval=100)
|
| 76 |
+
self.scaler_g = GradScaler(enabled=enabled)
|
| 77 |
+
self.scaler_d = GradScaler(enabled=enabled)
|
| 78 |
+
if self.cfg.ddp:
|
| 79 |
+
self.G = DistributedDataParallel(
|
| 80 |
+
self.G, device_ids=[self.cfg.local_rank],
|
| 81 |
+
output_device=self.cfg.local_rank,
|
| 82 |
+
find_unused_parameters=False)
|
| 83 |
+
|
| 84 |
+
self.D = DistributedDataParallel(
|
| 85 |
+
self.D, device_ids=[self.cfg.local_rank],
|
| 86 |
+
output_device=self.cfg.local_rank,
|
| 87 |
+
find_unused_parameters=False)
|
| 88 |
+
self.logger.info("Set DistributedDataParallel")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Trainer(DDPTrainer):
|
| 92 |
+
"""
|
| 93 |
+
Base Trainer class
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
generator,
|
| 99 |
+
discriminator,
|
| 100 |
+
config,
|
| 101 |
+
logger,
|
| 102 |
+
) -> None:
|
| 103 |
+
self.G = generator
|
| 104 |
+
self.D = discriminator
|
| 105 |
+
self.cfg = config
|
| 106 |
+
self.max_norm = 10
|
| 107 |
+
self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu'
|
| 108 |
+
self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999))
|
| 109 |
+
self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999))
|
| 110 |
+
self.loss_tracker = LossSummary()
|
| 111 |
+
if self.cfg.ddp:
|
| 112 |
+
self.device = torch.device(f"cuda:{self.cfg.local_rank}")
|
| 113 |
+
logger.info(f"---------{self.cfg.local_rank} {self.device}")
|
| 114 |
+
else:
|
| 115 |
+
self.device = torch.device(self.cfg.device)
|
| 116 |
+
self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv)
|
| 117 |
+
self.logger = logger
|
| 118 |
+
self._init_working_dir()
|
| 119 |
+
self._init_distributed()
|
| 120 |
+
self._init_amp(enabled=self.cfg.amp)
|
| 121 |
+
|
| 122 |
+
def _init_working_dir(self):
|
| 123 |
+
"""Init working directory for saving checkpoint, ..."""
|
| 124 |
+
os.makedirs(self.cfg.exp_dir, exist_ok=True)
|
| 125 |
+
Gname = self.G.name
|
| 126 |
+
Dname = self.D.name
|
| 127 |
+
self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt")
|
| 128 |
+
self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt")
|
| 129 |
+
self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt")
|
| 130 |
+
self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images")
|
| 131 |
+
self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images")
|
| 132 |
+
os.makedirs(self.save_image_dir, exist_ok=True)
|
| 133 |
+
os.makedirs(self.example_image_dir, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
def init_weight_G(self, weight: str):
|
| 136 |
+
"""Init Generator weight"""
|
| 137 |
+
return load_checkpoint(self.G, weight)
|
| 138 |
+
|
| 139 |
+
def init_weight_D(self, weight: str):
|
| 140 |
+
"""Init Discriminator weight"""
|
| 141 |
+
return load_checkpoint(self.D, weight)
|
| 142 |
+
|
| 143 |
+
def pretrain_generator(self, train_loader, start_epoch):
|
| 144 |
+
"""
|
| 145 |
+
Pretrain Generator to recontruct input image.
|
| 146 |
+
"""
|
| 147 |
+
init_losses = []
|
| 148 |
+
set_lr(self.optimizer_g, self.cfg.init_lr)
|
| 149 |
+
for epoch in range(start_epoch, self.cfg.init_epochs):
|
| 150 |
+
# Train with content loss only
|
| 151 |
+
|
| 152 |
+
pbar = tqdm(train_loader)
|
| 153 |
+
for data in pbar:
|
| 154 |
+
img = data["image"].to(self.device)
|
| 155 |
+
|
| 156 |
+
self.optimizer_g.zero_grad()
|
| 157 |
+
|
| 158 |
+
with autocast(enabled=self.cfg.amp):
|
| 159 |
+
fake_img = self.G(img)
|
| 160 |
+
loss = self.loss_fn.content_loss_vgg(img, fake_img)
|
| 161 |
+
|
| 162 |
+
self.scaler_g.scale(loss).backward()
|
| 163 |
+
self.scaler_g.step(self.optimizer_g)
|
| 164 |
+
self.scaler_g.update()
|
| 165 |
+
|
| 166 |
+
if self.cfg.ddp:
|
| 167 |
+
torch.distributed.barrier()
|
| 168 |
+
|
| 169 |
+
init_losses.append(loss.cpu().detach().numpy())
|
| 170 |
+
avg_content_loss = sum(init_losses) / len(init_losses)
|
| 171 |
+
pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}')
|
| 172 |
+
|
| 173 |
+
save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch)
|
| 174 |
+
if self.cfg.local_rank == 0:
|
| 175 |
+
self.generate_and_save(self.cfg.test_image_dir, subname='initg')
|
| 176 |
+
self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}")
|
| 177 |
+
|
| 178 |
+
set_lr(self.optimizer_g, self.cfg.lr_g)
|
| 179 |
+
|
| 180 |
+
def train_epoch(self, epoch, train_loader):
|
| 181 |
+
pbar = tqdm(train_loader, total=len(train_loader))
|
| 182 |
+
for data in pbar:
|
| 183 |
+
img = data["image"].to(self.device)
|
| 184 |
+
anime = data["anime"].to(self.device)
|
| 185 |
+
anime_gray = data["anime_gray"].to(self.device)
|
| 186 |
+
anime_smt_gray = data["smooth_gray"].to(self.device)
|
| 187 |
+
|
| 188 |
+
# ---------------- TRAIN D ---------------- #
|
| 189 |
+
self.optimizer_d.zero_grad()
|
| 190 |
+
|
| 191 |
+
with autocast(enabled=self.cfg.amp):
|
| 192 |
+
fake_img = self.G(img)
|
| 193 |
+
# Add some Gaussian noise to images before feeding to D
|
| 194 |
+
if self.cfg.d_noise:
|
| 195 |
+
fake_img += gaussian_noise()
|
| 196 |
+
anime += gaussian_noise()
|
| 197 |
+
anime_gray += gaussian_noise()
|
| 198 |
+
anime_smt_gray += gaussian_noise()
|
| 199 |
+
|
| 200 |
+
if self.cfg.gray_adv:
|
| 201 |
+
fake_img = to_gray_scale(fake_img)
|
| 202 |
+
|
| 203 |
+
fake_d = self.D(fake_img)
|
| 204 |
+
real_anime_d = self.D(anime)
|
| 205 |
+
real_anime_gray_d = self.D(anime_gray)
|
| 206 |
+
real_anime_smt_gray_d = self.D(anime_smt_gray)
|
| 207 |
+
|
| 208 |
+
loss_d = self.loss_fn.compute_loss_D(
|
| 209 |
+
fake_d,
|
| 210 |
+
real_anime_d,
|
| 211 |
+
real_anime_gray_d,
|
| 212 |
+
real_anime_smt_gray_d
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
self.scaler_d.scale(loss_d).backward()
|
| 216 |
+
self.scaler_d.unscale_(self.optimizer_d)
|
| 217 |
+
torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm)
|
| 218 |
+
self.scaler_d.step(self.optimizer_d)
|
| 219 |
+
self.scaler_d.update()
|
| 220 |
+
if self.cfg.ddp:
|
| 221 |
+
torch.distributed.barrier()
|
| 222 |
+
self.loss_tracker.update_loss_D(loss_d)
|
| 223 |
+
|
| 224 |
+
# ---------------- TRAIN G ---------------- #
|
| 225 |
+
self.optimizer_g.zero_grad()
|
| 226 |
+
|
| 227 |
+
with autocast(enabled=self.cfg.amp):
|
| 228 |
+
fake_img = self.G(img)
|
| 229 |
+
|
| 230 |
+
if self.cfg.gray_adv:
|
| 231 |
+
fake_d = self.D(to_gray_scale(fake_img))
|
| 232 |
+
else:
|
| 233 |
+
fake_d = self.D(fake_img)
|
| 234 |
+
|
| 235 |
+
(
|
| 236 |
+
adv_loss, con_loss,
|
| 237 |
+
gra_loss, col_loss,
|
| 238 |
+
tv_loss
|
| 239 |
+
) = self.loss_fn.compute_loss_G(
|
| 240 |
+
fake_img,
|
| 241 |
+
img,
|
| 242 |
+
fake_d,
|
| 243 |
+
anime_gray,
|
| 244 |
+
)
|
| 245 |
+
loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss
|
| 246 |
+
if torch.isnan(adv_loss).any():
|
| 247 |
+
self.logger.info("----------------------------------------------")
|
| 248 |
+
self.logger.info(fake_d)
|
| 249 |
+
self.logger.info(adv_loss)
|
| 250 |
+
self.logger.info("----------------------------------------------")
|
| 251 |
+
raise ValueError("NAN loss!!")
|
| 252 |
+
|
| 253 |
+
self.scaler_g.scale(loss_g).backward()
|
| 254 |
+
self.scaler_d.unscale_(self.optimizer_g)
|
| 255 |
+
grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm)
|
| 256 |
+
self.scaler_g.step(self.optimizer_g)
|
| 257 |
+
self.scaler_g.update()
|
| 258 |
+
if self.cfg.ddp:
|
| 259 |
+
torch.distributed.barrier()
|
| 260 |
+
|
| 261 |
+
self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)
|
| 262 |
+
pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}")
|
| 263 |
+
|
| 264 |
+
def get_train_loader(self, dataset):
|
| 265 |
+
if self.cfg.ddp:
|
| 266 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
| 267 |
+
else:
|
| 268 |
+
train_sampler = None
|
| 269 |
+
return DataLoader(
|
| 270 |
+
dataset,
|
| 271 |
+
batch_size=self.cfg.batch_size,
|
| 272 |
+
num_workers=self.cfg.num_workers,
|
| 273 |
+
pin_memory=True,
|
| 274 |
+
shuffle=train_sampler is None,
|
| 275 |
+
sampler=train_sampler,
|
| 276 |
+
drop_last=True,
|
| 277 |
+
# collate_fn=collate_fn,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def maybe_increase_imgsz(self, epoch, train_dataset):
|
| 281 |
+
"""
|
| 282 |
+
Increase image size at specific epoch
|
| 283 |
+
+ 50% epochs train at imgsz[0]
|
| 284 |
+
+ the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)`
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
epoch: Current epoch
|
| 288 |
+
train_dataset: Dataset
|
| 289 |
+
|
| 290 |
+
Examples:
|
| 291 |
+
```
|
| 292 |
+
epochs = 100
|
| 293 |
+
imgsz = [256, 352, 416, 512]
|
| 294 |
+
=> [(0, 256), (50, 352), (66, 416), (82, 512)]
|
| 295 |
+
```
|
| 296 |
+
"""
|
| 297 |
+
epochs = self.cfg.epochs
|
| 298 |
+
imgsz = self.cfg.imgsz
|
| 299 |
+
num_size_remains = len(imgsz) - 1
|
| 300 |
+
half_epochs = epochs // 2
|
| 301 |
+
|
| 302 |
+
if len(imgsz) == 1:
|
| 303 |
+
new_size = imgsz[0]
|
| 304 |
+
elif epoch < half_epochs:
|
| 305 |
+
new_size = imgsz[0]
|
| 306 |
+
else:
|
| 307 |
+
per_epoch_increment = int(half_epochs / num_size_remains)
|
| 308 |
+
found = None
|
| 309 |
+
for i, size in enumerate(imgsz[:]):
|
| 310 |
+
if epoch < half_epochs + per_epoch_increment * i:
|
| 311 |
+
found = size
|
| 312 |
+
break
|
| 313 |
+
if not found:
|
| 314 |
+
found = imgsz[-1]
|
| 315 |
+
new_size = found
|
| 316 |
+
|
| 317 |
+
self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}")
|
| 318 |
+
if new_size != train_dataset.imgsz:
|
| 319 |
+
train_dataset.set_imgsz(new_size)
|
| 320 |
+
self.logger.info(f"Increase image size to {new_size} at epoch {epoch}")
|
| 321 |
+
|
| 322 |
+
def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0):
|
| 323 |
+
"""
|
| 324 |
+
Train Generator and Discriminator.
|
| 325 |
+
"""
|
| 326 |
+
self.logger.info(self.device)
|
| 327 |
+
self.G.to(self.device)
|
| 328 |
+
self.D.to(self.device)
|
| 329 |
+
|
| 330 |
+
self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g)
|
| 331 |
+
|
| 332 |
+
if self.cfg.local_rank == 0:
|
| 333 |
+
self.logger.info(f"Start training for {self.cfg.epochs} epochs")
|
| 334 |
+
|
| 335 |
+
for i, data in enumerate(train_dataset):
|
| 336 |
+
for k in data.keys():
|
| 337 |
+
image = data[k]
|
| 338 |
+
cv2.imwrite(
|
| 339 |
+
os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"),
|
| 340 |
+
revert_to_np_image(image)
|
| 341 |
+
)
|
| 342 |
+
if i == 2:
|
| 343 |
+
break
|
| 344 |
+
|
| 345 |
+
end = None
|
| 346 |
+
num_iter = 0
|
| 347 |
+
per_epoch_times = []
|
| 348 |
+
for epoch in range(start_epoch, self.cfg.epochs):
|
| 349 |
+
self.maybe_increase_imgsz(epoch, train_dataset)
|
| 350 |
+
|
| 351 |
+
start = time.time()
|
| 352 |
+
self.train_epoch(epoch, self.get_train_loader(train_dataset))
|
| 353 |
+
|
| 354 |
+
if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0:
|
| 355 |
+
save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch)
|
| 356 |
+
save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch)
|
| 357 |
+
self.generate_and_save(self.cfg.test_image_dir)
|
| 358 |
+
|
| 359 |
+
if epoch % 10 == 0:
|
| 360 |
+
self.copy_results(epoch)
|
| 361 |
+
|
| 362 |
+
num_iter += 1
|
| 363 |
+
|
| 364 |
+
if self.cfg.local_rank == 0:
|
| 365 |
+
end = time.time()
|
| 366 |
+
if end is None:
|
| 367 |
+
eta = 9999
|
| 368 |
+
else:
|
| 369 |
+
per_epoch_time = (end - start)
|
| 370 |
+
per_epoch_times.append(per_epoch_time)
|
| 371 |
+
eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch)
|
| 372 |
+
eta = convert_to_readable(eta)
|
| 373 |
+
self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}")
|
| 374 |
+
|
| 375 |
+
def generate_and_save(
|
| 376 |
+
self,
|
| 377 |
+
image_dir,
|
| 378 |
+
max_imgs=15,
|
| 379 |
+
subname='gen'
|
| 380 |
+
):
|
| 381 |
+
'''
|
| 382 |
+
Generate and save images
|
| 383 |
+
'''
|
| 384 |
+
start = time.time()
|
| 385 |
+
self.G.eval()
|
| 386 |
+
|
| 387 |
+
max_iter = max_imgs
|
| 388 |
+
fake_imgs = []
|
| 389 |
+
real_imgs = []
|
| 390 |
+
image_files = glob(os.path.join(image_dir, "*"))
|
| 391 |
+
|
| 392 |
+
for i, image_file in enumerate(image_files):
|
| 393 |
+
image = read_image(image_file)
|
| 394 |
+
image = resize_image(image)
|
| 395 |
+
real_imgs.append(image.copy())
|
| 396 |
+
image = preprocess_images(image)
|
| 397 |
+
image = image.to(self.device)
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
with autocast(enabled=self.cfg.amp):
|
| 400 |
+
fake_img = self.G(image)
|
| 401 |
+
# fake_img = to_gray_scale(fake_img)
|
| 402 |
+
fake_img = fake_img.detach().cpu().numpy()
|
| 403 |
+
# Channel first -> channel last
|
| 404 |
+
fake_img = fake_img.transpose(0, 2, 3, 1)
|
| 405 |
+
fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0])
|
| 406 |
+
|
| 407 |
+
if i + 1 == max_iter:
|
| 408 |
+
break
|
| 409 |
+
|
| 410 |
+
# fake_imgs = np.concatenate(fake_imgs, axis=0)
|
| 411 |
+
|
| 412 |
+
for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)):
|
| 413 |
+
img = np.concatenate((real_img, fake_img), axis=1) # Concate aross width
|
| 414 |
+
save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg')
|
| 415 |
+
if not cv2.imwrite(save_path, img[..., ::-1]):
|
| 416 |
+
self.logger.info(f"Save generated image failed, {save_path}, {img.shape}")
|
| 417 |
+
elapsed = time.time() - start
|
| 418 |
+
self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.")
|
| 419 |
+
|
| 420 |
+
def copy_results(self, epoch):
|
| 421 |
+
"""Copy result (Weight + Generated images) to each epoch folder
|
| 422 |
+
Every N epoch
|
| 423 |
+
"""
|
| 424 |
+
copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}")
|
| 425 |
+
os.makedirs(copy_dir, exist_ok=True)
|
| 426 |
+
|
| 427 |
+
shutil.copy2(
|
| 428 |
+
self.checkpoint_path_G,
|
| 429 |
+
copy_dir
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir))
|
| 433 |
+
shutil.copytree(
|
| 434 |
+
self.save_image_dir,
|
| 435 |
+
dest,
|
| 436 |
+
dirs_exist_ok=True
|
| 437 |
+
)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .common import *
|
| 2 |
+
from .image_processing import *
|
| 3 |
+
|
| 4 |
+
class DefaultArgs:
|
| 5 |
+
dataset ='Hayao'
|
| 6 |
+
data_dir ='/content'
|
| 7 |
+
epochs = 10
|
| 8 |
+
batch_size = 1
|
| 9 |
+
checkpoint_dir ='/content/checkpoints'
|
| 10 |
+
save_image_dir ='/content/images'
|
| 11 |
+
display_image =True
|
| 12 |
+
save_interval =2
|
| 13 |
+
debug_samples =0
|
| 14 |
+
lr_g = 0.001
|
| 15 |
+
lr_d = 0.002
|
| 16 |
+
wadvg = 300.0
|
| 17 |
+
wadvd = 300.0
|
| 18 |
+
wcon = 1.5
|
| 19 |
+
wgra = 3
|
| 20 |
+
wcol = 10
|
| 21 |
+
use_sn = False
|
utils/common.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import urllib.request
|
| 6 |
+
import cv2
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
HTTP_PREFIXES = [
|
| 10 |
+
'http',
|
| 11 |
+
'data:image/jpeg',
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
RELEASED_WEIGHTS = {
|
| 16 |
+
"hayao:v1": (
|
| 17 |
+
"v1",
|
| 18 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
|
| 19 |
+
),
|
| 20 |
+
"hayao": (
|
| 21 |
+
"v1",
|
| 22 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
|
| 23 |
+
),
|
| 24 |
+
"shinkai:v1": (
|
| 25 |
+
"v1",
|
| 26 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
|
| 27 |
+
),
|
| 28 |
+
"shinkai": (
|
| 29 |
+
"v1",
|
| 30 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
|
| 31 |
+
),
|
| 32 |
+
|
| 33 |
+
## VER 2 ##
|
| 34 |
+
"hayao:v2": (
|
| 35 |
+
# Dataset trained on Google Landmark micro as training real photo
|
| 36 |
+
"v2",
|
| 37 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_gldv2_Hayao.pt"
|
| 38 |
+
),
|
| 39 |
+
"shinkai:v2": (
|
| 40 |
+
# Dataset trained on Google Landmark micro as training real photo
|
| 41 |
+
"v2",
|
| 42 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_gldv2_Shinkai.pt"
|
| 43 |
+
),
|
| 44 |
+
## Face portrait
|
| 45 |
+
"arcane:v2": (
|
| 46 |
+
"v2",
|
| 47 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.2/GeneratorV2_ffhq_Arcane_210624_e350.pt"
|
| 48 |
+
)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def is_image_file(path):
|
| 52 |
+
_, ext = os.path.splitext(path)
|
| 53 |
+
return ext.lower() in (".png", ".jpg", ".jpeg", ".webp")
|
| 54 |
+
|
| 55 |
+
def is_video_file(path):
|
| 56 |
+
# https://moviepy-tburrows13.readthedocs.io/en/improve-docs/ref/VideoClip/VideoFileClip.html
|
| 57 |
+
_, ext = os.path.splitext(path)
|
| 58 |
+
return ext.lower() in (".mp4", ".mov", ".ogv", ".avi", ".mpeg")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def read_image(path):
|
| 62 |
+
"""
|
| 63 |
+
Read image from given path
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
if any(path.startswith(p) for p in HTTP_PREFIXES):
|
| 67 |
+
urllib.request.urlretrieve(path, "temp.jpg")
|
| 68 |
+
path = "temp.jpg"
|
| 69 |
+
|
| 70 |
+
img = cv2.imread(path)
|
| 71 |
+
if img.shape[-1] == 4:
|
| 72 |
+
# 4 channels image
|
| 73 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
| 74 |
+
else:
|
| 75 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 76 |
+
return img
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_checkpoint(model, path, optimizer=None, epoch=None):
|
| 80 |
+
checkpoint = {
|
| 81 |
+
'model_state_dict': model.state_dict(),
|
| 82 |
+
'epoch': epoch,
|
| 83 |
+
}
|
| 84 |
+
if optimizer is not None:
|
| 85 |
+
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
| 86 |
+
|
| 87 |
+
torch.save(checkpoint, path)
|
| 88 |
+
|
| 89 |
+
def maybe_remove_module(state_dict):
|
| 90 |
+
# Remove added module ins state_dict in ddp training
|
| 91 |
+
# https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
|
| 92 |
+
new_state_dict = {}
|
| 93 |
+
module_str = 'module.'
|
| 94 |
+
for k, v in state_dict.items():
|
| 95 |
+
|
| 96 |
+
if k.startswith(module_str):
|
| 97 |
+
k = k[len(module_str):]
|
| 98 |
+
new_state_dict[k] = v
|
| 99 |
+
return new_state_dict
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
|
| 103 |
+
state_dict, path = load_state_dict(path, map_location)
|
| 104 |
+
model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
|
| 105 |
+
model.load_state_dict(
|
| 106 |
+
model_state_dict,
|
| 107 |
+
strict=True
|
| 108 |
+
)
|
| 109 |
+
if 'optimizer_state_dict' in state_dict:
|
| 110 |
+
if optimizer is not None:
|
| 111 |
+
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
| 112 |
+
if strip_optimizer:
|
| 113 |
+
del state_dict["optimizer_state_dict"]
|
| 114 |
+
torch.save(state_dict, path)
|
| 115 |
+
print(f"Optimizer stripped and saved to {path}")
|
| 116 |
+
|
| 117 |
+
epoch = state_dict.get('epoch', 0)
|
| 118 |
+
return epoch
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_state_dict(weight, map_location) -> dict:
|
| 122 |
+
if weight.lower() in RELEASED_WEIGHTS:
|
| 123 |
+
weight = _download_weight(weight.lower())
|
| 124 |
+
|
| 125 |
+
if map_location is None:
|
| 126 |
+
# auto select
|
| 127 |
+
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 128 |
+
state_dict = torch.load(weight, map_location=map_location)
|
| 129 |
+
|
| 130 |
+
return state_dict, weight
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def initialize_weights(net):
|
| 134 |
+
for m in net.modules():
|
| 135 |
+
try:
|
| 136 |
+
if isinstance(m, nn.Conv2d):
|
| 137 |
+
# m.weight.data.normal_(0, 0.02)
|
| 138 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 139 |
+
m.bias.data.zero_()
|
| 140 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 141 |
+
# m.weight.data.normal_(0, 0.02)
|
| 142 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 143 |
+
m.bias.data.zero_()
|
| 144 |
+
elif isinstance(m, nn.Linear):
|
| 145 |
+
# m.weight.data.normal_(0, 0.02)
|
| 146 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 147 |
+
m.bias.data.zero_()
|
| 148 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 149 |
+
m.weight.data.fill_(1)
|
| 150 |
+
m.bias.data.zero_()
|
| 151 |
+
except Exception as e:
|
| 152 |
+
# print(f'SKip layer {m}, {e}')
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def set_lr(optimizer, lr):
|
| 157 |
+
for param_group in optimizer.param_groups:
|
| 158 |
+
param_group['lr'] = lr
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class DownloadProgressBar(tqdm):
|
| 162 |
+
'''
|
| 163 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
| 164 |
+
'''
|
| 165 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
| 166 |
+
if tsize is not None:
|
| 167 |
+
self.total = tsize
|
| 168 |
+
self.update(b * bsize - self.n)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _download_weight(weight):
|
| 172 |
+
'''
|
| 173 |
+
Download weight and save to local file
|
| 174 |
+
'''
|
| 175 |
+
os.makedirs('.cache', exist_ok=True)
|
| 176 |
+
url = RELEASED_WEIGHTS[weight][1]
|
| 177 |
+
filename = os.path.basename(url)
|
| 178 |
+
save_path = f'.cache/{filename}'
|
| 179 |
+
|
| 180 |
+
if os.path.isfile(save_path):
|
| 181 |
+
return save_path
|
| 182 |
+
|
| 183 |
+
desc = f'Downloading {url} to {save_path}'
|
| 184 |
+
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
|
| 185 |
+
urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
|
| 186 |
+
|
| 187 |
+
return save_path
|
| 188 |
+
|
utils/fast_numpyio.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code from https://github.com/divideconcept/fastnumpyio/blob/main/fastnumpyio.py
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
import numpy.lib.format
|
| 6 |
+
import struct
|
| 7 |
+
|
| 8 |
+
def save(file, array):
|
| 9 |
+
magic_string=b"\x93NUMPY\x01\x00v\x00"
|
| 10 |
+
header=bytes(("{'descr': '"+array.dtype.descr[0][1]+"', 'fortran_order': False, 'shape': "+str(array.shape)+", }").ljust(127-len(magic_string))+"\n",'utf-8')
|
| 11 |
+
if type(file) == str:
|
| 12 |
+
file=open(file,"wb")
|
| 13 |
+
file.write(magic_string)
|
| 14 |
+
file.write(header)
|
| 15 |
+
file.write(array.data)
|
| 16 |
+
|
| 17 |
+
def pack(array):
|
| 18 |
+
size=len(array.shape)
|
| 19 |
+
return bytes(array.dtype.byteorder.replace('=','<' if sys.byteorder == 'little' else '>')+array.dtype.kind,'utf-8')+array.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f'<B{size}I',size,*array.shape)+array.data
|
| 20 |
+
|
| 21 |
+
def load(file):
|
| 22 |
+
if type(file) == str:
|
| 23 |
+
file=open(file,"rb")
|
| 24 |
+
header = file.read(128)
|
| 25 |
+
if not header:
|
| 26 |
+
return None
|
| 27 |
+
descr = str(header[19:25], 'utf-8').replace("'","").replace(" ","")
|
| 28 |
+
shape = tuple(int(num) for num in str(header[60:120], 'utf-8').replace(', }', '').replace('(', '').replace(')', '').split(','))
|
| 29 |
+
datasize = numpy.lib.format.descr_to_dtype(descr).itemsize
|
| 30 |
+
for dimension in shape:
|
| 31 |
+
datasize *= dimension
|
| 32 |
+
return np.ndarray(shape, dtype=descr, buffer=file.read(datasize))
|
| 33 |
+
|
| 34 |
+
def unpack(data):
|
| 35 |
+
dtype = str(data[:2],'utf-8')
|
| 36 |
+
dtype += str(data[2])
|
| 37 |
+
size = data[3]
|
| 38 |
+
shape = struct.unpack_from(f'<{size}I', data, 4)
|
| 39 |
+
datasize=data[2]
|
| 40 |
+
for dimension in shape:
|
| 41 |
+
datasize *= dimension
|
| 42 |
+
return np.ndarray(shape, dtype=dtype, buffer=data[4+size*4:4+size*4+datasize])
|
| 43 |
+
|
utils/image_processing.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def gram(input):
|
| 9 |
+
"""
|
| 10 |
+
Calculate Gram Matrix
|
| 11 |
+
|
| 12 |
+
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss
|
| 13 |
+
"""
|
| 14 |
+
b, c, w, h = input.size()
|
| 15 |
+
|
| 16 |
+
x = input.contiguous().view(b * c, w * h)
|
| 17 |
+
|
| 18 |
+
# x = x / 2
|
| 19 |
+
|
| 20 |
+
# Work around, torch.mm would generate some inf values.
|
| 21 |
+
# https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2
|
| 22 |
+
# x = torch.clamp(x, max=1.0e2, min=-1.0e2)
|
| 23 |
+
# x[x > 1.0e2] = 1.0e2
|
| 24 |
+
# x[x < -1.0e2] = -1.0e2
|
| 25 |
+
|
| 26 |
+
G = torch.mm(x, x.T)
|
| 27 |
+
G = torch.clamp(G, -64990.0, 64990.0)
|
| 28 |
+
# normalize by total elements
|
| 29 |
+
result = G.div(b * c * w * h)
|
| 30 |
+
return result
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def divisible(dim):
|
| 35 |
+
'''
|
| 36 |
+
Make width and height divisible by 32
|
| 37 |
+
'''
|
| 38 |
+
width, height = dim
|
| 39 |
+
return width - (width % 32), height - (height % 32)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
|
| 43 |
+
dim = None
|
| 44 |
+
h, w = image.shape[:2]
|
| 45 |
+
|
| 46 |
+
if width and height:
|
| 47 |
+
return cv2.resize(image, divisible((width, height)), interpolation=inter)
|
| 48 |
+
|
| 49 |
+
if width is None and height is None:
|
| 50 |
+
return cv2.resize(image, divisible((w, h)), interpolation=inter)
|
| 51 |
+
|
| 52 |
+
if width is None:
|
| 53 |
+
r = height / float(h)
|
| 54 |
+
dim = (int(w * r), height)
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
r = width / float(w)
|
| 58 |
+
dim = (width, int(h * r))
|
| 59 |
+
|
| 60 |
+
return cv2.resize(image, divisible(dim), interpolation=inter)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def normalize_input(images):
|
| 64 |
+
'''
|
| 65 |
+
[0, 255] -> [-1, 1]
|
| 66 |
+
'''
|
| 67 |
+
return images / 127.5 - 1.0
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def denormalize_input(images, dtype=None):
|
| 71 |
+
'''
|
| 72 |
+
[-1, 1] -> [0, 255]
|
| 73 |
+
'''
|
| 74 |
+
images = images * 127.5 + 127.5
|
| 75 |
+
|
| 76 |
+
if dtype is not None:
|
| 77 |
+
if isinstance(images, torch.Tensor):
|
| 78 |
+
images = images.type(dtype)
|
| 79 |
+
else:
|
| 80 |
+
# numpy.ndarray
|
| 81 |
+
images = images.astype(dtype)
|
| 82 |
+
|
| 83 |
+
return images
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def preprocess_images(images):
|
| 87 |
+
'''
|
| 88 |
+
Preprocess image for inference
|
| 89 |
+
|
| 90 |
+
@Arguments:
|
| 91 |
+
- images: np.ndarray
|
| 92 |
+
|
| 93 |
+
@Returns
|
| 94 |
+
- images: torch.tensor
|
| 95 |
+
'''
|
| 96 |
+
images = images.astype(np.float32)
|
| 97 |
+
|
| 98 |
+
# Normalize to [-1, 1]
|
| 99 |
+
images = normalize_input(images)
|
| 100 |
+
images = torch.from_numpy(images)
|
| 101 |
+
|
| 102 |
+
# Add batch dim
|
| 103 |
+
if len(images.shape) == 3:
|
| 104 |
+
images = images.unsqueeze(0)
|
| 105 |
+
|
| 106 |
+
# channel first
|
| 107 |
+
images = images.permute(0, 3, 1, 2)
|
| 108 |
+
|
| 109 |
+
return images
|
| 110 |
+
|
| 111 |
+
def compute_data_mean(data_folder):
|
| 112 |
+
if not os.path.exists(data_folder):
|
| 113 |
+
raise FileNotFoundError(f'Folder {data_folder} does not exits')
|
| 114 |
+
|
| 115 |
+
image_files = os.listdir(data_folder)
|
| 116 |
+
total = np.zeros(3)
|
| 117 |
+
|
| 118 |
+
print(f"Compute mean (R, G, B) from {len(image_files)} images")
|
| 119 |
+
|
| 120 |
+
for img_file in tqdm(image_files):
|
| 121 |
+
path = os.path.join(data_folder, img_file)
|
| 122 |
+
image = cv2.imread(path)
|
| 123 |
+
total += image.mean(axis=(0, 1))
|
| 124 |
+
|
| 125 |
+
channel_mean = total / len(image_files)
|
| 126 |
+
mean = np.mean(channel_mean)
|
| 127 |
+
|
| 128 |
+
return mean - channel_mean[...,::-1] # Convert to BGR for training
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == '__main__':
|
| 132 |
+
t = torch.rand(2, 14, 32, 32)
|
| 133 |
+
|
| 134 |
+
with torch.autocast("cpu"):
|
| 135 |
+
print(gram(t))
|
utils/logger.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_logger(path, *args, **kwargs):
|
| 5 |
+
# logger = logging.getLogger('train')
|
| 6 |
+
# logger.setLevel(logging.NOTSET)
|
| 7 |
+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 8 |
+
# # add filehandler
|
| 9 |
+
# fh = logging.FileHandler(path)
|
| 10 |
+
# fh.setLevel(logging.NOTSET)
|
| 11 |
+
# fh.setFormatter(formatter)
|
| 12 |
+
# ch = logging.StreamHandler()
|
| 13 |
+
# ch.setLevel(logging.ERROR)
|
| 14 |
+
# logger.addHandler(fh)
|
| 15 |
+
# logger.addHandler(ch)
|
| 16 |
+
# return logger
|
| 17 |
+
logging.basicConfig(format = '%(asctime)s %(message)s',
|
| 18 |
+
datefmt = '%m/%d/%Y %I:%M:%S %p',
|
| 19 |
+
handlers=[
|
| 20 |
+
logging.FileHandler(path),
|
| 21 |
+
logging.StreamHandler()
|
| 22 |
+
],
|
| 23 |
+
level=logging.DEBUG)
|
| 24 |
+
return logging
|