Spaces:
Sleeping
Sleeping
Upload app cpu version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +138 -0
- README.md +1 -1
- app.py +50 -0
- app_config.py +9 -0
- checkpoints/epoch_10/colornet.pth +3 -0
- checkpoints/epoch_10/discriminator.pth +3 -0
- checkpoints/epoch_10/embed_net.pth +3 -0
- checkpoints/epoch_10/learning_state.pth +3 -0
- checkpoints/epoch_10/nonlocal_net.pth +3 -0
- checkpoints/epoch_12/colornet.pth +3 -0
- checkpoints/epoch_12/discriminator.pth +3 -0
- checkpoints/epoch_12/embed_net.pth +3 -0
- checkpoints/epoch_12/learning_state.pth +3 -0
- checkpoints/epoch_12/nonlocal_net.pth +3 -0
- checkpoints/epoch_16/colornet.pth +3 -0
- checkpoints/epoch_16/discriminator.pth +3 -0
- checkpoints/epoch_16/embed_net.pth +3 -0
- checkpoints/epoch_16/learning_state.pth +3 -0
- checkpoints/epoch_16/nonlocal_net.pth +3 -0
- checkpoints/epoch_20/colornet.pth +3 -0
- checkpoints/epoch_20/discriminator.pth +3 -0
- checkpoints/epoch_20/embed_net.pth +3 -0
- checkpoints/epoch_20/learning_state.pth +3 -0
- checkpoints/epoch_20/nonlocal_net.pth +3 -0
- requirements.txt +0 -0
- sample_input/ref1.jpg +0 -0
- sample_input/video1.mp4 +3 -0
- src/__init__.py +0 -0
- src/data/dataloader.py +332 -0
- src/data/functional.py +84 -0
- src/data/transforms.py +348 -0
- src/inference.py +174 -0
- src/losses.py +277 -0
- src/metrics.py +225 -0
- src/models/CNN/ColorVidNet.py +141 -0
- src/models/CNN/FrameColor.py +76 -0
- src/models/CNN/GAN_models.py +212 -0
- src/models/CNN/NonlocalNet.py +437 -0
- src/models/CNN/__init__.py +0 -0
- src/models/__init__.py +0 -0
- src/models/vit/__init__.py +0 -0
- src/models/vit/blocks.py +80 -0
- src/models/vit/config.py +22 -0
- src/models/vit/config.yml +132 -0
- src/models/vit/decoder.py +34 -0
- src/models/vit/embed.py +52 -0
- src/models/vit/factory.py +45 -0
- src/models/vit/utils.py +71 -0
- src/models/vit/vit.py +199 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flagged/
|
2 |
+
sample_output/
|
3 |
+
wandb/
|
4 |
+
.vscode
|
5 |
+
.DS_Store
|
6 |
+
*ckpt*/
|
7 |
+
# Custom
|
8 |
+
*.pt
|
9 |
+
data/local
|
10 |
+
# Byte-compiled / optimized / DLL files
|
11 |
+
__pycache__/
|
12 |
+
*.py[cod]
|
13 |
+
*$py.class
|
14 |
+
|
15 |
+
# C extensions
|
16 |
+
*.so
|
17 |
+
|
18 |
+
# Distribution / packaging
|
19 |
+
.Python
|
20 |
+
build/
|
21 |
+
develop-eggs/
|
22 |
+
dist/
|
23 |
+
downloads/
|
24 |
+
eggs/
|
25 |
+
.eggs/
|
26 |
+
lib/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
pip-wheel-metadata/
|
33 |
+
share/python-wheels/
|
34 |
+
*.egg-info/
|
35 |
+
.installed.cfg
|
36 |
+
*.egg
|
37 |
+
MANIFEST
|
38 |
+
|
39 |
+
# PyInstaller
|
40 |
+
# Usually these files are written by a python script from a template
|
41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
42 |
+
*.manifest
|
43 |
+
*.spec
|
44 |
+
|
45 |
+
# Installer logs
|
46 |
+
pip-log.txt
|
47 |
+
pip-delete-this-directory.txt
|
48 |
+
|
49 |
+
# Unit test / coverage reports
|
50 |
+
htmlcov/
|
51 |
+
.tox/
|
52 |
+
.nox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
*.py,cover
|
60 |
+
.hypothesis/
|
61 |
+
.pytest_cache/
|
62 |
+
|
63 |
+
# Translations
|
64 |
+
*.mo
|
65 |
+
*.pot
|
66 |
+
|
67 |
+
# Django stuff:
|
68 |
+
*.log
|
69 |
+
local_settings.py
|
70 |
+
db.sqlite3
|
71 |
+
db.sqlite3-journal
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
.python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
104 |
+
__pypackages__/
|
105 |
+
|
106 |
+
# Celery stuff
|
107 |
+
celerybeat-schedule
|
108 |
+
celerybeat.pid
|
109 |
+
|
110 |
+
# SageMath parsed files
|
111 |
+
*.sage.py
|
112 |
+
|
113 |
+
# Environments
|
114 |
+
.env
|
115 |
+
.venv
|
116 |
+
env/
|
117 |
+
venv/
|
118 |
+
ENV/
|
119 |
+
env.bak/
|
120 |
+
venv.bak/
|
121 |
+
|
122 |
+
# Spyder project settings
|
123 |
+
.spyderproject
|
124 |
+
.spyproject
|
125 |
+
|
126 |
+
# Rope project settings
|
127 |
+
.ropeproject
|
128 |
+
|
129 |
+
# mkdocs documentation
|
130 |
+
/site
|
131 |
+
|
132 |
+
# mypy
|
133 |
+
.mypy_cache/
|
134 |
+
.dmypy.json
|
135 |
+
dmypy.json
|
136 |
+
|
137 |
+
# Pyre type checker
|
138 |
+
.pyre/
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🏃
|
4 |
colorFrom: green
|
5 |
colorTo: yellow
|
|
|
1 |
---
|
2 |
+
title: Exemplar-based Video Colorization using Vision Transformer (CPU version)
|
3 |
emoji: 🏃
|
4 |
colorFrom: green
|
5 |
colorTo: yellow
|
app.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.inference import SwinTExCo
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import time
|
7 |
+
import app_config as cfg
|
8 |
+
|
9 |
+
|
10 |
+
model = SwinTExCo(weights_path=cfg.ckpt_path)
|
11 |
+
|
12 |
+
def video_colorization(video_path, ref_image, progress=gr.Progress()):
|
13 |
+
# Initialize video reader
|
14 |
+
video_reader = cv2.VideoCapture(video_path)
|
15 |
+
fps = video_reader.get(cv2.CAP_PROP_FPS)
|
16 |
+
height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
17 |
+
width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
|
18 |
+
num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
|
19 |
+
|
20 |
+
# Initialize reference image
|
21 |
+
ref_image = Image.fromarray(ref_image)
|
22 |
+
|
23 |
+
# Initialize video writer
|
24 |
+
output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
|
25 |
+
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
26 |
+
|
27 |
+
# Init progress bar
|
28 |
+
|
29 |
+
for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
|
30 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
|
31 |
+
video_writer.write(colorized_frame)
|
32 |
+
|
33 |
+
# for i in progress.tqdm(range(1000)):
|
34 |
+
# time.sleep(0.5)
|
35 |
+
|
36 |
+
video_writer.release()
|
37 |
+
|
38 |
+
return output_path
|
39 |
+
|
40 |
+
app = gr.Interface(
|
41 |
+
fn=video_colorization,
|
42 |
+
inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
|
43 |
+
gr.Image(sources="upload", label="Reference image (color)")],
|
44 |
+
outputs=gr.Video(label="Output video (colorized)"),
|
45 |
+
title=cfg.TITLE,
|
46 |
+
description=cfg.DESCRIPTION
|
47 |
+
).queue()
|
48 |
+
|
49 |
+
|
50 |
+
app.launch()
|
app_config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ckpt_path = 'checkpoints/epoch_20'
|
2 |
+
TITLE = 'Deep Exemplar-based Video Colorization using Vision Transformer'
|
3 |
+
DESCRIPTION = '''
|
4 |
+
<center>
|
5 |
+
This is a demo app of the thesis: <b>Deep Exemplar-based Video Colorization using Vision Transformer</b>.<br/>
|
6 |
+
The code is available at: <i>The link will be updated soon</i>.<br/>
|
7 |
+
Our previous work was also written into paper and accepted at the <a href="https://ictc.org/program_proceeding">ICTC 2023 conference</a> (Section <i>B1-4</i>).
|
8 |
+
</center>
|
9 |
+
'''.strip()
|
checkpoints/epoch_10/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ecb43b5e02b77bec5342e2e296d336bf8f384a07d3c809d1a548fd5fb1e7365
|
3 |
+
size 131239411
|
checkpoints/epoch_10/discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce8968a9d3d2f99b1bc1e32080507e0d671cee00b66200105c8839be684b84b4
|
3 |
+
size 45073068
|
checkpoints/epoch_10/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc711755a75c43025dabe9407cbd11d164eaa9e21f26430d0c16c7493410d902
|
3 |
+
size 110352261
|
checkpoints/epoch_10/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d09b1e96fdf0205930a21928449a44c51cedd965cc0d573068c73971bcb8bd2
|
3 |
+
size 748166487
|
checkpoints/epoch_10/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:86c97d6803d625a0dff8c6c09b70852371906eb5ef77df0277c27875666a68e2
|
3 |
+
size 73189765
|
checkpoints/epoch_12/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50f4b92cd59f4c88c0c1d7c93652413d54b1b96d729fc4b93e235887b5164f28
|
3 |
+
size 131239846
|
checkpoints/epoch_12/discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b54b0bad6ceec33569cc5833cbf03ed8ddbb5f07998aa634badf8298d3cd15f
|
3 |
+
size 45073513
|
checkpoints/epoch_12/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
|
3 |
+
size 110352698
|
checkpoints/epoch_12/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f8bb4dbb3cb8e497a9a2079947f0221823fa8b44695e2d2ad8478be48464fad
|
3 |
+
size 748166934
|
checkpoints/epoch_12/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1f76b53dad7bf15c7d26aa106c95387e75751b8c31fafef2bd73ea7d77160cb
|
3 |
+
size 73190208
|
checkpoints/epoch_16/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81ec9cff0ad5b0d920179fa7a9cc229e1424bfc796b7134604ff66b97d748c49
|
3 |
+
size 131239846
|
checkpoints/epoch_16/discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42262d5ed7596f38e65774085222530eee57da8dfaa7fe1aa223d824ed166f62
|
3 |
+
size 45073513
|
checkpoints/epoch_16/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
|
3 |
+
size 110352698
|
checkpoints/epoch_16/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea4cf81341750ebf517c696a0f6241bfeede0584b0ce75ad208e3ffc8280877f
|
3 |
+
size 748166934
|
checkpoints/epoch_16/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85b63363bc9c79732df78ba50ed19491ed86e961214bbd1f796a871334eba516
|
3 |
+
size 73190208
|
checkpoints/epoch_20/colornet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c524f4e5df5f6ce91db1973a30de55299ebcbbde1edd2009718d3b4cd2631339
|
3 |
+
size 131239846
|
checkpoints/epoch_20/discriminator.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fcd80950c796fcfe6e4b6bdeeb358776700458d868da94ee31df3d1d37779310
|
3 |
+
size 45073513
|
checkpoints/epoch_20/embed_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
|
3 |
+
size 110352698
|
checkpoints/epoch_20/learning_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b1163b210b246b07d8f1c50eb3766d97c6f03bf409c854d00b7c69edb6d7391
|
3 |
+
size 748166934
|
checkpoints/epoch_20/nonlocal_net.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:031e5f38cc79eb3c0ed51ca2ad3c8921fdda2fa05946c357f84881259de74e6d
|
3 |
+
size 73190208
|
requirements.txt
ADDED
Binary file (434 Bytes). View file
|
|
sample_input/ref1.jpg
ADDED
![]() |
sample_input/video1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:077ebcd3cf6c020c95732e74a0fe1fab9b80102bc14d5e201b12c4917e0c0d1d
|
3 |
+
size 1011726
|
src/__init__.py
ADDED
File without changes
|
src/data/dataloader.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from src.utils import (
|
4 |
+
CenterPadCrop_numpy,
|
5 |
+
Distortion_with_flow_cpu,
|
6 |
+
Distortion_with_flow_gpu,
|
7 |
+
Normalize,
|
8 |
+
RGB2Lab,
|
9 |
+
ToTensor,
|
10 |
+
Normalize,
|
11 |
+
RGB2Lab,
|
12 |
+
ToTensor,
|
13 |
+
CenterPad,
|
14 |
+
read_flow,
|
15 |
+
SquaredPadding
|
16 |
+
)
|
17 |
+
import torch
|
18 |
+
import torch.utils.data as data
|
19 |
+
import torchvision.transforms as transforms
|
20 |
+
from numpy import random
|
21 |
+
import os
|
22 |
+
from PIL import Image
|
23 |
+
from scipy.ndimage.filters import gaussian_filter
|
24 |
+
from scipy.ndimage import map_coordinates
|
25 |
+
import glob
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def image_loader(path):
|
30 |
+
with open(path, "rb") as f:
|
31 |
+
with Image.open(f) as img:
|
32 |
+
return img.convert("RGB")
|
33 |
+
|
34 |
+
|
35 |
+
class CenterCrop(object):
|
36 |
+
"""
|
37 |
+
center crop the numpy array
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, image_size):
|
41 |
+
self.h0, self.w0 = image_size
|
42 |
+
|
43 |
+
def __call__(self, input_numpy):
|
44 |
+
if input_numpy.ndim == 3:
|
45 |
+
h, w, channel = input_numpy.shape
|
46 |
+
output_numpy = np.zeros((self.h0, self.w0, channel))
|
47 |
+
output_numpy = input_numpy[
|
48 |
+
(h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0, :
|
49 |
+
]
|
50 |
+
else:
|
51 |
+
h, w = input_numpy.shape
|
52 |
+
output_numpy = np.zeros((self.h0, self.w0))
|
53 |
+
output_numpy = input_numpy[
|
54 |
+
(h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0
|
55 |
+
]
|
56 |
+
return output_numpy
|
57 |
+
|
58 |
+
|
59 |
+
class VideosDataset(torch.utils.data.Dataset):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
video_data_root,
|
63 |
+
flow_data_root,
|
64 |
+
mask_data_root,
|
65 |
+
imagenet_folder,
|
66 |
+
annotation_file_path,
|
67 |
+
image_size,
|
68 |
+
num_refs=5, # max = 20
|
69 |
+
image_transform=None,
|
70 |
+
real_reference_probability=1,
|
71 |
+
nonzero_placeholder_probability=0.5,
|
72 |
+
):
|
73 |
+
self.video_data_root = video_data_root
|
74 |
+
self.flow_data_root = flow_data_root
|
75 |
+
self.mask_data_root = mask_data_root
|
76 |
+
self.imagenet_folder = imagenet_folder
|
77 |
+
self.image_transform = image_transform
|
78 |
+
self.CenterPad = CenterPad(image_size)
|
79 |
+
self.Resize = transforms.Resize(image_size)
|
80 |
+
self.ToTensor = ToTensor()
|
81 |
+
self.CenterCrop = transforms.CenterCrop(image_size)
|
82 |
+
self.SquaredPadding = SquaredPadding(image_size[0])
|
83 |
+
self.num_refs = num_refs
|
84 |
+
|
85 |
+
assert os.path.exists(self.video_data_root), "find no video dataroot"
|
86 |
+
assert os.path.exists(self.flow_data_root), "find no flow dataroot"
|
87 |
+
assert os.path.exists(self.imagenet_folder), "find no imagenet folder"
|
88 |
+
# self.epoch = epoch
|
89 |
+
self.image_pairs = pd.read_csv(annotation_file_path, dtype=str)
|
90 |
+
self.real_len = len(self.image_pairs)
|
91 |
+
# self.image_pairs = pd.concat([self.image_pairs] * self.epoch, ignore_index=True)
|
92 |
+
self.real_reference_probability = real_reference_probability
|
93 |
+
self.nonzero_placeholder_probability = nonzero_placeholder_probability
|
94 |
+
print("##### parsing image pairs in %s: %d pairs #####" % (video_data_root, self.__len__()))
|
95 |
+
|
96 |
+
def __getitem__(self, index):
|
97 |
+
(
|
98 |
+
video_name,
|
99 |
+
prev_frame,
|
100 |
+
current_frame,
|
101 |
+
flow_forward_name,
|
102 |
+
mask_name,
|
103 |
+
reference_1_name,
|
104 |
+
reference_2_name,
|
105 |
+
reference_3_name,
|
106 |
+
reference_4_name,
|
107 |
+
reference_5_name
|
108 |
+
) = self.image_pairs.iloc[index, :5+self.num_refs].values.tolist()
|
109 |
+
|
110 |
+
video_path = os.path.join(self.video_data_root, video_name)
|
111 |
+
flow_path = os.path.join(self.flow_data_root, video_name)
|
112 |
+
mask_path = os.path.join(self.mask_data_root, video_name)
|
113 |
+
|
114 |
+
prev_frame_path = os.path.join(video_path, prev_frame)
|
115 |
+
current_frame_path = os.path.join(video_path, current_frame)
|
116 |
+
list_frame_path = glob.glob(os.path.join(video_path, '*'))
|
117 |
+
list_frame_path.sort()
|
118 |
+
|
119 |
+
reference_1_path = os.path.join(self.imagenet_folder, reference_1_name)
|
120 |
+
reference_2_path = os.path.join(self.imagenet_folder, reference_2_name)
|
121 |
+
reference_3_path = os.path.join(self.imagenet_folder, reference_3_name)
|
122 |
+
reference_4_path = os.path.join(self.imagenet_folder, reference_4_name)
|
123 |
+
reference_5_path = os.path.join(self.imagenet_folder, reference_5_name)
|
124 |
+
|
125 |
+
flow_forward_path = os.path.join(flow_path, flow_forward_name)
|
126 |
+
mask_path = os.path.join(mask_path, mask_name)
|
127 |
+
|
128 |
+
#reference_gt_1_path = prev_frame_path
|
129 |
+
#reference_gt_2_path = current_frame_path
|
130 |
+
try:
|
131 |
+
I1 = Image.open(prev_frame_path).convert("RGB")
|
132 |
+
I2 = Image.open(current_frame_path).convert("RGB")
|
133 |
+
try:
|
134 |
+
I_reference_video = Image.open(list_frame_path[0]).convert("RGB") # Get first frame
|
135 |
+
except:
|
136 |
+
I_reference_video = Image.open(current_frame_path).convert("RGB") # Get current frame if error
|
137 |
+
|
138 |
+
reference_list = [reference_1_path, reference_2_path, reference_3_path, reference_4_path, reference_5_path]
|
139 |
+
while reference_list: # run until getting the colorized reference
|
140 |
+
reference_path = random.choice(reference_list)
|
141 |
+
I_reference_video_real = Image.open(reference_path)
|
142 |
+
if I_reference_video_real.mode == 'L':
|
143 |
+
reference_list.remove(reference_path)
|
144 |
+
else:
|
145 |
+
break
|
146 |
+
if not reference_list:
|
147 |
+
I_reference_video_real = I_reference_video
|
148 |
+
|
149 |
+
flow_forward = read_flow(flow_forward_path) # numpy
|
150 |
+
|
151 |
+
mask = Image.open(mask_path) # PIL
|
152 |
+
mask = self.Resize(mask)
|
153 |
+
mask = np.array(mask)
|
154 |
+
# mask = self.SquaredPadding(mask, return_pil=False, return_paddings=False)
|
155 |
+
# binary mask
|
156 |
+
mask[mask < 240] = 0
|
157 |
+
mask[mask >= 240] = 1
|
158 |
+
mask = self.ToTensor(mask)
|
159 |
+
|
160 |
+
# transform
|
161 |
+
I1 = self.image_transform(I1)
|
162 |
+
I2 = self.image_transform(I2)
|
163 |
+
I_reference_video = self.image_transform(I_reference_video)
|
164 |
+
I_reference_video_real = self.image_transform(I_reference_video_real)
|
165 |
+
flow_forward = self.ToTensor(flow_forward)
|
166 |
+
flow_forward = self.Resize(flow_forward)#, return_pil=False, return_paddings=False, dtype=np.float32)
|
167 |
+
|
168 |
+
|
169 |
+
if np.random.random() < self.real_reference_probability:
|
170 |
+
I_reference_output = I_reference_video_real # Use reference from imagenet
|
171 |
+
placeholder = torch.zeros_like(I1)
|
172 |
+
self_ref_flag = torch.zeros_like(I1)
|
173 |
+
else:
|
174 |
+
I_reference_output = I_reference_video # Use reference from ground truth
|
175 |
+
placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
|
176 |
+
self_ref_flag = torch.ones_like(I1)
|
177 |
+
|
178 |
+
outputs = [
|
179 |
+
I1,
|
180 |
+
I2,
|
181 |
+
I_reference_output,
|
182 |
+
flow_forward,
|
183 |
+
mask,
|
184 |
+
placeholder,
|
185 |
+
self_ref_flag,
|
186 |
+
video_name + prev_frame,
|
187 |
+
video_name + current_frame,
|
188 |
+
reference_path
|
189 |
+
]
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
print("error in reading image pair: %s" % str(self.image_pairs[index]))
|
193 |
+
print(e)
|
194 |
+
return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
|
195 |
+
return outputs
|
196 |
+
|
197 |
+
def __len__(self):
|
198 |
+
return len(self.image_pairs)
|
199 |
+
|
200 |
+
|
201 |
+
def parse_imgnet_images(pairs_file):
|
202 |
+
pairs = []
|
203 |
+
with open(pairs_file, "r") as f:
|
204 |
+
lines = f.readlines()
|
205 |
+
for line in lines:
|
206 |
+
line = line.strip().split("|")
|
207 |
+
image_a = line[0]
|
208 |
+
image_b = line[1]
|
209 |
+
pairs.append((image_a, image_b))
|
210 |
+
return pairs
|
211 |
+
|
212 |
+
|
213 |
+
class VideosDataset_ImageNet(data.Dataset):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
imagenet_data_root,
|
217 |
+
pairs_file,
|
218 |
+
image_size,
|
219 |
+
transforms_imagenet=None,
|
220 |
+
distortion_level=3,
|
221 |
+
brightnessjitter=0,
|
222 |
+
nonzero_placeholder_probability=0.5,
|
223 |
+
extra_reference_transform=None,
|
224 |
+
real_reference_probability=1,
|
225 |
+
distortion_device='cpu'
|
226 |
+
):
|
227 |
+
self.imagenet_data_root = imagenet_data_root
|
228 |
+
self.image_pairs = pd.read_csv(pairs_file, names=['i1', 'i2'])
|
229 |
+
self.transforms_imagenet_raw = transforms_imagenet
|
230 |
+
self.extra_reference_transform = transforms.Compose(extra_reference_transform)
|
231 |
+
self.real_reference_probability = real_reference_probability
|
232 |
+
self.transforms_imagenet = transforms.Compose(transforms_imagenet)
|
233 |
+
self.image_size = image_size
|
234 |
+
self.real_len = len(self.image_pairs)
|
235 |
+
self.distortion_level = distortion_level
|
236 |
+
self.distortion_transform = Distortion_with_flow_cpu() if distortion_device == 'cpu' else Distortion_with_flow_gpu()
|
237 |
+
self.brightnessjitter = brightnessjitter
|
238 |
+
self.flow_transform = transforms.Compose([CenterPadCrop_numpy(self.image_size), ToTensor()])
|
239 |
+
self.nonzero_placeholder_probability = nonzero_placeholder_probability
|
240 |
+
self.ToTensor = ToTensor()
|
241 |
+
self.Normalize = Normalize()
|
242 |
+
print("##### parsing imageNet pairs in %s: %d pairs #####" % (imagenet_data_root, self.__len__()))
|
243 |
+
|
244 |
+
def __getitem__(self, index):
|
245 |
+
pa, pb = self.image_pairs.iloc[index].values.tolist()
|
246 |
+
if np.random.random() > 0.5:
|
247 |
+
pa, pb = pb, pa
|
248 |
+
|
249 |
+
image_a_path = os.path.join(self.imagenet_data_root, pa)
|
250 |
+
image_b_path = os.path.join(self.imagenet_data_root, pb)
|
251 |
+
|
252 |
+
I1 = image_loader(image_a_path)
|
253 |
+
I2 = I1
|
254 |
+
I_reference_video = I1
|
255 |
+
I_reference_video_real = image_loader(image_b_path)
|
256 |
+
# print("i'm here get image 2")
|
257 |
+
# generate the flow
|
258 |
+
alpha = np.random.rand() * self.distortion_level
|
259 |
+
distortion_range = 50
|
260 |
+
random_state = np.random.RandomState(None)
|
261 |
+
shape = self.image_size[0], self.image_size[1]
|
262 |
+
# dx: flow on the vertical direction; dy: flow on the horizontal direction
|
263 |
+
forward_dx = (
|
264 |
+
gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
|
265 |
+
)
|
266 |
+
forward_dy = (
|
267 |
+
gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
|
268 |
+
)
|
269 |
+
# print("i'm here get image 3")
|
270 |
+
for transform in self.transforms_imagenet_raw:
|
271 |
+
if type(transform) is RGB2Lab:
|
272 |
+
I1_raw = I1
|
273 |
+
I1 = transform(I1)
|
274 |
+
for transform in self.transforms_imagenet_raw:
|
275 |
+
if type(transform) is RGB2Lab:
|
276 |
+
I2 = self.distortion_transform(I2, forward_dx, forward_dy)
|
277 |
+
I2_raw = I2
|
278 |
+
I2 = transform(I2)
|
279 |
+
# print("i'm here get image 4")
|
280 |
+
I2[0:1, :, :] = I2[0:1, :, :] + torch.randn(1) * self.brightnessjitter
|
281 |
+
|
282 |
+
I_reference_video = self.extra_reference_transform(I_reference_video)
|
283 |
+
for transform in self.transforms_imagenet_raw:
|
284 |
+
I_reference_video = transform(I_reference_video)
|
285 |
+
|
286 |
+
I_reference_video_real = self.transforms_imagenet(I_reference_video_real)
|
287 |
+
# print("i'm here get image 5")
|
288 |
+
flow_forward_raw = np.stack((forward_dy, forward_dx), axis=-1)
|
289 |
+
flow_forward = self.flow_transform(flow_forward_raw)
|
290 |
+
|
291 |
+
# update the mask for the pixels on the border
|
292 |
+
grid_x, grid_y = np.meshgrid(np.arange(self.image_size[0]), np.arange(self.image_size[1]), indexing="ij")
|
293 |
+
grid = np.stack((grid_y, grid_x), axis=-1)
|
294 |
+
grid_warp = grid + flow_forward_raw
|
295 |
+
location_y = grid_warp[:, :, 0].flatten()
|
296 |
+
location_x = grid_warp[:, :, 1].flatten()
|
297 |
+
I2_raw = np.array(I2_raw).astype(float)
|
298 |
+
I21_r = map_coordinates(I2_raw[:, :, 0], np.stack((location_x, location_y)), cval=-1).reshape(
|
299 |
+
(self.image_size[0], self.image_size[1])
|
300 |
+
)
|
301 |
+
I21_g = map_coordinates(I2_raw[:, :, 1], np.stack((location_x, location_y)), cval=-1).reshape(
|
302 |
+
(self.image_size[0], self.image_size[1])
|
303 |
+
)
|
304 |
+
I21_b = map_coordinates(I2_raw[:, :, 2], np.stack((location_x, location_y)), cval=-1).reshape(
|
305 |
+
(self.image_size[0], self.image_size[1])
|
306 |
+
)
|
307 |
+
I21_raw = np.stack((I21_r, I21_g, I21_b), axis=2)
|
308 |
+
mask = np.ones((self.image_size[0], self.image_size[1]))
|
309 |
+
mask[(I21_raw[:, :, 0] == -1) & (I21_raw[:, :, 1] == -1) & (I21_raw[:, :, 2] == -1)] = 0
|
310 |
+
mask[abs(I21_raw - I1_raw).sum(axis=-1) > 50] = 0
|
311 |
+
mask = self.ToTensor(mask)
|
312 |
+
# print("i'm here get image 6")
|
313 |
+
if np.random.random() < self.real_reference_probability:
|
314 |
+
I_reference_output = I_reference_video_real
|
315 |
+
placeholder = torch.zeros_like(I1)
|
316 |
+
self_ref_flag = torch.zeros_like(I1)
|
317 |
+
else:
|
318 |
+
I_reference_output = I_reference_video
|
319 |
+
placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
|
320 |
+
self_ref_flag = torch.ones_like(I1)
|
321 |
+
|
322 |
+
# except Exception as e:
|
323 |
+
# if combo_path is not None:
|
324 |
+
# print("problem in ", combo_path)
|
325 |
+
# print("problem in, ", image_a_path)
|
326 |
+
# print(e)
|
327 |
+
# return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
|
328 |
+
# print("i'm here get image 7")
|
329 |
+
return [I1, I2, I_reference_output, flow_forward, mask, placeholder, self_ref_flag, "holder", pb, pa]
|
330 |
+
|
331 |
+
def __len__(self):
|
332 |
+
return len(self.image_pairs)
|
src/data/functional.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numbers
|
5 |
+
import collections
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageOps
|
8 |
+
|
9 |
+
|
10 |
+
def _is_pil_image(img):
|
11 |
+
return isinstance(img, Image.Image)
|
12 |
+
|
13 |
+
|
14 |
+
def _is_tensor_image(img):
|
15 |
+
return torch.is_tensor(img) and img.ndimension() == 3
|
16 |
+
|
17 |
+
|
18 |
+
def _is_numpy_image(img):
|
19 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
20 |
+
|
21 |
+
|
22 |
+
def to_mytensor(pic):
|
23 |
+
pic_arr = np.array(pic)
|
24 |
+
if pic_arr.ndim == 2:
|
25 |
+
pic_arr = pic_arr[..., np.newaxis]
|
26 |
+
img = torch.from_numpy(pic_arr.transpose((2, 0, 1)))
|
27 |
+
if not isinstance(img, torch.FloatTensor):
|
28 |
+
return img.float() # no normalize .div(255)
|
29 |
+
else:
|
30 |
+
return img
|
31 |
+
|
32 |
+
|
33 |
+
def normalize(tensor, mean, std):
|
34 |
+
if not _is_tensor_image(tensor):
|
35 |
+
raise TypeError("tensor is not a torch image.")
|
36 |
+
if tensor.size(0) == 1:
|
37 |
+
tensor.sub_(mean).div_(std)
|
38 |
+
else:
|
39 |
+
for t, m, s in zip(tensor, mean, std):
|
40 |
+
t.sub_(m).div_(s)
|
41 |
+
return tensor
|
42 |
+
|
43 |
+
|
44 |
+
def resize(img, size, interpolation=Image.BILINEAR):
|
45 |
+
if not _is_pil_image(img):
|
46 |
+
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
|
47 |
+
if not isinstance(size, int) and (not isinstance(size, collections.Iterable) or len(size) != 2):
|
48 |
+
raise TypeError("Got inappropriate size arg: {}".format(size))
|
49 |
+
|
50 |
+
if not isinstance(size, int):
|
51 |
+
return img.resize(size[::-1], interpolation)
|
52 |
+
|
53 |
+
w, h = img.size
|
54 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
55 |
+
return img
|
56 |
+
if w < h:
|
57 |
+
ow = size
|
58 |
+
oh = int(round(size * h / w))
|
59 |
+
else:
|
60 |
+
oh = size
|
61 |
+
ow = int(round(size * w / h))
|
62 |
+
return img.resize((ow, oh), interpolation)
|
63 |
+
|
64 |
+
|
65 |
+
def pad(img, padding, fill=0):
|
66 |
+
if not _is_pil_image(img):
|
67 |
+
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
|
68 |
+
|
69 |
+
if not isinstance(padding, (numbers.Number, tuple)):
|
70 |
+
raise TypeError("Got inappropriate padding arg")
|
71 |
+
if not isinstance(fill, (numbers.Number, str, tuple)):
|
72 |
+
raise TypeError("Got inappropriate fill arg")
|
73 |
+
|
74 |
+
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
|
75 |
+
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)))
|
76 |
+
|
77 |
+
return ImageOps.expand(img, border=padding, fill=fill)
|
78 |
+
|
79 |
+
|
80 |
+
def crop(img, i, j, h, w):
|
81 |
+
if not _is_pil_image(img):
|
82 |
+
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
|
83 |
+
|
84 |
+
return img.crop((j, i, j + w, i + h))
|
src/data/transforms.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import numbers
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from skimage import color
|
10 |
+
|
11 |
+
import src.data.functional as F
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"Compose",
|
15 |
+
"Concatenate",
|
16 |
+
"ToTensor",
|
17 |
+
"Normalize",
|
18 |
+
"Resize",
|
19 |
+
"Scale",
|
20 |
+
"CenterCrop",
|
21 |
+
"Pad",
|
22 |
+
"RandomCrop",
|
23 |
+
"RandomHorizontalFlip",
|
24 |
+
"RandomVerticalFlip",
|
25 |
+
"RandomResizedCrop",
|
26 |
+
"RandomSizedCrop",
|
27 |
+
"FiveCrop",
|
28 |
+
"TenCrop",
|
29 |
+
"RGB2Lab",
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def CustomFunc(inputs, func, *args, **kwargs):
|
34 |
+
im_l = func(inputs[0], *args, **kwargs)
|
35 |
+
im_ab = func(inputs[1], *args, **kwargs)
|
36 |
+
warp_ba = func(inputs[2], *args, **kwargs)
|
37 |
+
warp_aba = func(inputs[3], *args, **kwargs)
|
38 |
+
im_gbl_ab = func(inputs[4], *args, **kwargs)
|
39 |
+
bgr_mc_im = func(inputs[5], *args, **kwargs)
|
40 |
+
|
41 |
+
layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
|
42 |
+
|
43 |
+
for l in range(5):
|
44 |
+
layer = inputs[6 + l]
|
45 |
+
err_ba = func(layer[0], *args, **kwargs)
|
46 |
+
err_ab = func(layer[1], *args, **kwargs)
|
47 |
+
|
48 |
+
layer_data.append([err_ba, err_ab])
|
49 |
+
|
50 |
+
return layer_data
|
51 |
+
|
52 |
+
|
53 |
+
class Compose(object):
|
54 |
+
"""Composes several transforms together.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
58 |
+
|
59 |
+
Example:
|
60 |
+
>>> transforms.Compose([
|
61 |
+
>>> transforms.CenterCrop(10),
|
62 |
+
>>> transforms.ToTensor(),
|
63 |
+
>>> ])
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, transforms):
|
67 |
+
self.transforms = transforms
|
68 |
+
|
69 |
+
def __call__(self, inputs):
|
70 |
+
for t in self.transforms:
|
71 |
+
inputs = t(inputs)
|
72 |
+
return inputs
|
73 |
+
|
74 |
+
|
75 |
+
class Concatenate(object):
|
76 |
+
"""
|
77 |
+
Input: [im_l, im_ab, inputs]
|
78 |
+
inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba]
|
79 |
+
|
80 |
+
Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba]
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __call__(self, inputs):
|
84 |
+
im_l = inputs[0]
|
85 |
+
im_ab = inputs[1]
|
86 |
+
warp_ba = inputs[2]
|
87 |
+
warp_aba = inputs[3]
|
88 |
+
im_glb_ab = inputs[4]
|
89 |
+
bgr_mc_im = inputs[5]
|
90 |
+
bgr_mc_im = bgr_mc_im[[2, 1, 0], ...]
|
91 |
+
|
92 |
+
err_ba = []
|
93 |
+
err_ab = []
|
94 |
+
|
95 |
+
for l in range(5):
|
96 |
+
layer = inputs[6 + l]
|
97 |
+
err_ba.append(layer[0])
|
98 |
+
err_ab.append(layer[1])
|
99 |
+
|
100 |
+
cerr_ba = torch.cat(err_ba, 0)
|
101 |
+
cerr_ab = torch.cat(err_ab, 0)
|
102 |
+
|
103 |
+
return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab)
|
104 |
+
|
105 |
+
|
106 |
+
class ToTensor(object):
|
107 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
108 |
+
|
109 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
110 |
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __call__(self, inputs):
|
114 |
+
"""
|
115 |
+
Args:
|
116 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Tensor: Converted image.
|
120 |
+
"""
|
121 |
+
return CustomFunc(inputs, F.to_mytensor)
|
122 |
+
|
123 |
+
|
124 |
+
class Normalize(object):
|
125 |
+
"""Normalize an tensor image with mean and standard deviation.
|
126 |
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
127 |
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
128 |
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
129 |
+
|
130 |
+
Args:
|
131 |
+
mean (sequence): Sequence of means for each channel.
|
132 |
+
std (sequence): Sequence of standard deviations for each channel.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __call__(self, inputs):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Tensor: Normalized Tensor image.
|
142 |
+
"""
|
143 |
+
|
144 |
+
im_l = F.normalize(inputs[0], 50, 1) # [0, 100]
|
145 |
+
im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100]
|
146 |
+
|
147 |
+
inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1)
|
148 |
+
inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1))
|
149 |
+
warp_ba = inputs[2]
|
150 |
+
|
151 |
+
inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1)
|
152 |
+
inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1))
|
153 |
+
warp_aba = inputs[3]
|
154 |
+
|
155 |
+
im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100]
|
156 |
+
|
157 |
+
bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1))
|
158 |
+
|
159 |
+
layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
|
160 |
+
|
161 |
+
for l in range(5):
|
162 |
+
layer = inputs[6 + l]
|
163 |
+
err_ba = F.normalize(layer[0], 127, 2) # [0, 255]
|
164 |
+
err_ab = F.normalize(layer[1], 127, 2) # [0, 255]
|
165 |
+
layer_data.append([err_ba, err_ab])
|
166 |
+
|
167 |
+
return layer_data
|
168 |
+
|
169 |
+
|
170 |
+
class Resize(object):
|
171 |
+
"""Resize the input PIL Image to the given size.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
175 |
+
(h, w), output size will be matched to this. If size is an int,
|
176 |
+
smaller edge of the image will be matched to this number.
|
177 |
+
i.e, if height > width, then image will be rescaled to
|
178 |
+
(size * height / width, size)
|
179 |
+
interpolation (int, optional): Desired interpolation. Default is
|
180 |
+
``PIL.Image.BILINEAR``
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
184 |
+
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
|
185 |
+
self.size = size
|
186 |
+
self.interpolation = interpolation
|
187 |
+
|
188 |
+
def __call__(self, inputs):
|
189 |
+
"""
|
190 |
+
Args:
|
191 |
+
img (PIL Image): Image to be scaled.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
PIL Image: Rescaled image.
|
195 |
+
"""
|
196 |
+
return CustomFunc(inputs, F.resize, self.size, self.interpolation)
|
197 |
+
|
198 |
+
|
199 |
+
class RandomCrop(object):
|
200 |
+
"""Crop the given PIL Image at a random location.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
204 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
205 |
+
made.
|
206 |
+
padding (int or sequence, optional): Optional padding on each border
|
207 |
+
of the image. Default is 0, i.e no padding. If a sequence of length
|
208 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
209 |
+
respectively.
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(self, size, padding=0):
|
213 |
+
if isinstance(size, numbers.Number):
|
214 |
+
self.size = (int(size), int(size))
|
215 |
+
else:
|
216 |
+
self.size = size
|
217 |
+
self.padding = padding
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def get_params(img, output_size):
|
221 |
+
"""Get parameters for ``crop`` for a random crop.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
img (PIL Image): Image to be cropped.
|
225 |
+
output_size (tuple): Expected output size of the crop.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
229 |
+
"""
|
230 |
+
w, h = img.size
|
231 |
+
th, tw = output_size
|
232 |
+
if w == tw and h == th:
|
233 |
+
return 0, 0, h, w
|
234 |
+
|
235 |
+
i = random.randint(0, h - th)
|
236 |
+
j = random.randint(0, w - tw)
|
237 |
+
return i, j, th, tw
|
238 |
+
|
239 |
+
def __call__(self, inputs):
|
240 |
+
"""
|
241 |
+
Args:
|
242 |
+
img (PIL Image): Image to be cropped.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
PIL Image: Cropped image.
|
246 |
+
"""
|
247 |
+
if self.padding > 0:
|
248 |
+
inputs = CustomFunc(inputs, F.pad, self.padding)
|
249 |
+
|
250 |
+
i, j, h, w = self.get_params(inputs[0], self.size)
|
251 |
+
return CustomFunc(inputs, F.crop, i, j, h, w)
|
252 |
+
|
253 |
+
|
254 |
+
class CenterCrop(object):
|
255 |
+
"""Crop the given PIL Image at a random location.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
259 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
260 |
+
made.
|
261 |
+
padding (int or sequence, optional): Optional padding on each border
|
262 |
+
of the image. Default is 0, i.e no padding. If a sequence of length
|
263 |
+
4 is provided, it is used to pad left, top, right, bottom borders
|
264 |
+
respectively.
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(self, size, padding=0):
|
268 |
+
if isinstance(size, numbers.Number):
|
269 |
+
self.size = (int(size), int(size))
|
270 |
+
else:
|
271 |
+
self.size = size
|
272 |
+
self.padding = padding
|
273 |
+
|
274 |
+
@staticmethod
|
275 |
+
def get_params(img, output_size):
|
276 |
+
"""Get parameters for ``crop`` for a random crop.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
img (PIL Image): Image to be cropped.
|
280 |
+
output_size (tuple): Expected output size of the crop.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
284 |
+
"""
|
285 |
+
w, h = img.size
|
286 |
+
th, tw = output_size
|
287 |
+
if w == tw and h == th:
|
288 |
+
return 0, 0, h, w
|
289 |
+
|
290 |
+
i = (h - th) // 2
|
291 |
+
j = (w - tw) // 2
|
292 |
+
return i, j, th, tw
|
293 |
+
|
294 |
+
def __call__(self, inputs):
|
295 |
+
"""
|
296 |
+
Args:
|
297 |
+
img (PIL Image): Image to be cropped.
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
PIL Image: Cropped image.
|
301 |
+
"""
|
302 |
+
if self.padding > 0:
|
303 |
+
inputs = CustomFunc(inputs, F.pad, self.padding)
|
304 |
+
|
305 |
+
i, j, h, w = self.get_params(inputs[0], self.size)
|
306 |
+
return CustomFunc(inputs, F.crop, i, j, h, w)
|
307 |
+
|
308 |
+
|
309 |
+
class RandomHorizontalFlip(object):
|
310 |
+
"""Horizontally flip the given PIL Image randomly with a probability of 0.5."""
|
311 |
+
|
312 |
+
def __call__(self, inputs):
|
313 |
+
"""
|
314 |
+
Args:
|
315 |
+
img (PIL Image): Image to be flipped.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
PIL Image: Randomly flipped image.
|
319 |
+
"""
|
320 |
+
|
321 |
+
if random.random() < 0.5:
|
322 |
+
return CustomFunc(inputs, F.hflip)
|
323 |
+
return inputs
|
324 |
+
|
325 |
+
|
326 |
+
class RGB2Lab(object):
|
327 |
+
def __call__(self, inputs):
|
328 |
+
"""
|
329 |
+
Args:
|
330 |
+
img (PIL Image): Image to be flipped.
|
331 |
+
|
332 |
+
Returns:
|
333 |
+
PIL Image: Randomly flipped image.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __call__(self, inputs):
|
337 |
+
image_lab = color.rgb2lab(inputs[0])
|
338 |
+
warp_ba_lab = color.rgb2lab(inputs[2])
|
339 |
+
warp_aba_lab = color.rgb2lab(inputs[3])
|
340 |
+
im_gbl_lab = color.rgb2lab(inputs[4])
|
341 |
+
|
342 |
+
inputs[0] = image_lab[:, :, :1] # l channel
|
343 |
+
inputs[1] = image_lab[:, :, 1:] # ab channel
|
344 |
+
inputs[2] = warp_ba_lab # lab channel
|
345 |
+
inputs[3] = warp_aba_lab # lab channel
|
346 |
+
inputs[4] = im_gbl_lab[:, :, 1:] # ab channel
|
347 |
+
|
348 |
+
return inputs
|
src/inference.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.models.CNN.ColorVidNet import ColorVidNet
|
2 |
+
from src.models.vit.embed import SwinModel
|
3 |
+
from src.models.CNN.NonlocalNet import WarpNet
|
4 |
+
from src.models.CNN.FrameColor import frame_colorization
|
5 |
+
import torch
|
6 |
+
from src.models.vit.utils import load_params
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
from PIL import Image
|
10 |
+
from PIL import ImageEnhance as IE
|
11 |
+
import torchvision.transforms as T
|
12 |
+
from src.utils import (
|
13 |
+
RGB2Lab,
|
14 |
+
ToTensor,
|
15 |
+
Normalize,
|
16 |
+
uncenter_l,
|
17 |
+
tensor_lab2rgb
|
18 |
+
)
|
19 |
+
import numpy as np
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
class SwinTExCo:
|
23 |
+
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
|
24 |
+
if device == None:
|
25 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
26 |
+
else:
|
27 |
+
self.device = device
|
28 |
+
|
29 |
+
self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device)
|
30 |
+
self.nonlocal_net = WarpNet(feature_channel=128).to(self.device)
|
31 |
+
self.colornet = ColorVidNet(7).to(self.device)
|
32 |
+
|
33 |
+
self.embed_net.eval()
|
34 |
+
self.nonlocal_net.eval()
|
35 |
+
self.colornet.eval()
|
36 |
+
|
37 |
+
self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth"))
|
38 |
+
self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth"))
|
39 |
+
self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth"))
|
40 |
+
|
41 |
+
self.processor = T.Compose([
|
42 |
+
T.Resize((224,224)),
|
43 |
+
RGB2Lab(),
|
44 |
+
ToTensor(),
|
45 |
+
Normalize()
|
46 |
+
])
|
47 |
+
|
48 |
+
pass
|
49 |
+
|
50 |
+
def __load_models(self, model, weight_path):
|
51 |
+
params = load_params(weight_path, self.device)
|
52 |
+
model.load_state_dict(params, strict=True)
|
53 |
+
|
54 |
+
def __preprocess_reference(self, img):
|
55 |
+
color_enhancer = IE.Color(img)
|
56 |
+
img = color_enhancer.enhance(1.5)
|
57 |
+
return img
|
58 |
+
|
59 |
+
def __upscale_image(self, large_IA_l, I_current_ab_predict):
|
60 |
+
H, W = large_IA_l.shape[2:]
|
61 |
+
large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict,
|
62 |
+
size=(H,W),
|
63 |
+
mode="bilinear",
|
64 |
+
align_corners=False)
|
65 |
+
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict.cpu()), dim=1)
|
66 |
+
large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
|
67 |
+
return large_current_rgb_predict
|
68 |
+
|
69 |
+
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
|
70 |
+
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
|
71 |
+
large_IA_l = large_IA_lab[:, 0:1, :, :]
|
72 |
+
|
73 |
+
IA_lab = self.processor(curr_frame)
|
74 |
+
IA_lab = IA_lab.unsqueeze(0).to(self.device)
|
75 |
+
IA_l = IA_lab[:, 0:1, :, :]
|
76 |
+
if I_last_lab_predict is None:
|
77 |
+
I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device)
|
78 |
+
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
I_current_ab_predict, _ = frame_colorization(
|
82 |
+
IA_l,
|
83 |
+
I_reference_lab,
|
84 |
+
I_last_lab_predict,
|
85 |
+
features_B,
|
86 |
+
self.embed_net,
|
87 |
+
self.nonlocal_net,
|
88 |
+
self.colornet,
|
89 |
+
luminance_noise=0,
|
90 |
+
temperature=1e-10,
|
91 |
+
joint_training=False
|
92 |
+
)
|
93 |
+
I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
|
94 |
+
|
95 |
+
IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict)
|
96 |
+
IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.)
|
97 |
+
IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8)
|
98 |
+
|
99 |
+
return I_last_lab_predict, IA_predict_rgb
|
100 |
+
|
101 |
+
def predict_video(self, video, ref_image):
|
102 |
+
ref_image = self.__preprocess_reference(ref_image)
|
103 |
+
|
104 |
+
I_last_lab_predict = None
|
105 |
+
|
106 |
+
IB_lab = self.processor(ref_image)
|
107 |
+
IB_lab = IB_lab.unsqueeze(0).to(self.device)
|
108 |
+
|
109 |
+
with torch.no_grad():
|
110 |
+
I_reference_lab = IB_lab
|
111 |
+
I_reference_l = I_reference_lab[:, 0:1, :, :]
|
112 |
+
I_reference_ab = I_reference_lab[:, 1:3, :, :]
|
113 |
+
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
|
114 |
+
features_B = self.embed_net(I_reference_rgb)
|
115 |
+
|
116 |
+
#PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame")
|
117 |
+
while video.isOpened():
|
118 |
+
#PBAR.update(1)
|
119 |
+
ret, curr_frame = video.read()
|
120 |
+
|
121 |
+
if not ret:
|
122 |
+
break
|
123 |
+
|
124 |
+
curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
|
125 |
+
curr_frame = Image.fromarray(curr_frame)
|
126 |
+
|
127 |
+
I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
|
128 |
+
|
129 |
+
IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
|
130 |
+
|
131 |
+
yield IA_predict_rgb
|
132 |
+
|
133 |
+
#PBAR.close()
|
134 |
+
video.release()
|
135 |
+
|
136 |
+
def predict_image(self, image, ref_image):
|
137 |
+
ref_image = self.__preprocess_reference(ref_image)
|
138 |
+
|
139 |
+
I_last_lab_predict = None
|
140 |
+
|
141 |
+
IB_lab = self.processor(ref_image)
|
142 |
+
IB_lab = IB_lab.unsqueeze(0).to(self.device)
|
143 |
+
|
144 |
+
with torch.no_grad():
|
145 |
+
I_reference_lab = IB_lab
|
146 |
+
I_reference_l = I_reference_lab[:, 0:1, :, :]
|
147 |
+
I_reference_ab = I_reference_lab[:, 1:3, :, :]
|
148 |
+
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
|
149 |
+
features_B = self.embed_net(I_reference_rgb)
|
150 |
+
|
151 |
+
curr_frame = image
|
152 |
+
I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
|
153 |
+
|
154 |
+
IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
|
155 |
+
|
156 |
+
return IA_predict_rgb
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
model = SwinTExCo('checkpoints/epoch_20/')
|
160 |
+
|
161 |
+
# Initialize video reader and writer
|
162 |
+
video = cv2.VideoCapture('sample_input/video_2.mp4')
|
163 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
164 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
165 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
166 |
+
video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
167 |
+
|
168 |
+
# Initialize reference image
|
169 |
+
ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB')
|
170 |
+
|
171 |
+
for colorized_frame in model.predict_video(video, ref_image):
|
172 |
+
video_writer.write(colorized_frame)
|
173 |
+
|
174 |
+
video_writer.release()
|
src/losses.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.utils import feature_normalize
|
4 |
+
|
5 |
+
|
6 |
+
### START### CONTEXTUAL LOSS ####
|
7 |
+
class ContextualLoss(nn.Module):
|
8 |
+
"""
|
9 |
+
input is Al, Bl, channel = 1, range ~ [0, 255]
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super(ContextualLoss, self).__init__()
|
14 |
+
return None
|
15 |
+
|
16 |
+
def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
|
17 |
+
"""
|
18 |
+
X_features&Y_features are are feature vectors or feature 2d array
|
19 |
+
h: bandwidth
|
20 |
+
return the per-sample loss
|
21 |
+
"""
|
22 |
+
batch_size = X_features.shape[0]
|
23 |
+
feature_depth = X_features.shape[1]
|
24 |
+
|
25 |
+
# to normalized feature vectors
|
26 |
+
if feature_centering:
|
27 |
+
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
|
28 |
+
dim=-1
|
29 |
+
)
|
30 |
+
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
|
31 |
+
dim=-1
|
32 |
+
)
|
33 |
+
X_features = feature_normalize(X_features).view(
|
34 |
+
batch_size, feature_depth, -1
|
35 |
+
) # batch_size * feature_depth * feature_size^2
|
36 |
+
Y_features = feature_normalize(Y_features).view(
|
37 |
+
batch_size, feature_depth, -1
|
38 |
+
) # batch_size * feature_depth * feature_size^2
|
39 |
+
|
40 |
+
# conine distance = 1 - similarity
|
41 |
+
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
|
42 |
+
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
|
43 |
+
|
44 |
+
# normalized distance: dij_bar
|
45 |
+
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
|
46 |
+
|
47 |
+
# pairwise affinity
|
48 |
+
w = torch.exp((1 - d_norm) / h)
|
49 |
+
A_ij = w / torch.sum(w, dim=-1, keepdim=True)
|
50 |
+
|
51 |
+
# contextual loss per sample
|
52 |
+
CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1)
|
53 |
+
return -torch.log(CX)
|
54 |
+
|
55 |
+
|
56 |
+
class ContextualLoss_forward(nn.Module):
|
57 |
+
"""
|
58 |
+
input is Al, Bl, channel = 1, range ~ [0, 255]
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self):
|
62 |
+
super(ContextualLoss_forward, self).__init__()
|
63 |
+
return None
|
64 |
+
|
65 |
+
def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
|
66 |
+
"""
|
67 |
+
X_features&Y_features are are feature vectors or feature 2d array
|
68 |
+
h: bandwidth
|
69 |
+
return the per-sample loss
|
70 |
+
"""
|
71 |
+
batch_size = X_features.shape[0]
|
72 |
+
feature_depth = X_features.shape[1]
|
73 |
+
|
74 |
+
# to normalized feature vectors
|
75 |
+
if feature_centering:
|
76 |
+
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
|
77 |
+
dim=-1
|
78 |
+
)
|
79 |
+
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
|
80 |
+
dim=-1
|
81 |
+
)
|
82 |
+
X_features = feature_normalize(X_features).view(
|
83 |
+
batch_size, feature_depth, -1
|
84 |
+
) # batch_size * feature_depth * feature_size^2
|
85 |
+
Y_features = feature_normalize(Y_features).view(
|
86 |
+
batch_size, feature_depth, -1
|
87 |
+
) # batch_size * feature_depth * feature_size^2
|
88 |
+
|
89 |
+
# conine distance = 1 - similarity
|
90 |
+
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
|
91 |
+
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
|
92 |
+
|
93 |
+
# normalized distance: dij_bar
|
94 |
+
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
|
95 |
+
|
96 |
+
# pairwise affinity
|
97 |
+
w = torch.exp((1 - d_norm) / h)
|
98 |
+
A_ij = w / torch.sum(w, dim=-1, keepdim=True)
|
99 |
+
|
100 |
+
# contextual loss per sample
|
101 |
+
CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
|
102 |
+
return -torch.log(CX)
|
103 |
+
|
104 |
+
|
105 |
+
### END### CONTEXTUAL LOSS ####
|
106 |
+
|
107 |
+
|
108 |
+
##########################
|
109 |
+
|
110 |
+
|
111 |
+
def mse_loss_fn(input, target=0):
|
112 |
+
return torch.mean((input - target) ** 2)
|
113 |
+
|
114 |
+
|
115 |
+
### START### PERCEPTUAL LOSS ###
|
116 |
+
def Perceptual_loss(domain_invariant, weight_perceptual):
|
117 |
+
instancenorm = nn.InstanceNorm2d(512, affine=False)
|
118 |
+
|
119 |
+
def __call__(A_relu5_1, predict_relu5_1):
|
120 |
+
if domain_invariant:
|
121 |
+
feat_loss = (
|
122 |
+
mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual
|
126 |
+
return feat_loss
|
127 |
+
|
128 |
+
return __call__
|
129 |
+
|
130 |
+
|
131 |
+
### END### PERCEPTUAL LOSS ###
|
132 |
+
|
133 |
+
|
134 |
+
def l1_loss_fn(input, target=0):
|
135 |
+
return torch.mean(torch.abs(input - target))
|
136 |
+
|
137 |
+
|
138 |
+
### END#################
|
139 |
+
|
140 |
+
|
141 |
+
### START### ADVERSIAL LOSS ###
|
142 |
+
def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device):
|
143 |
+
if weight_gan > 0:
|
144 |
+
y_pred_fake, _ = discriminator(fake_data_lab)
|
145 |
+
y_pred_real, _ = discriminator(real_data_lab)
|
146 |
+
|
147 |
+
y = torch.ones_like(y_pred_real)
|
148 |
+
generator_loss = (
|
149 |
+
(
|
150 |
+
torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2)
|
151 |
+
+ torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2)
|
152 |
+
)
|
153 |
+
/ 2
|
154 |
+
* weight_gan
|
155 |
+
)
|
156 |
+
return generator_loss
|
157 |
+
|
158 |
+
return torch.Tensor([0]).to(device)
|
159 |
+
|
160 |
+
|
161 |
+
def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator):
|
162 |
+
y_pred_fake, _ = discriminator(fake_data_lab.detach())
|
163 |
+
y_pred_real, _ = discriminator(real_data_lab.detach())
|
164 |
+
|
165 |
+
y = torch.ones_like(y_pred_real)
|
166 |
+
discriminator_loss = (
|
167 |
+
torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2)
|
168 |
+
+ torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2)
|
169 |
+
) / 2
|
170 |
+
return discriminator_loss
|
171 |
+
|
172 |
+
|
173 |
+
### END### ADVERSIAL LOSS #####
|
174 |
+
|
175 |
+
|
176 |
+
def consistent_loss_fn(
|
177 |
+
I_current_lab_predict,
|
178 |
+
I_last_ab_predict,
|
179 |
+
I_current_nonlocal_lab_predict,
|
180 |
+
I_last_nonlocal_lab_predict,
|
181 |
+
flow_forward,
|
182 |
+
mask,
|
183 |
+
warping_layer,
|
184 |
+
weight_consistent=0.02,
|
185 |
+
weight_nonlocal_consistent=0.0,
|
186 |
+
device="cuda",
|
187 |
+
):
|
188 |
+
def weighted_mse_loss(input, target, weights):
|
189 |
+
out = (input - target) ** 2
|
190 |
+
out = out * weights.expand_as(out)
|
191 |
+
return out.mean()
|
192 |
+
|
193 |
+
def consistent():
|
194 |
+
I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward)
|
195 |
+
I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :]
|
196 |
+
consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent
|
197 |
+
return consistent_loss
|
198 |
+
|
199 |
+
def nonlocal_consistent():
|
200 |
+
I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward)
|
201 |
+
nonlocal_consistent_loss = (
|
202 |
+
weighted_mse_loss(
|
203 |
+
I_current_nonlocal_lab_predict_warp[:, 1:3, :, :],
|
204 |
+
I_last_nonlocal_lab_predict[:, 1:3, :, :],
|
205 |
+
mask,
|
206 |
+
)
|
207 |
+
* weight_nonlocal_consistent
|
208 |
+
)
|
209 |
+
|
210 |
+
return nonlocal_consistent_loss
|
211 |
+
|
212 |
+
consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device)
|
213 |
+
nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device)
|
214 |
+
|
215 |
+
return consistent_loss + nonlocal_consistent_loss
|
216 |
+
|
217 |
+
|
218 |
+
### END### CONSISTENCY LOSS #####
|
219 |
+
|
220 |
+
|
221 |
+
### START### SMOOTHNESS LOSS ###
|
222 |
+
def smoothness_loss_fn(
|
223 |
+
I_current_l,
|
224 |
+
I_current_lab,
|
225 |
+
I_current_ab_predict,
|
226 |
+
A_relu2_1,
|
227 |
+
weighted_layer_color,
|
228 |
+
nonlocal_weighted_layer,
|
229 |
+
weight_smoothness=5.0,
|
230 |
+
weight_nonlocal_smoothness=0.0,
|
231 |
+
device="cuda",
|
232 |
+
):
|
233 |
+
def smoothness(scale_factor=1.0):
|
234 |
+
I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
|
235 |
+
IA_ab_weighed = weighted_layer_color(
|
236 |
+
I_current_lab,
|
237 |
+
I_current_lab_predict,
|
238 |
+
patch_size=3,
|
239 |
+
alpha=10,
|
240 |
+
scale_factor=scale_factor,
|
241 |
+
)
|
242 |
+
smoothness_loss = (
|
243 |
+
mse_loss_fn(
|
244 |
+
nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
|
245 |
+
IA_ab_weighed,
|
246 |
+
)
|
247 |
+
* weight_smoothness
|
248 |
+
)
|
249 |
+
|
250 |
+
return smoothness_loss
|
251 |
+
|
252 |
+
def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5):
|
253 |
+
nonlocal_smooth_feature = feature_normalize(A_relu2_1)
|
254 |
+
I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
|
255 |
+
I_current_ab_weighted_nonlocal = nonlocal_weighted_layer(
|
256 |
+
I_current_lab_predict,
|
257 |
+
nonlocal_smooth_feature.detach(),
|
258 |
+
patch_size=3,
|
259 |
+
alpha=alpha_nonlocal_smoothness,
|
260 |
+
scale_factor=scale_factor,
|
261 |
+
)
|
262 |
+
nonlocal_smoothness_loss = (
|
263 |
+
mse_loss_fn(
|
264 |
+
nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
|
265 |
+
I_current_ab_weighted_nonlocal,
|
266 |
+
)
|
267 |
+
* weight_nonlocal_smoothness
|
268 |
+
)
|
269 |
+
return nonlocal_smoothness_loss
|
270 |
+
|
271 |
+
smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device)
|
272 |
+
nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device)
|
273 |
+
|
274 |
+
return smoothness_loss + nonlocal_smoothness_loss
|
275 |
+
|
276 |
+
|
277 |
+
### END### SMOOTHNESS LOSS #####
|
src/metrics.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
|
2 |
+
import numpy as np
|
3 |
+
import lpips
|
4 |
+
import torch
|
5 |
+
from pytorch_fid.fid_score import calculate_frechet_distance
|
6 |
+
from pytorch_fid.inception import InceptionV3
|
7 |
+
import torch.nn as nn
|
8 |
+
import cv2
|
9 |
+
from scipy import stats
|
10 |
+
import os
|
11 |
+
|
12 |
+
def calc_ssim(pred_image, gt_image):
|
13 |
+
'''
|
14 |
+
Structural Similarity Index (SSIM) is a perceptual metric that quantifies the image quality degradation that is
|
15 |
+
caused by processing such as data compression or by losses in data transmission.
|
16 |
+
|
17 |
+
# Arguments
|
18 |
+
img1: PIL.Image
|
19 |
+
img2: PIL.Image
|
20 |
+
# Returns
|
21 |
+
ssim: float (-1.0, 1.0)
|
22 |
+
'''
|
23 |
+
pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
|
24 |
+
gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
|
25 |
+
ssim = structural_similarity(pred_image, gt_image, channel_axis=2, data_range=255.)
|
26 |
+
return ssim
|
27 |
+
|
28 |
+
def calc_psnr(pred_image, gt_image):
|
29 |
+
'''
|
30 |
+
Peak Signal-to-Noise Ratio (PSNR) is an expression for the ratio between the maximum possible value (power) of a signal
|
31 |
+
and the power of distorting noise that affects the quality of its representation.
|
32 |
+
|
33 |
+
# Arguments
|
34 |
+
img1: PIL.Image
|
35 |
+
img2: PIL.Image
|
36 |
+
# Returns
|
37 |
+
psnr: float
|
38 |
+
'''
|
39 |
+
pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
|
40 |
+
gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
|
41 |
+
|
42 |
+
psnr = peak_signal_noise_ratio(gt_image, pred_image, data_range=255.)
|
43 |
+
return psnr
|
44 |
+
|
45 |
+
class LPIPS_utils:
|
46 |
+
def __init__(self, device = 'cuda'):
|
47 |
+
self.loss_fn = lpips.LPIPS(net='vgg', spatial=True) # Can set net = 'squeeze' or 'vgg'or 'alex'
|
48 |
+
self.loss_fn = self.loss_fn.to(device)
|
49 |
+
self.device = device
|
50 |
+
|
51 |
+
def compare_lpips(self,img_fake, img_real, data_range=255.): # input: torch 1 c h w / h w c
|
52 |
+
img_fake = torch.from_numpy(np.array(img_fake).astype(np.float32)/data_range)
|
53 |
+
img_real = torch.from_numpy(np.array(img_real).astype(np.float32)/data_range)
|
54 |
+
if img_fake.ndim==3:
|
55 |
+
img_fake = img_fake.permute(2,0,1).unsqueeze(0)
|
56 |
+
img_real = img_real.permute(2,0,1).unsqueeze(0)
|
57 |
+
img_fake = img_fake.to(self.device)
|
58 |
+
img_real = img_real.to(self.device)
|
59 |
+
|
60 |
+
dist = self.loss_fn.forward(img_fake,img_real)
|
61 |
+
return dist.mean().item()
|
62 |
+
|
63 |
+
class FID_utils(nn.Module):
|
64 |
+
"""Class for computing the Fréchet Inception Distance (FID) metric score.
|
65 |
+
It is implemented as a class in order to hold the inception model instance
|
66 |
+
in its state.
|
67 |
+
Parameters
|
68 |
+
----------
|
69 |
+
resize_input : bool (optional)
|
70 |
+
Whether or not to resize the input images to the image size (299, 299)
|
71 |
+
on which the inception model was trained. Since the model is fully
|
72 |
+
convolutional, the score also works without resizing. In literature
|
73 |
+
and when working with GANs people tend to set this value to True,
|
74 |
+
however, for internal evaluation this is not necessary.
|
75 |
+
device : str or torch.device
|
76 |
+
The device on which to run the inception model.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, resize_input=True, device="cuda"):
|
80 |
+
super(FID_utils, self).__init__()
|
81 |
+
self.device = device
|
82 |
+
if self.device is None:
|
83 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
84 |
+
#self.model = InceptionV3(resize_input=resize_input).to(device)
|
85 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
86 |
+
self.model = InceptionV3([block_idx]).to(device)
|
87 |
+
self.model = self.model.eval()
|
88 |
+
|
89 |
+
def get_activations(self,batch): # 1 c h w
|
90 |
+
with torch.no_grad():
|
91 |
+
pred = self.model(batch)[0]
|
92 |
+
# If model output is not scalar, apply global spatial average pooling.
|
93 |
+
# This happens if you choose a dimensionality not equal 2048.
|
94 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
95 |
+
#pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
96 |
+
print("error in get activations!")
|
97 |
+
#pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
98 |
+
return pred
|
99 |
+
|
100 |
+
|
101 |
+
def _get_mu_sigma(self, batch,data_range):
|
102 |
+
"""Compute the inception mu and sigma for a batch of images.
|
103 |
+
Parameters
|
104 |
+
----------
|
105 |
+
images : np.ndarray
|
106 |
+
A batch of images with shape (n_images,3, width, height).
|
107 |
+
Returns
|
108 |
+
-------
|
109 |
+
mu : np.ndarray
|
110 |
+
The array of mean activations with shape (2048,).
|
111 |
+
sigma : np.ndarray
|
112 |
+
The covariance matrix of activations with shape (2048, 2048).
|
113 |
+
"""
|
114 |
+
# forward pass
|
115 |
+
if batch.ndim ==3 and batch.shape[2]==3:
|
116 |
+
batch=batch.permute(2,0,1).unsqueeze(0)
|
117 |
+
batch /= data_range
|
118 |
+
#batch = torch.tensor(batch)#.unsqueeze(1).repeat((1, 3, 1, 1))
|
119 |
+
batch = batch.to(self.device, torch.float32)
|
120 |
+
#(activations,) = self.model(batch)
|
121 |
+
activations = self.get_activations(batch)
|
122 |
+
activations = activations.detach().cpu().numpy().squeeze(3).squeeze(2)
|
123 |
+
|
124 |
+
# compute statistics
|
125 |
+
mu = np.mean(activations,axis=0)
|
126 |
+
sigma = np.cov(activations, rowvar=False)
|
127 |
+
|
128 |
+
return mu, sigma
|
129 |
+
|
130 |
+
def score(self, images_1, images_2, data_range=255.):
|
131 |
+
"""Compute the FID score.
|
132 |
+
The input batches should have the shape (n_images,3, width, height). or (h,w,3)
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
images_1 : np.ndarray
|
136 |
+
First batch of images.
|
137 |
+
images_2 : np.ndarray
|
138 |
+
Section batch of images.
|
139 |
+
Returns
|
140 |
+
-------
|
141 |
+
score : float
|
142 |
+
The FID score.
|
143 |
+
"""
|
144 |
+
images_1 = torch.from_numpy(np.array(images_1).astype(np.float32))
|
145 |
+
images_2 = torch.from_numpy(np.array(images_2).astype(np.float32))
|
146 |
+
images_1 = images_1.to(self.device)
|
147 |
+
images_2 = images_2.to(self.device)
|
148 |
+
|
149 |
+
mu_1, sigma_1 = self._get_mu_sigma(images_1,data_range)
|
150 |
+
mu_2, sigma_2 = self._get_mu_sigma(images_2,data_range)
|
151 |
+
score = calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2)
|
152 |
+
|
153 |
+
return score
|
154 |
+
|
155 |
+
def JS_divergence(p, q):
|
156 |
+
M = (p + q) / 2
|
157 |
+
return 0.5 * stats.entropy(p, M) + 0.5 * stats.entropy(q, M)
|
158 |
+
|
159 |
+
|
160 |
+
def compute_JS_bgr(input_dir, dilation=1):
|
161 |
+
input_img_list = os.listdir(input_dir)
|
162 |
+
input_img_list.sort()
|
163 |
+
# print(input_img_list)
|
164 |
+
|
165 |
+
hist_b_list = [] # [img1_histb, img2_histb, ...]
|
166 |
+
hist_g_list = []
|
167 |
+
hist_r_list = []
|
168 |
+
|
169 |
+
for img_name in input_img_list:
|
170 |
+
# print(os.path.join(input_dir, img_name))
|
171 |
+
img_in = cv2.imread(os.path.join(input_dir, img_name))
|
172 |
+
H, W, C = img_in.shape
|
173 |
+
|
174 |
+
hist_b = cv2.calcHist([img_in], [0], None, [256], [0,256]) # B
|
175 |
+
hist_g = cv2.calcHist([img_in], [1], None, [256], [0,256]) # G
|
176 |
+
hist_r = cv2.calcHist([img_in], [2], None, [256], [0,256]) # R
|
177 |
+
|
178 |
+
hist_b = hist_b / (H * W)
|
179 |
+
hist_g = hist_g / (H * W)
|
180 |
+
hist_r = hist_r / (H * W)
|
181 |
+
|
182 |
+
hist_b_list.append(hist_b)
|
183 |
+
hist_g_list.append(hist_g)
|
184 |
+
hist_r_list.append(hist_r)
|
185 |
+
|
186 |
+
JS_b_list = []
|
187 |
+
JS_g_list = []
|
188 |
+
JS_r_list = []
|
189 |
+
|
190 |
+
for i in range(len(hist_b_list)):
|
191 |
+
if i + dilation > len(hist_b_list) - 1:
|
192 |
+
break
|
193 |
+
hist_b_img1 = hist_b_list[i]
|
194 |
+
hist_b_img2 = hist_b_list[i + dilation]
|
195 |
+
JS_b = JS_divergence(hist_b_img1, hist_b_img2)
|
196 |
+
JS_b_list.append(JS_b)
|
197 |
+
|
198 |
+
hist_g_img1 = hist_g_list[i]
|
199 |
+
hist_g_img2 = hist_g_list[i+dilation]
|
200 |
+
JS_g = JS_divergence(hist_g_img1, hist_g_img2)
|
201 |
+
JS_g_list.append(JS_g)
|
202 |
+
|
203 |
+
hist_r_img1 = hist_r_list[i]
|
204 |
+
hist_r_img2 = hist_r_list[i+dilation]
|
205 |
+
JS_r = JS_divergence(hist_r_img1, hist_r_img2)
|
206 |
+
JS_r_list.append(JS_r)
|
207 |
+
|
208 |
+
return JS_b_list, JS_g_list, JS_r_list
|
209 |
+
|
210 |
+
|
211 |
+
def calc_cdc(vid_folder, dilation=[1, 2, 4], weight=[1/3, 1/3, 1/3]):
|
212 |
+
mean_b, mean_g, mean_r = 0, 0, 0
|
213 |
+
for d, w in zip(dilation, weight):
|
214 |
+
JS_b_list_one, JS_g_list_one, JS_r_list_one = compute_JS_bgr(vid_folder, d)
|
215 |
+
mean_b += w * np.mean(JS_b_list_one)
|
216 |
+
mean_g += w * np.mean(JS_g_list_one)
|
217 |
+
mean_r += w * np.mean(JS_r_list_one)
|
218 |
+
|
219 |
+
cdc = np.mean([mean_b, mean_g, mean_r])
|
220 |
+
return cdc
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
src/models/CNN/ColorVidNet.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.parallel
|
4 |
+
|
5 |
+
class ColorVidNet(nn.Module):
|
6 |
+
def __init__(self, ic):
|
7 |
+
super(ColorVidNet, self).__init__()
|
8 |
+
self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1))
|
9 |
+
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
|
10 |
+
self.conv1_2norm = nn.BatchNorm2d(64, affine=False)
|
11 |
+
self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64)
|
12 |
+
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
|
13 |
+
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
|
14 |
+
self.conv2_2norm = nn.BatchNorm2d(128, affine=False)
|
15 |
+
self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128)
|
16 |
+
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
|
17 |
+
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
|
18 |
+
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
|
19 |
+
self.conv3_3norm = nn.BatchNorm2d(256, affine=False)
|
20 |
+
self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256)
|
21 |
+
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
|
22 |
+
self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
|
23 |
+
self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
|
24 |
+
self.conv4_3norm = nn.BatchNorm2d(512, affine=False)
|
25 |
+
self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
|
26 |
+
self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
|
27 |
+
self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
|
28 |
+
self.conv5_3norm = nn.BatchNorm2d(512, affine=False)
|
29 |
+
self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
|
30 |
+
self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
|
31 |
+
self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
|
32 |
+
self.conv6_3norm = nn.BatchNorm2d(512, affine=False)
|
33 |
+
self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1)
|
34 |
+
self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1)
|
35 |
+
self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1)
|
36 |
+
self.conv7_3norm = nn.BatchNorm2d(512, affine=False)
|
37 |
+
self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
|
38 |
+
self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1)
|
39 |
+
self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1)
|
40 |
+
self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1)
|
41 |
+
self.conv8_3norm = nn.BatchNorm2d(256, affine=False)
|
42 |
+
self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
|
43 |
+
self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1)
|
44 |
+
self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1)
|
45 |
+
self.conv9_2norm = nn.BatchNorm2d(128, affine=False)
|
46 |
+
self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1)
|
47 |
+
self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1)
|
48 |
+
self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1)
|
49 |
+
self.conv10_ab = nn.Conv2d(128, 2, 1, 1)
|
50 |
+
|
51 |
+
# add self.relux_x
|
52 |
+
self.relu1_1 = nn.PReLU()
|
53 |
+
self.relu1_2 = nn.PReLU()
|
54 |
+
self.relu2_1 = nn.PReLU()
|
55 |
+
self.relu2_2 = nn.PReLU()
|
56 |
+
self.relu3_1 = nn.PReLU()
|
57 |
+
self.relu3_2 = nn.PReLU()
|
58 |
+
self.relu3_3 = nn.PReLU()
|
59 |
+
self.relu4_1 = nn.PReLU()
|
60 |
+
self.relu4_2 = nn.PReLU()
|
61 |
+
self.relu4_3 = nn.PReLU()
|
62 |
+
self.relu5_1 = nn.PReLU()
|
63 |
+
self.relu5_2 = nn.PReLU()
|
64 |
+
self.relu5_3 = nn.PReLU()
|
65 |
+
self.relu6_1 = nn.PReLU()
|
66 |
+
self.relu6_2 = nn.PReLU()
|
67 |
+
self.relu6_3 = nn.PReLU()
|
68 |
+
self.relu7_1 = nn.PReLU()
|
69 |
+
self.relu7_2 = nn.PReLU()
|
70 |
+
self.relu7_3 = nn.PReLU()
|
71 |
+
self.relu8_1_comb = nn.PReLU()
|
72 |
+
self.relu8_2 = nn.PReLU()
|
73 |
+
self.relu8_3 = nn.PReLU()
|
74 |
+
self.relu9_1_comb = nn.PReLU()
|
75 |
+
self.relu9_2 = nn.PReLU()
|
76 |
+
self.relu10_1_comb = nn.PReLU()
|
77 |
+
self.relu10_2 = nn.LeakyReLU(0.2, True)
|
78 |
+
|
79 |
+
self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1))
|
80 |
+
self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1))
|
81 |
+
self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1))
|
82 |
+
|
83 |
+
self.conv1_2norm = nn.InstanceNorm2d(64)
|
84 |
+
self.conv2_2norm = nn.InstanceNorm2d(128)
|
85 |
+
self.conv3_3norm = nn.InstanceNorm2d(256)
|
86 |
+
self.conv4_3norm = nn.InstanceNorm2d(512)
|
87 |
+
self.conv5_3norm = nn.InstanceNorm2d(512)
|
88 |
+
self.conv6_3norm = nn.InstanceNorm2d(512)
|
89 |
+
self.conv7_3norm = nn.InstanceNorm2d(512)
|
90 |
+
self.conv8_3norm = nn.InstanceNorm2d(256)
|
91 |
+
self.conv9_2norm = nn.InstanceNorm2d(128)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
"""x: gray image (1 channel), ab(2 channel), ab_err, ba_err"""
|
95 |
+
conv1_1 = self.relu1_1(self.conv1_1(x))
|
96 |
+
conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
|
97 |
+
conv1_2norm = self.conv1_2norm(conv1_2)
|
98 |
+
conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm)
|
99 |
+
conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss))
|
100 |
+
conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
|
101 |
+
conv2_2norm = self.conv2_2norm(conv2_2)
|
102 |
+
conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm)
|
103 |
+
conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss))
|
104 |
+
conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
|
105 |
+
conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
|
106 |
+
conv3_3norm = self.conv3_3norm(conv3_3)
|
107 |
+
conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm)
|
108 |
+
conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss))
|
109 |
+
conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
|
110 |
+
conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
|
111 |
+
conv4_3norm = self.conv4_3norm(conv4_3)
|
112 |
+
conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm))
|
113 |
+
conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
|
114 |
+
conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
|
115 |
+
conv5_3norm = self.conv5_3norm(conv5_3)
|
116 |
+
conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm))
|
117 |
+
conv6_2 = self.relu6_2(self.conv6_2(conv6_1))
|
118 |
+
conv6_3 = self.relu6_3(self.conv6_3(conv6_2))
|
119 |
+
conv6_3norm = self.conv6_3norm(conv6_3)
|
120 |
+
conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm))
|
121 |
+
conv7_2 = self.relu7_2(self.conv7_2(conv7_1))
|
122 |
+
conv7_3 = self.relu7_3(self.conv7_3(conv7_2))
|
123 |
+
conv7_3norm = self.conv7_3norm(conv7_3)
|
124 |
+
conv8_1 = self.conv8_1(conv7_3norm)
|
125 |
+
conv3_3_short = self.conv3_3_short(conv3_3norm)
|
126 |
+
conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short)
|
127 |
+
conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb))
|
128 |
+
conv8_3 = self.relu8_3(self.conv8_3(conv8_2))
|
129 |
+
conv8_3norm = self.conv8_3norm(conv8_3)
|
130 |
+
conv9_1 = self.conv9_1(conv8_3norm)
|
131 |
+
conv2_2_short = self.conv2_2_short(conv2_2norm)
|
132 |
+
conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short)
|
133 |
+
conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb))
|
134 |
+
conv9_2norm = self.conv9_2norm(conv9_2)
|
135 |
+
conv10_1 = self.conv10_1(conv9_2norm)
|
136 |
+
conv1_2_short = self.conv1_2_short(conv1_2norm)
|
137 |
+
conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short)
|
138 |
+
conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb))
|
139 |
+
conv10_ab = self.conv10_ab(conv10_2)
|
140 |
+
|
141 |
+
return torch.tanh(conv10_ab) * 128
|
src/models/CNN/FrameColor.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from src.utils import *
|
3 |
+
from src.models.vit.vit import FeatureTransform
|
4 |
+
|
5 |
+
|
6 |
+
def warp_color(
|
7 |
+
IA_l,
|
8 |
+
IB_lab,
|
9 |
+
features_B,
|
10 |
+
embed_net,
|
11 |
+
nonlocal_net,
|
12 |
+
temperature=0.01,
|
13 |
+
):
|
14 |
+
IA_rgb_from_gray = gray2rgb_batch(IA_l)
|
15 |
+
|
16 |
+
with torch.no_grad():
|
17 |
+
A_feat0, A_feat1, A_feat2, A_feat3 = embed_net(IA_rgb_from_gray)
|
18 |
+
B_feat0, B_feat1, B_feat2, B_feat3 = features_B
|
19 |
+
|
20 |
+
A_feat0 = feature_normalize(A_feat0)
|
21 |
+
A_feat1 = feature_normalize(A_feat1)
|
22 |
+
A_feat2 = feature_normalize(A_feat2)
|
23 |
+
A_feat3 = feature_normalize(A_feat3)
|
24 |
+
|
25 |
+
B_feat0 = feature_normalize(B_feat0)
|
26 |
+
B_feat1 = feature_normalize(B_feat1)
|
27 |
+
B_feat2 = feature_normalize(B_feat2)
|
28 |
+
B_feat3 = feature_normalize(B_feat3)
|
29 |
+
|
30 |
+
return nonlocal_net(
|
31 |
+
IB_lab,
|
32 |
+
A_feat0,
|
33 |
+
A_feat1,
|
34 |
+
A_feat2,
|
35 |
+
A_feat3,
|
36 |
+
B_feat0,
|
37 |
+
B_feat1,
|
38 |
+
B_feat2,
|
39 |
+
B_feat3,
|
40 |
+
temperature=temperature,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def frame_colorization(
|
45 |
+
IA_l,
|
46 |
+
IB_lab,
|
47 |
+
IA_last_lab,
|
48 |
+
features_B,
|
49 |
+
embed_net,
|
50 |
+
nonlocal_net,
|
51 |
+
colornet,
|
52 |
+
joint_training=True,
|
53 |
+
luminance_noise=0,
|
54 |
+
temperature=0.01,
|
55 |
+
):
|
56 |
+
if luminance_noise:
|
57 |
+
IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise
|
58 |
+
|
59 |
+
with torch.autograd.set_grad_enabled(joint_training):
|
60 |
+
nonlocal_BA_lab, similarity_map = warp_color(
|
61 |
+
IA_l,
|
62 |
+
IB_lab,
|
63 |
+
features_B,
|
64 |
+
embed_net,
|
65 |
+
nonlocal_net,
|
66 |
+
temperature=temperature,
|
67 |
+
)
|
68 |
+
nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :]
|
69 |
+
IA_ab_predict = colornet(
|
70 |
+
torch.cat(
|
71 |
+
(IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab),
|
72 |
+
dim=1,
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
return IA_ab_predict, nonlocal_BA_lab
|
src/models/CNN/GAN_models.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DCGAN-like generator and discriminator
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import Parameter
|
6 |
+
|
7 |
+
|
8 |
+
def l2normalize(v, eps=1e-12):
|
9 |
+
return v / (v.norm() + eps)
|
10 |
+
|
11 |
+
|
12 |
+
class SpectralNorm(nn.Module):
|
13 |
+
def __init__(self, module, name="weight", power_iterations=1):
|
14 |
+
super(SpectralNorm, self).__init__()
|
15 |
+
self.module = module
|
16 |
+
self.name = name
|
17 |
+
self.power_iterations = power_iterations
|
18 |
+
if not self._made_params():
|
19 |
+
self._make_params()
|
20 |
+
|
21 |
+
def _update_u_v(self):
|
22 |
+
u = getattr(self.module, self.name + "_u")
|
23 |
+
v = getattr(self.module, self.name + "_v")
|
24 |
+
w = getattr(self.module, self.name + "_bar")
|
25 |
+
|
26 |
+
height = w.data.shape[0]
|
27 |
+
for _ in range(self.power_iterations):
|
28 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
29 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
30 |
+
|
31 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
32 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
33 |
+
|
34 |
+
def _made_params(self):
|
35 |
+
try:
|
36 |
+
u = getattr(self.module, self.name + "_u")
|
37 |
+
v = getattr(self.module, self.name + "_v")
|
38 |
+
w = getattr(self.module, self.name + "_bar")
|
39 |
+
return True
|
40 |
+
except AttributeError:
|
41 |
+
return False
|
42 |
+
|
43 |
+
def _make_params(self):
|
44 |
+
w = getattr(self.module, self.name)
|
45 |
+
|
46 |
+
height = w.data.shape[0]
|
47 |
+
width = w.view(height, -1).data.shape[1]
|
48 |
+
|
49 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
50 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
51 |
+
u.data = l2normalize(u.data)
|
52 |
+
v.data = l2normalize(v.data)
|
53 |
+
w_bar = Parameter(w.data)
|
54 |
+
|
55 |
+
del self.module._parameters[self.name]
|
56 |
+
|
57 |
+
self.module.register_parameter(self.name + "_u", u)
|
58 |
+
self.module.register_parameter(self.name + "_v", v)
|
59 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
60 |
+
|
61 |
+
def forward(self, *args):
|
62 |
+
self._update_u_v()
|
63 |
+
return self.module.forward(*args)
|
64 |
+
|
65 |
+
|
66 |
+
class Generator(nn.Module):
|
67 |
+
def __init__(self, z_dim):
|
68 |
+
super(Generator, self).__init__()
|
69 |
+
self.z_dim = z_dim
|
70 |
+
|
71 |
+
self.model = nn.Sequential(
|
72 |
+
nn.ConvTranspose2d(z_dim, 512, 4, stride=1),
|
73 |
+
nn.InstanceNorm2d(512),
|
74 |
+
nn.ReLU(),
|
75 |
+
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)),
|
76 |
+
nn.InstanceNorm2d(256),
|
77 |
+
nn.ReLU(),
|
78 |
+
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)),
|
79 |
+
nn.InstanceNorm2d(128),
|
80 |
+
nn.ReLU(),
|
81 |
+
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)),
|
82 |
+
nn.InstanceNorm2d(64),
|
83 |
+
nn.ReLU(),
|
84 |
+
nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1, 1)),
|
85 |
+
nn.Tanh(),
|
86 |
+
)
|
87 |
+
|
88 |
+
def forward(self, z):
|
89 |
+
return self.model(z.view(-1, self.z_dim, 1, 1))
|
90 |
+
|
91 |
+
|
92 |
+
channels = 3
|
93 |
+
leak = 0.1
|
94 |
+
w_g = 4
|
95 |
+
|
96 |
+
|
97 |
+
class Discriminator(nn.Module):
|
98 |
+
def __init__(self):
|
99 |
+
super(Discriminator, self).__init__()
|
100 |
+
|
101 |
+
self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1, 1)))
|
102 |
+
self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1, 1)))
|
103 |
+
self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1, 1)))
|
104 |
+
self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1, 1)))
|
105 |
+
self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1, 1)))
|
106 |
+
self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1, 1)))
|
107 |
+
self.conv7 = SpectralNorm(nn.Conv2d(256, 256, 3, stride=1, padding=(1, 1)))
|
108 |
+
self.conv8 = SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1, 1)))
|
109 |
+
self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, 1))
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
m = x
|
113 |
+
m = nn.LeakyReLU(leak)(self.conv1(m))
|
114 |
+
m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(64)(self.conv2(m)))
|
115 |
+
m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv3(m)))
|
116 |
+
m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv4(m)))
|
117 |
+
m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv5(m)))
|
118 |
+
m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv6(m)))
|
119 |
+
m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv7(m)))
|
120 |
+
m = nn.LeakyReLU(leak)(self.conv8(m))
|
121 |
+
|
122 |
+
return self.fc(m.view(-1, w_g * w_g * 512))
|
123 |
+
|
124 |
+
|
125 |
+
class Self_Attention(nn.Module):
|
126 |
+
"""Self attention Layer"""
|
127 |
+
|
128 |
+
def __init__(self, in_dim):
|
129 |
+
super(Self_Attention, self).__init__()
|
130 |
+
self.chanel_in = in_dim
|
131 |
+
|
132 |
+
self.query_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
|
133 |
+
self.key_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
|
134 |
+
self.value_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1))
|
135 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
136 |
+
|
137 |
+
self.softmax = nn.Softmax(dim=-1) #
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
"""
|
141 |
+
inputs :
|
142 |
+
x : input feature maps( B X C X W X H)
|
143 |
+
returns :
|
144 |
+
out : self attention value + input feature
|
145 |
+
attention: B X N X N (N is Width*Height)
|
146 |
+
"""
|
147 |
+
m_batchsize, C, width, height = x.size()
|
148 |
+
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
|
149 |
+
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
|
150 |
+
energy = torch.bmm(proj_query, proj_key) # transpose check
|
151 |
+
attention = self.softmax(energy) # BX (N) X (N)
|
152 |
+
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
|
153 |
+
|
154 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
155 |
+
out = out.view(m_batchsize, C, width, height)
|
156 |
+
|
157 |
+
out = self.gamma * out + x
|
158 |
+
return out
|
159 |
+
|
160 |
+
class Discriminator_x64_224(nn.Module):
|
161 |
+
"""
|
162 |
+
Discriminative Network
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, in_size=6, ndf=64):
|
166 |
+
super(Discriminator_x64_224, self).__init__()
|
167 |
+
self.in_size = in_size
|
168 |
+
self.ndf = ndf
|
169 |
+
|
170 |
+
self.layer1 = nn.Sequential(SpectralNorm(nn.Conv2d(self.in_size, self.ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True))
|
171 |
+
|
172 |
+
self.layer2 = nn.Sequential(
|
173 |
+
SpectralNorm(nn.Conv2d(self.ndf, self.ndf, 4, 2, 1)),
|
174 |
+
nn.InstanceNorm2d(self.ndf),
|
175 |
+
nn.LeakyReLU(0.2, inplace=True),
|
176 |
+
)
|
177 |
+
self.attention = Self_Attention(self.ndf)
|
178 |
+
self.layer3 = nn.Sequential(
|
179 |
+
SpectralNorm(nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1)),
|
180 |
+
nn.InstanceNorm2d(self.ndf * 2),
|
181 |
+
nn.LeakyReLU(0.2, inplace=True),
|
182 |
+
)
|
183 |
+
self.layer4 = nn.Sequential(
|
184 |
+
SpectralNorm(nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1)),
|
185 |
+
nn.InstanceNorm2d(self.ndf * 4),
|
186 |
+
nn.LeakyReLU(0.2, inplace=True),
|
187 |
+
)
|
188 |
+
self.layer5 = nn.Sequential(
|
189 |
+
SpectralNorm(nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1)),
|
190 |
+
nn.InstanceNorm2d(self.ndf * 8),
|
191 |
+
nn.LeakyReLU(0.2, inplace=True),
|
192 |
+
)
|
193 |
+
self.layer6 = nn.Sequential(
|
194 |
+
SpectralNorm(nn.Conv2d(self.ndf * 8, self.ndf * 16, 4, 2, 1)),
|
195 |
+
nn.InstanceNorm2d(self.ndf * 16),
|
196 |
+
nn.LeakyReLU(0.2, inplace=True),
|
197 |
+
)
|
198 |
+
|
199 |
+
self.last = SpectralNorm(nn.Conv2d(self.ndf * 16, 1, [3, 3], 1, 0))
|
200 |
+
|
201 |
+
def forward(self, input):
|
202 |
+
feature1 = self.layer1(input)
|
203 |
+
feature2 = self.layer2(feature1)
|
204 |
+
feature_attention = self.attention(feature2)
|
205 |
+
feature3 = self.layer3(feature_attention)
|
206 |
+
feature4 = self.layer4(feature3)
|
207 |
+
feature5 = self.layer5(feature4)
|
208 |
+
feature6 = self.layer6(feature5)
|
209 |
+
output = self.last(feature6)
|
210 |
+
output = F.avg_pool2d(output, output.size()[2:]).view(output.size()[0], -1)
|
211 |
+
|
212 |
+
return output, feature4
|
src/models/CNN/NonlocalNet.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from src.utils import uncenter_l
|
6 |
+
|
7 |
+
|
8 |
+
def find_local_patch(x, patch_size):
|
9 |
+
"""
|
10 |
+
> We take a tensor `x` and return a tensor `x_unfold` that contains all the patches of size
|
11 |
+
`patch_size` in `x`
|
12 |
+
|
13 |
+
Args:
|
14 |
+
x: the input tensor
|
15 |
+
patch_size: the size of the patch to be extracted.
|
16 |
+
"""
|
17 |
+
|
18 |
+
N, C, H, W = x.shape
|
19 |
+
x_unfold = F.unfold(x, kernel_size=(patch_size, patch_size), padding=(patch_size // 2, patch_size // 2), stride=(1, 1))
|
20 |
+
|
21 |
+
return x_unfold.view(N, x_unfold.shape[1], H, W)
|
22 |
+
|
23 |
+
|
24 |
+
class WeightedAverage(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
):
|
28 |
+
super(WeightedAverage, self).__init__()
|
29 |
+
|
30 |
+
def forward(self, x_lab, patch_size=3, alpha=1, scale_factor=1):
|
31 |
+
"""
|
32 |
+
It takes a 3-channel image (L, A, B) and returns a 2-channel image (A, B) where each pixel is a
|
33 |
+
weighted average of the A and B values of the pixels in a 3x3 neighborhood around it
|
34 |
+
|
35 |
+
Args:
|
36 |
+
x_lab: the input image in LAB color space
|
37 |
+
patch_size: the size of the patch to use for the local average. Defaults to 3
|
38 |
+
alpha: the higher the alpha, the smoother the output. Defaults to 1
|
39 |
+
scale_factor: the scale factor of the input image. Defaults to 1
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
The output of the forward function is a tensor of size (batch_size, 2, height, width)
|
43 |
+
"""
|
44 |
+
# alpha=0: less smooth; alpha=inf: smoother
|
45 |
+
x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
|
46 |
+
l = x_lab[:, 0:1, :, :]
|
47 |
+
a = x_lab[:, 1:2, :, :]
|
48 |
+
b = x_lab[:, 2:3, :, :]
|
49 |
+
local_l = find_local_patch(l, patch_size)
|
50 |
+
local_a = find_local_patch(a, patch_size)
|
51 |
+
local_b = find_local_patch(b, patch_size)
|
52 |
+
local_difference_l = (local_l - l) ** 2
|
53 |
+
correlation = nn.functional.softmax(-1 * local_difference_l / alpha, dim=1)
|
54 |
+
|
55 |
+
return torch.cat(
|
56 |
+
(
|
57 |
+
torch.sum(correlation * local_a, dim=1, keepdim=True),
|
58 |
+
torch.sum(correlation * local_b, dim=1, keepdim=True),
|
59 |
+
),
|
60 |
+
1,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class WeightedAverage_color(nn.Module):
|
65 |
+
"""
|
66 |
+
smooth the image according to the color distance in the LAB space
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
):
|
72 |
+
super(WeightedAverage_color, self).__init__()
|
73 |
+
|
74 |
+
def forward(self, x_lab, x_lab_predict, patch_size=3, alpha=1, scale_factor=1):
|
75 |
+
"""
|
76 |
+
It takes the predicted a and b channels, and the original a and b channels, and finds the
|
77 |
+
weighted average of the predicted a and b channels based on the similarity of the original a and
|
78 |
+
b channels to the predicted a and b channels
|
79 |
+
|
80 |
+
Args:
|
81 |
+
x_lab: the input image in LAB color space
|
82 |
+
x_lab_predict: the predicted LAB image
|
83 |
+
patch_size: the size of the patch to use for the local color correction. Defaults to 3
|
84 |
+
alpha: controls the smoothness of the output. Defaults to 1
|
85 |
+
scale_factor: the scale factor of the input image. Defaults to 1
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
The return is the weighted average of the local a and b channels.
|
89 |
+
"""
|
90 |
+
""" alpha=0: less smooth; alpha=inf: smoother """
|
91 |
+
x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
|
92 |
+
l = uncenter_l(x_lab[:, 0:1, :, :])
|
93 |
+
a = x_lab[:, 1:2, :, :]
|
94 |
+
b = x_lab[:, 2:3, :, :]
|
95 |
+
a_predict = x_lab_predict[:, 1:2, :, :]
|
96 |
+
b_predict = x_lab_predict[:, 2:3, :, :]
|
97 |
+
local_l = find_local_patch(l, patch_size)
|
98 |
+
local_a = find_local_patch(a, patch_size)
|
99 |
+
local_b = find_local_patch(b, patch_size)
|
100 |
+
local_a_predict = find_local_patch(a_predict, patch_size)
|
101 |
+
local_b_predict = find_local_patch(b_predict, patch_size)
|
102 |
+
|
103 |
+
local_color_difference = (local_l - l) ** 2 + (local_a - a) ** 2 + (local_b - b) ** 2
|
104 |
+
# so that sum of weights equal to 1
|
105 |
+
correlation = nn.functional.softmax(-1 * local_color_difference / alpha, dim=1)
|
106 |
+
|
107 |
+
return torch.cat(
|
108 |
+
(
|
109 |
+
torch.sum(correlation * local_a_predict, dim=1, keepdim=True),
|
110 |
+
torch.sum(correlation * local_b_predict, dim=1, keepdim=True),
|
111 |
+
),
|
112 |
+
1,
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
class NonlocalWeightedAverage(nn.Module):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
):
|
120 |
+
super(NonlocalWeightedAverage, self).__init__()
|
121 |
+
|
122 |
+
def forward(self, x_lab, feature, patch_size=3, alpha=0.1, scale_factor=1):
|
123 |
+
"""
|
124 |
+
It takes in a feature map and a label map, and returns a smoothed label map
|
125 |
+
|
126 |
+
Args:
|
127 |
+
x_lab: the input image in LAB color space
|
128 |
+
feature: the feature map of the input image
|
129 |
+
patch_size: the size of the patch to be used for the correlation matrix. Defaults to 3
|
130 |
+
alpha: the higher the alpha, the smoother the output.
|
131 |
+
scale_factor: the scale factor of the input image. Defaults to 1
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
weighted_ab is the weighted ab channel of the image.
|
135 |
+
"""
|
136 |
+
# alpha=0: less smooth; alpha=inf: smoother
|
137 |
+
# input feature is normalized feature
|
138 |
+
x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
|
139 |
+
batch_size, channel, height, width = x_lab.shape
|
140 |
+
feature = F.interpolate(feature, size=(height, width))
|
141 |
+
batch_size = x_lab.shape[0]
|
142 |
+
x_ab = x_lab[:, 1:3, :, :].view(batch_size, 2, -1)
|
143 |
+
x_ab = x_ab.permute(0, 2, 1)
|
144 |
+
|
145 |
+
local_feature = find_local_patch(feature, patch_size)
|
146 |
+
local_feature = local_feature.view(batch_size, local_feature.shape[1], -1)
|
147 |
+
|
148 |
+
correlation_matrix = torch.matmul(local_feature.permute(0, 2, 1), local_feature)
|
149 |
+
correlation_matrix = nn.functional.softmax(correlation_matrix / alpha, dim=-1)
|
150 |
+
|
151 |
+
weighted_ab = torch.matmul(correlation_matrix, x_ab)
|
152 |
+
weighted_ab = weighted_ab.permute(0, 2, 1).contiguous()
|
153 |
+
weighted_ab = weighted_ab.view(batch_size, 2, height, width)
|
154 |
+
return weighted_ab
|
155 |
+
|
156 |
+
|
157 |
+
class CorrelationLayer(nn.Module):
|
158 |
+
def __init__(self, search_range):
|
159 |
+
super(CorrelationLayer, self).__init__()
|
160 |
+
self.search_range = search_range
|
161 |
+
|
162 |
+
def forward(self, x1, x2, alpha=1, raw_output=False, metric="similarity"):
|
163 |
+
"""
|
164 |
+
It takes two tensors, x1 and x2, and returns a tensor of shape (batch_size, (search_range * 2 +
|
165 |
+
1) ** 2, height, width) where each element is the dot product of the corresponding patch in x1
|
166 |
+
and x2
|
167 |
+
|
168 |
+
Args:
|
169 |
+
x1: the first image
|
170 |
+
x2: the image to be warped
|
171 |
+
alpha: the temperature parameter for the softmax function. Defaults to 1
|
172 |
+
raw_output: if True, return the raw output of the network, otherwise return the softmax
|
173 |
+
output. Defaults to False
|
174 |
+
metric: "similarity" or "subtraction". Defaults to similarity
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
The output of the forward function is a softmax of the correlation volume.
|
178 |
+
"""
|
179 |
+
shape = list(x1.size())
|
180 |
+
shape[1] = (self.search_range * 2 + 1) ** 2
|
181 |
+
cv = torch.zeros(shape).to(torch.device("cuda"))
|
182 |
+
|
183 |
+
for i in range(-self.search_range, self.search_range + 1):
|
184 |
+
for j in range(-self.search_range, self.search_range + 1):
|
185 |
+
if i < 0:
|
186 |
+
slice_h, slice_h_r = slice(None, i), slice(-i, None)
|
187 |
+
elif i > 0:
|
188 |
+
slice_h, slice_h_r = slice(i, None), slice(None, -i)
|
189 |
+
else:
|
190 |
+
slice_h, slice_h_r = slice(None), slice(None)
|
191 |
+
|
192 |
+
if j < 0:
|
193 |
+
slice_w, slice_w_r = slice(None, j), slice(-j, None)
|
194 |
+
elif j > 0:
|
195 |
+
slice_w, slice_w_r = slice(j, None), slice(None, -j)
|
196 |
+
else:
|
197 |
+
slice_w, slice_w_r = slice(None), slice(None)
|
198 |
+
|
199 |
+
if metric == "similarity":
|
200 |
+
cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = (
|
201 |
+
x1[:, :, slice_h, slice_w] * x2[:, :, slice_h_r, slice_w_r]
|
202 |
+
).sum(1)
|
203 |
+
else: # patchwise subtraction
|
204 |
+
cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = -(
|
205 |
+
(x1[:, :, slice_h, slice_w] - x2[:, :, slice_h_r, slice_w_r]) ** 2
|
206 |
+
).sum(1)
|
207 |
+
|
208 |
+
# TODO sigmoid?
|
209 |
+
if raw_output:
|
210 |
+
return cv
|
211 |
+
else:
|
212 |
+
return nn.functional.softmax(cv / alpha, dim=1)
|
213 |
+
|
214 |
+
|
215 |
+
class WTA_scale(torch.autograd.Function):
|
216 |
+
"""
|
217 |
+
We can implement our own custom autograd Functions by subclassing
|
218 |
+
torch.autograd.Function and implementing the forward and backward passes
|
219 |
+
which operate on Tensors.
|
220 |
+
"""
|
221 |
+
|
222 |
+
@staticmethod
|
223 |
+
def forward(ctx, input, scale=1e-4):
|
224 |
+
"""
|
225 |
+
In the forward pass we receive a Tensor containing the input and return a
|
226 |
+
Tensor containing the output. You can cache arbitrary Tensors for use in the
|
227 |
+
backward pass using the save_for_backward method.
|
228 |
+
"""
|
229 |
+
activation_max, index_max = torch.max(input, -1, keepdim=True)
|
230 |
+
input_scale = input * scale # default: 1e-4
|
231 |
+
# input_scale = input * scale # default: 1e-4
|
232 |
+
output_max_scale = torch.where(input == activation_max, input, input_scale)
|
233 |
+
|
234 |
+
mask = (input == activation_max).type(torch.float)
|
235 |
+
ctx.save_for_backward(input, mask)
|
236 |
+
return output_max_scale
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def backward(ctx, grad_output):
|
240 |
+
"""
|
241 |
+
In the backward pass we receive a Tensor containing the gradient of the loss
|
242 |
+
with respect to the output, and we need to compute the gradient of the loss
|
243 |
+
with respect to the input.
|
244 |
+
"""
|
245 |
+
input, mask = ctx.saved_tensors
|
246 |
+
mask_ones = torch.ones_like(mask)
|
247 |
+
mask_small_ones = torch.ones_like(mask) * 1e-4
|
248 |
+
# mask_small_ones = torch.ones_like(mask) * 1e-4
|
249 |
+
|
250 |
+
grad_scale = torch.where(mask == 1, mask_ones, mask_small_ones)
|
251 |
+
grad_input = grad_output.clone() * grad_scale
|
252 |
+
return grad_input, None
|
253 |
+
|
254 |
+
|
255 |
+
class ResidualBlock(nn.Module):
|
256 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
|
257 |
+
super(ResidualBlock, self).__init__()
|
258 |
+
self.padding1 = nn.ReflectionPad2d(padding)
|
259 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
|
260 |
+
self.bn1 = nn.InstanceNorm2d(out_channels)
|
261 |
+
self.prelu = nn.PReLU()
|
262 |
+
self.padding2 = nn.ReflectionPad2d(padding)
|
263 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
|
264 |
+
self.bn2 = nn.InstanceNorm2d(out_channels)
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
residual = x
|
268 |
+
out = self.padding1(x)
|
269 |
+
out = self.conv1(out)
|
270 |
+
out = self.bn1(out)
|
271 |
+
out = self.prelu(out)
|
272 |
+
out = self.padding2(out)
|
273 |
+
out = self.conv2(out)
|
274 |
+
out = self.bn2(out)
|
275 |
+
out += residual
|
276 |
+
out = self.prelu(out)
|
277 |
+
return out
|
278 |
+
|
279 |
+
|
280 |
+
class WarpNet(nn.Module):
|
281 |
+
"""input is Al, Bl, channel = 1, range~[0,255]"""
|
282 |
+
|
283 |
+
def __init__(self, feature_channel=128):
|
284 |
+
super(WarpNet, self).__init__()
|
285 |
+
self.feature_channel = feature_channel
|
286 |
+
self.in_channels = self.feature_channel * 4
|
287 |
+
self.inter_channels = 256
|
288 |
+
# 44*44
|
289 |
+
self.layer2_1 = nn.Sequential(
|
290 |
+
nn.ReflectionPad2d(1),
|
291 |
+
# nn.Conv2d(128, 128, kernel_size=3, padding=0, stride=1),
|
292 |
+
# nn.Conv2d(96, 128, kernel_size=3, padding=20, stride=1),
|
293 |
+
nn.Conv2d(96, 128, kernel_size=3, padding=0, stride=1),
|
294 |
+
nn.InstanceNorm2d(128),
|
295 |
+
nn.PReLU(),
|
296 |
+
nn.ReflectionPad2d(1),
|
297 |
+
nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=2),
|
298 |
+
nn.InstanceNorm2d(self.feature_channel),
|
299 |
+
nn.PReLU(),
|
300 |
+
nn.Dropout(0.2),
|
301 |
+
)
|
302 |
+
self.layer3_1 = nn.Sequential(
|
303 |
+
nn.ReflectionPad2d(1),
|
304 |
+
# nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1),
|
305 |
+
# nn.Conv2d(192, 128, kernel_size=3, padding=10, stride=1),
|
306 |
+
nn.Conv2d(192, 128, kernel_size=3, padding=0, stride=1),
|
307 |
+
nn.InstanceNorm2d(128),
|
308 |
+
nn.PReLU(),
|
309 |
+
nn.ReflectionPad2d(1),
|
310 |
+
nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=1),
|
311 |
+
nn.InstanceNorm2d(self.feature_channel),
|
312 |
+
nn.PReLU(),
|
313 |
+
nn.Dropout(0.2),
|
314 |
+
)
|
315 |
+
|
316 |
+
# 22*22->44*44
|
317 |
+
self.layer4_1 = nn.Sequential(
|
318 |
+
nn.ReflectionPad2d(1),
|
319 |
+
# nn.Conv2d(512, 256, kernel_size=3, padding=0, stride=1),
|
320 |
+
# nn.Conv2d(384, 256, kernel_size=3, padding=5, stride=1),
|
321 |
+
nn.Conv2d(384, 256, kernel_size=3, padding=0, stride=1),
|
322 |
+
nn.InstanceNorm2d(256),
|
323 |
+
nn.PReLU(),
|
324 |
+
nn.ReflectionPad2d(1),
|
325 |
+
nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1),
|
326 |
+
nn.InstanceNorm2d(self.feature_channel),
|
327 |
+
nn.PReLU(),
|
328 |
+
nn.Upsample(scale_factor=2),
|
329 |
+
nn.Dropout(0.2),
|
330 |
+
)
|
331 |
+
|
332 |
+
# 11*11->44*44
|
333 |
+
self.layer5_1 = nn.Sequential(
|
334 |
+
nn.ReflectionPad2d(1),
|
335 |
+
# nn.Conv2d(1024, 256, kernel_size=3, padding=0, stride=1),
|
336 |
+
# nn.Conv2d(768, 256, kernel_size=2, padding=2, stride=1),
|
337 |
+
nn.Conv2d(768, 256, kernel_size=3, padding=0, stride=1),
|
338 |
+
nn.InstanceNorm2d(256),
|
339 |
+
nn.PReLU(),
|
340 |
+
nn.Upsample(scale_factor=2),
|
341 |
+
nn.ReflectionPad2d(1),
|
342 |
+
nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1),
|
343 |
+
nn.InstanceNorm2d(self.feature_channel),
|
344 |
+
nn.PReLU(),
|
345 |
+
nn.Upsample(scale_factor=2),
|
346 |
+
nn.Dropout(0.2),
|
347 |
+
)
|
348 |
+
|
349 |
+
self.layer = nn.Sequential(
|
350 |
+
ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
|
351 |
+
ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
|
352 |
+
ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
|
353 |
+
)
|
354 |
+
|
355 |
+
self.theta = nn.Conv2d(
|
356 |
+
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
357 |
+
)
|
358 |
+
self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
|
359 |
+
|
360 |
+
self.upsampling = nn.Upsample(scale_factor=4)
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
B_lab_map,
|
365 |
+
A_relu2_1,
|
366 |
+
A_relu3_1,
|
367 |
+
A_relu4_1,
|
368 |
+
A_relu5_1,
|
369 |
+
B_relu2_1,
|
370 |
+
B_relu3_1,
|
371 |
+
B_relu4_1,
|
372 |
+
B_relu5_1,
|
373 |
+
temperature=0.001 * 5,
|
374 |
+
detach_flag=False,
|
375 |
+
WTA_scale_weight=1,
|
376 |
+
):
|
377 |
+
batch_size = B_lab_map.shape[0]
|
378 |
+
channel = B_lab_map.shape[1]
|
379 |
+
image_height = B_lab_map.shape[2]
|
380 |
+
image_width = B_lab_map.shape[3]
|
381 |
+
feature_height = int(image_height / 4)
|
382 |
+
feature_width = int(image_width / 4)
|
383 |
+
|
384 |
+
# scale feature size to 44*44
|
385 |
+
A_feature2_1 = self.layer2_1(A_relu2_1)
|
386 |
+
B_feature2_1 = self.layer2_1(B_relu2_1)
|
387 |
+
A_feature3_1 = self.layer3_1(A_relu3_1)
|
388 |
+
B_feature3_1 = self.layer3_1(B_relu3_1)
|
389 |
+
A_feature4_1 = self.layer4_1(A_relu4_1)
|
390 |
+
B_feature4_1 = self.layer4_1(B_relu4_1)
|
391 |
+
A_feature5_1 = self.layer5_1(A_relu5_1)
|
392 |
+
B_feature5_1 = self.layer5_1(B_relu5_1)
|
393 |
+
|
394 |
+
# concatenate features
|
395 |
+
if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]:
|
396 |
+
A_feature5_1 = F.pad(A_feature5_1, (0, 0, 1, 1), "replicate")
|
397 |
+
B_feature5_1 = F.pad(B_feature5_1, (0, 0, 1, 1), "replicate")
|
398 |
+
|
399 |
+
A_features = self.layer(torch.cat((A_feature2_1, A_feature3_1, A_feature4_1, A_feature5_1), 1))
|
400 |
+
B_features = self.layer(torch.cat((B_feature2_1, B_feature3_1, B_feature4_1, B_feature5_1), 1))
|
401 |
+
|
402 |
+
# pairwise cosine similarity
|
403 |
+
theta = self.theta(A_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width)
|
404 |
+
theta = theta - theta.mean(dim=-1, keepdim=True) # center the feature
|
405 |
+
theta_norm = torch.norm(theta, 2, 1, keepdim=True) + sys.float_info.epsilon
|
406 |
+
theta = torch.div(theta, theta_norm)
|
407 |
+
theta_permute = theta.permute(0, 2, 1) # 2*(feature_height*feature_width)*256
|
408 |
+
phi = self.phi(B_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width)
|
409 |
+
phi = phi - phi.mean(dim=-1, keepdim=True) # center the feature
|
410 |
+
phi_norm = torch.norm(phi, 2, 1, keepdim=True) + sys.float_info.epsilon
|
411 |
+
phi = torch.div(phi, phi_norm)
|
412 |
+
f = torch.matmul(theta_permute, phi) # 2*(feature_height*feature_width)*(feature_height*feature_width)
|
413 |
+
if detach_flag:
|
414 |
+
f = f.detach()
|
415 |
+
|
416 |
+
f_similarity = f.unsqueeze_(dim=1)
|
417 |
+
similarity_map = torch.max(f_similarity, -1, keepdim=True)[0]
|
418 |
+
similarity_map = similarity_map.view(batch_size, 1, feature_height, feature_width)
|
419 |
+
|
420 |
+
# f can be negative
|
421 |
+
f_WTA = f if WTA_scale_weight == 1 else WTA_scale.apply(f, WTA_scale_weight)
|
422 |
+
f_WTA = f_WTA / temperature
|
423 |
+
f_div_C = F.softmax(f_WTA.squeeze_(), dim=-1) # 2*1936*1936;
|
424 |
+
|
425 |
+
# downsample the reference color
|
426 |
+
B_lab = F.avg_pool2d(B_lab_map, 4)
|
427 |
+
B_lab = B_lab.view(batch_size, channel, -1)
|
428 |
+
B_lab = B_lab.permute(0, 2, 1) # 2*1936*channel
|
429 |
+
|
430 |
+
# multiply the corr map with color
|
431 |
+
y = torch.matmul(f_div_C, B_lab) # 2*1936*channel
|
432 |
+
y = y.permute(0, 2, 1).contiguous()
|
433 |
+
y = y.view(batch_size, channel, feature_height, feature_width) # 2*3*44*44
|
434 |
+
y = self.upsampling(y)
|
435 |
+
similarity_map = self.upsampling(similarity_map)
|
436 |
+
|
437 |
+
return y, similarity_map
|
src/models/CNN/__init__.py
ADDED
File without changes
|
src/models/__init__.py
ADDED
File without changes
|
src/models/vit/__init__.py
ADDED
File without changes
|
src/models/vit/blocks.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from timm.models.layers import DropPath
|
3 |
+
|
4 |
+
|
5 |
+
class FeedForward(nn.Module):
|
6 |
+
def __init__(self, dim, hidden_dim, dropout, out_dim=None):
|
7 |
+
super().__init__()
|
8 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
9 |
+
self.act = nn.GELU()
|
10 |
+
if out_dim is None:
|
11 |
+
out_dim = dim
|
12 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
13 |
+
self.drop = nn.Dropout(dropout)
|
14 |
+
|
15 |
+
@property
|
16 |
+
def unwrapped(self):
|
17 |
+
return self
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x = self.fc1(x)
|
21 |
+
x = self.act(x)
|
22 |
+
x = self.drop(x)
|
23 |
+
x = self.fc2(x)
|
24 |
+
x = self.drop(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class Attention(nn.Module):
|
29 |
+
def __init__(self, dim, heads, dropout):
|
30 |
+
super().__init__()
|
31 |
+
self.heads = heads
|
32 |
+
head_dim = dim // heads
|
33 |
+
self.scale = head_dim**-0.5
|
34 |
+
self.attn = None
|
35 |
+
|
36 |
+
self.qkv = nn.Linear(dim, dim * 3)
|
37 |
+
self.attn_drop = nn.Dropout(dropout)
|
38 |
+
self.proj = nn.Linear(dim, dim)
|
39 |
+
self.proj_drop = nn.Dropout(dropout)
|
40 |
+
|
41 |
+
@property
|
42 |
+
def unwrapped(self):
|
43 |
+
return self
|
44 |
+
|
45 |
+
def forward(self, x, mask=None):
|
46 |
+
B, N, C = x.shape
|
47 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
|
48 |
+
q, k, v = (
|
49 |
+
qkv[0],
|
50 |
+
qkv[1],
|
51 |
+
qkv[2],
|
52 |
+
)
|
53 |
+
|
54 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
55 |
+
attn = attn.softmax(dim=-1)
|
56 |
+
attn = self.attn_drop(attn)
|
57 |
+
|
58 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
59 |
+
x = self.proj(x)
|
60 |
+
x = self.proj_drop(x)
|
61 |
+
|
62 |
+
return x, attn
|
63 |
+
|
64 |
+
|
65 |
+
class Block(nn.Module):
|
66 |
+
def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
|
67 |
+
super().__init__()
|
68 |
+
self.norm1 = nn.LayerNorm(dim)
|
69 |
+
self.norm2 = nn.LayerNorm(dim)
|
70 |
+
self.attn = Attention(dim, heads, dropout)
|
71 |
+
self.mlp = FeedForward(dim, mlp_dim, dropout)
|
72 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
73 |
+
|
74 |
+
def forward(self, x, mask=None, return_attention=False):
|
75 |
+
y, attn = self.attn(self.norm1(x), mask)
|
76 |
+
if return_attention:
|
77 |
+
return attn
|
78 |
+
x = x + self.drop_path(y)
|
79 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
80 |
+
return x
|
src/models/vit/config.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def load_config():
|
8 |
+
return yaml.load(
|
9 |
+
open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def check_os_environ(key, use):
|
14 |
+
if key not in os.environ:
|
15 |
+
raise ValueError(
|
16 |
+
f"{key} is not defined in the os variables, it is required for {use}."
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def dataset_dir():
|
21 |
+
check_os_environ("DATASET", "data loading")
|
22 |
+
return os.environ["DATASET"]
|
src/models/vit/config.yml
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
# deit
|
3 |
+
deit_tiny_distilled_patch16_224:
|
4 |
+
image_size: 224
|
5 |
+
patch_size: 16
|
6 |
+
d_model: 192
|
7 |
+
n_heads: 3
|
8 |
+
n_layers: 12
|
9 |
+
normalization: deit
|
10 |
+
distilled: true
|
11 |
+
deit_small_distilled_patch16_224:
|
12 |
+
image_size: 224
|
13 |
+
patch_size: 16
|
14 |
+
d_model: 384
|
15 |
+
n_heads: 6
|
16 |
+
n_layers: 12
|
17 |
+
normalization: deit
|
18 |
+
distilled: true
|
19 |
+
deit_base_distilled_patch16_224:
|
20 |
+
image_size: 224
|
21 |
+
patch_size: 16
|
22 |
+
d_model: 768
|
23 |
+
n_heads: 12
|
24 |
+
n_layers: 12
|
25 |
+
normalization: deit
|
26 |
+
distilled: true
|
27 |
+
deit_base_distilled_patch16_384:
|
28 |
+
image_size: 384
|
29 |
+
patch_size: 16
|
30 |
+
d_model: 768
|
31 |
+
n_heads: 12
|
32 |
+
n_layers: 12
|
33 |
+
normalization: deit
|
34 |
+
distilled: true
|
35 |
+
# vit
|
36 |
+
vit_base_patch8_384:
|
37 |
+
image_size: 384
|
38 |
+
patch_size: 8
|
39 |
+
d_model: 768
|
40 |
+
n_heads: 12
|
41 |
+
n_layers: 12
|
42 |
+
normalization: vit
|
43 |
+
distilled: false
|
44 |
+
vit_tiny_patch16_384:
|
45 |
+
image_size: 384
|
46 |
+
patch_size: 16
|
47 |
+
d_model: 192
|
48 |
+
n_heads: 3
|
49 |
+
n_layers: 12
|
50 |
+
normalization: vit
|
51 |
+
distilled: false
|
52 |
+
vit_small_patch16_384:
|
53 |
+
image_size: 384
|
54 |
+
patch_size: 16
|
55 |
+
d_model: 384
|
56 |
+
n_heads: 6
|
57 |
+
n_layers: 12
|
58 |
+
normalization: vit
|
59 |
+
distilled: false
|
60 |
+
vit_base_patch16_384:
|
61 |
+
image_size: 384
|
62 |
+
patch_size: 16
|
63 |
+
d_model: 768
|
64 |
+
n_heads: 12
|
65 |
+
n_layers: 12
|
66 |
+
normalization: vit
|
67 |
+
distilled: false
|
68 |
+
vit_large_patch16_384:
|
69 |
+
image_size: 384
|
70 |
+
patch_size: 16
|
71 |
+
d_model: 1024
|
72 |
+
n_heads: 16
|
73 |
+
n_layers: 24
|
74 |
+
normalization: vit
|
75 |
+
vit_small_patch32_384:
|
76 |
+
image_size: 384
|
77 |
+
patch_size: 32
|
78 |
+
d_model: 384
|
79 |
+
n_heads: 6
|
80 |
+
n_layers: 12
|
81 |
+
normalization: vit
|
82 |
+
distilled: false
|
83 |
+
vit_base_patch32_384:
|
84 |
+
image_size: 384
|
85 |
+
patch_size: 32
|
86 |
+
d_model: 768
|
87 |
+
n_heads: 12
|
88 |
+
n_layers: 12
|
89 |
+
normalization: vit
|
90 |
+
vit_large_patch32_384:
|
91 |
+
image_size: 384
|
92 |
+
patch_size: 32
|
93 |
+
d_model: 1024
|
94 |
+
n_heads: 16
|
95 |
+
n_layers: 24
|
96 |
+
normalization: vit
|
97 |
+
decoder:
|
98 |
+
linear: {}
|
99 |
+
deeplab_dec:
|
100 |
+
encoder_layer: -1
|
101 |
+
mask_transformer:
|
102 |
+
drop_path_rate: 0.0
|
103 |
+
dropout: 0.1
|
104 |
+
n_layers: 2
|
105 |
+
dataset:
|
106 |
+
ade20k:
|
107 |
+
epochs: 64
|
108 |
+
eval_freq: 2
|
109 |
+
batch_size: 8
|
110 |
+
learning_rate: 0.001
|
111 |
+
im_size: 512
|
112 |
+
crop_size: 512
|
113 |
+
window_size: 512
|
114 |
+
window_stride: 512
|
115 |
+
pascal_context:
|
116 |
+
epochs: 256
|
117 |
+
eval_freq: 8
|
118 |
+
batch_size: 16
|
119 |
+
learning_rate: 0.001
|
120 |
+
im_size: 520
|
121 |
+
crop_size: 480
|
122 |
+
window_size: 480
|
123 |
+
window_stride: 320
|
124 |
+
cityscapes:
|
125 |
+
epochs: 216
|
126 |
+
eval_freq: 4
|
127 |
+
batch_size: 8
|
128 |
+
learning_rate: 0.01
|
129 |
+
im_size: 1024
|
130 |
+
crop_size: 768
|
131 |
+
window_size: 768
|
132 |
+
window_stride: 512
|
src/models/vit/decoder.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from einops import rearrange
|
3 |
+
from src.models.vit.utils import init_weights
|
4 |
+
|
5 |
+
|
6 |
+
class DecoderLinear(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
n_cls,
|
10 |
+
d_encoder,
|
11 |
+
scale_factor,
|
12 |
+
dropout_rate=0.3,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.scale_factor = scale_factor
|
16 |
+
self.head = nn.Linear(d_encoder, n_cls)
|
17 |
+
self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear")
|
18 |
+
self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor))
|
19 |
+
self.dropout = nn.Dropout(dropout_rate)
|
20 |
+
self.gelu = nn.GELU()
|
21 |
+
self.apply(init_weights)
|
22 |
+
|
23 |
+
def forward(self, x, img_size):
|
24 |
+
H, _ = img_size
|
25 |
+
x = self.head(x) ####### (2, 577, 64)
|
26 |
+
x = x.transpose(2, 1) ## (2, 64, 576)
|
27 |
+
x = self.upsampling(x) # (2, 64, 576*scale_factor*scale_factor)
|
28 |
+
x = x.transpose(2, 1) ## (2, 576*scale_factor*scale_factor, 64)
|
29 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor)) # (2, 64, 24*scale_factor, 24*scale_factor)
|
30 |
+
x = self.norm(x)
|
31 |
+
x = self.dropout(x)
|
32 |
+
x = self.gelu(x)
|
33 |
+
|
34 |
+
return x # (2, 64, a, a)
|
src/models/vit/embed.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from timm import create_model
|
3 |
+
from torchvision.transforms import Normalize
|
4 |
+
|
5 |
+
class SwinModel(nn.Module):
|
6 |
+
def __init__(self, pretrained_model="swinv2-cr-t-224", device="cuda") -> None:
|
7 |
+
"""
|
8 |
+
vit_tiny_patch16_224.augreg_in21k_ft_in1k
|
9 |
+
swinv2_cr_tiny_ns_224.sw_in1k
|
10 |
+
"""
|
11 |
+
super().__init__()
|
12 |
+
self.device = device
|
13 |
+
self.pretrained_model = pretrained_model
|
14 |
+
if pretrained_model == "swinv2-cr-t-224":
|
15 |
+
self.pretrained = create_model(
|
16 |
+
"swinv2_cr_tiny_ns_224.sw_in1k",
|
17 |
+
pretrained=True,
|
18 |
+
features_only=True,
|
19 |
+
out_indices=[-4, -3, -2, -1],
|
20 |
+
).to(device)
|
21 |
+
elif pretrained_model == "swinv2-t-256":
|
22 |
+
self.pretrained = create_model(
|
23 |
+
"swinv2_tiny_window16_256.ms_in1k",
|
24 |
+
pretrained=True,
|
25 |
+
features_only=True,
|
26 |
+
out_indices=[-4, -3, -2, -1],
|
27 |
+
).to(device)
|
28 |
+
elif pretrained_model == "swinv2-cr-s-224":
|
29 |
+
self.pretrained = create_model(
|
30 |
+
"swinv2_cr_small_ns_224.sw_in1k",
|
31 |
+
pretrained=True,
|
32 |
+
features_only=True,
|
33 |
+
out_indices=[-4, -3, -2, -1],
|
34 |
+
).to(device)
|
35 |
+
else:
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
self.pretrained.eval()
|
39 |
+
self.normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
40 |
+
self.upsample = nn.Upsample(scale_factor=2)
|
41 |
+
|
42 |
+
for params in self.pretrained.parameters():
|
43 |
+
params.requires_grad = False
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
outputs = self.pretrained(x)
|
47 |
+
if self.pretrained_model in ["swinv2-t-256"]:
|
48 |
+
for i in range(len(outputs)):
|
49 |
+
outputs[i] = outputs[i].permute(0, 3, 1, 2) # Change channel-last to channel-first
|
50 |
+
outputs = [self.upsample(feat) for feat in outputs]
|
51 |
+
|
52 |
+
return outputs
|
src/models/vit/factory.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from timm.models.vision_transformer import default_cfgs
|
4 |
+
from timm.models.helpers import load_pretrained, load_custom_pretrained
|
5 |
+
from src.models.vit.utils import checkpoint_filter_fn
|
6 |
+
from src.models.vit.vit import VisionTransformer
|
7 |
+
|
8 |
+
|
9 |
+
def create_vit(model_cfg):
|
10 |
+
model_cfg = model_cfg.copy()
|
11 |
+
backbone = model_cfg.pop("backbone")
|
12 |
+
|
13 |
+
model_cfg.pop("normalization")
|
14 |
+
model_cfg["n_cls"] = 1000
|
15 |
+
mlp_expansion_ratio = 4
|
16 |
+
model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
|
17 |
+
|
18 |
+
if backbone in default_cfgs:
|
19 |
+
default_cfg = default_cfgs[backbone]
|
20 |
+
else:
|
21 |
+
default_cfg = dict(
|
22 |
+
pretrained=False,
|
23 |
+
num_classes=1000,
|
24 |
+
drop_rate=0.0,
|
25 |
+
drop_path_rate=0.0,
|
26 |
+
drop_block_rate=None,
|
27 |
+
)
|
28 |
+
|
29 |
+
default_cfg["input_size"] = (
|
30 |
+
3,
|
31 |
+
model_cfg["image_size"][0],
|
32 |
+
model_cfg["image_size"][1],
|
33 |
+
)
|
34 |
+
model = VisionTransformer(**model_cfg)
|
35 |
+
if backbone == "vit_base_patch8_384":
|
36 |
+
path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth")
|
37 |
+
state_dict = torch.load(path, map_location="cpu")
|
38 |
+
filtered_dict = checkpoint_filter_fn(state_dict, model)
|
39 |
+
model.load_state_dict(filtered_dict, strict=True)
|
40 |
+
elif "deit" in backbone:
|
41 |
+
load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
|
42 |
+
else:
|
43 |
+
load_custom_pretrained(model, default_cfg)
|
44 |
+
|
45 |
+
return model
|
src/models/vit/utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from timm.models.layers import trunc_normal_
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
|
10 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
11 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
12 |
+
posemb_tok, posemb_grid = (
|
13 |
+
posemb[:, :num_extra_tokens],
|
14 |
+
posemb[0, num_extra_tokens:],
|
15 |
+
)
|
16 |
+
if grid_old_shape is None:
|
17 |
+
gs_old_h = int(math.sqrt(len(posemb_grid)))
|
18 |
+
gs_old_w = gs_old_h
|
19 |
+
else:
|
20 |
+
gs_old_h, gs_old_w = grid_old_shape
|
21 |
+
|
22 |
+
gs_h, gs_w = grid_new_shape
|
23 |
+
posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
|
24 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
25 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
26 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
27 |
+
return posemb
|
28 |
+
|
29 |
+
|
30 |
+
def init_weights(m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
elif isinstance(m, nn.LayerNorm):
|
36 |
+
nn.init.constant_(m.bias, 0)
|
37 |
+
nn.init.constant_(m.weight, 1.0)
|
38 |
+
|
39 |
+
|
40 |
+
def checkpoint_filter_fn(state_dict, model):
|
41 |
+
"""convert patch embedding weight from manual patchify + linear proj to conv"""
|
42 |
+
out_dict = {}
|
43 |
+
if "model" in state_dict:
|
44 |
+
# For deit models
|
45 |
+
state_dict = state_dict["model"]
|
46 |
+
num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
|
47 |
+
patch_size = model.patch_size
|
48 |
+
image_size = model.patch_embed.image_size
|
49 |
+
for k, v in state_dict.items():
|
50 |
+
if k == "pos_embed" and v.shape != model.pos_embed.shape:
|
51 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
52 |
+
v = resize_pos_embed(
|
53 |
+
v,
|
54 |
+
None,
|
55 |
+
(image_size[0] // patch_size, image_size[1] // patch_size),
|
56 |
+
num_extra_tokens,
|
57 |
+
)
|
58 |
+
out_dict[k] = v
|
59 |
+
return out_dict
|
60 |
+
|
61 |
+
def load_params(ckpt_file, device):
|
62 |
+
# params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}')
|
63 |
+
# new_params = []
|
64 |
+
# for key, value in params.items():
|
65 |
+
# new_params.append(("module."+key if has_module else key, value))
|
66 |
+
# return OrderedDict(new_params)
|
67 |
+
params = torch.load(ckpt_file, map_location=device)
|
68 |
+
new_params = []
|
69 |
+
for key, value in params.items():
|
70 |
+
new_params.append((key, value))
|
71 |
+
return OrderedDict(new_params)
|
src/models/vit/vit.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from timm.models.vision_transformer import _load_weights
|
5 |
+
from timm.models.layers import trunc_normal_
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from src.models.vit.utils import init_weights, resize_pos_embed
|
9 |
+
from src.models.vit.blocks import Block
|
10 |
+
from src.models.vit.decoder import DecoderLinear
|
11 |
+
|
12 |
+
|
13 |
+
class PatchEmbedding(nn.Module):
|
14 |
+
def __init__(self, image_size, patch_size, embed_dim, channels):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.image_size = image_size
|
18 |
+
if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
|
19 |
+
raise ValueError("image dimensions must be divisible by the patch size")
|
20 |
+
self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
|
21 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
22 |
+
self.patch_size = patch_size
|
23 |
+
|
24 |
+
self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
25 |
+
|
26 |
+
def forward(self, im):
|
27 |
+
B, C, H, W = im.shape
|
28 |
+
x = self.proj(im).flatten(2).transpose(1, 2)
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class VisionTransformer(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
image_size,
|
36 |
+
patch_size,
|
37 |
+
n_layers,
|
38 |
+
d_model,
|
39 |
+
d_ff,
|
40 |
+
n_heads,
|
41 |
+
n_cls,
|
42 |
+
dropout=0.1,
|
43 |
+
drop_path_rate=0.0,
|
44 |
+
distilled=False,
|
45 |
+
channels=3,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.patch_embed = PatchEmbedding(
|
49 |
+
image_size,
|
50 |
+
patch_size,
|
51 |
+
d_model,
|
52 |
+
channels,
|
53 |
+
)
|
54 |
+
self.patch_size = patch_size
|
55 |
+
self.n_layers = n_layers
|
56 |
+
self.d_model = d_model
|
57 |
+
self.d_ff = d_ff
|
58 |
+
self.n_heads = n_heads
|
59 |
+
self.dropout = nn.Dropout(dropout)
|
60 |
+
self.n_cls = n_cls
|
61 |
+
|
62 |
+
# cls and pos tokens
|
63 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
64 |
+
self.distilled = distilled
|
65 |
+
if self.distilled:
|
66 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
67 |
+
self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 2, d_model))
|
68 |
+
self.head_dist = nn.Linear(d_model, n_cls)
|
69 |
+
else:
|
70 |
+
self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, d_model))
|
71 |
+
|
72 |
+
# transformer blocks
|
73 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
|
74 |
+
self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)])
|
75 |
+
|
76 |
+
# output head
|
77 |
+
self.norm = nn.LayerNorm(d_model)
|
78 |
+
self.head = nn.Linear(d_model, n_cls)
|
79 |
+
|
80 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
81 |
+
trunc_normal_(self.cls_token, std=0.02)
|
82 |
+
if self.distilled:
|
83 |
+
trunc_normal_(self.dist_token, std=0.02)
|
84 |
+
self.pre_logits = nn.Identity()
|
85 |
+
|
86 |
+
self.apply(init_weights)
|
87 |
+
|
88 |
+
@torch.jit.ignore
|
89 |
+
def no_weight_decay(self):
|
90 |
+
return {"pos_embed", "cls_token", "dist_token"}
|
91 |
+
|
92 |
+
@torch.jit.ignore()
|
93 |
+
def load_pretrained(self, checkpoint_path, prefix=""):
|
94 |
+
_load_weights(self, checkpoint_path, prefix)
|
95 |
+
|
96 |
+
def forward(self, im, head_out_idx: List[int], n_dim_output=3, return_features=False):
|
97 |
+
B, _, H, W = im.shape
|
98 |
+
PS = self.patch_size
|
99 |
+
assert n_dim_output == 3 or n_dim_output == 4, "n_dim_output must be 3 or 4"
|
100 |
+
x = self.patch_embed(im)
|
101 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
102 |
+
if self.distilled:
|
103 |
+
dist_tokens = self.dist_token.expand(B, -1, -1)
|
104 |
+
x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
|
105 |
+
else:
|
106 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
107 |
+
|
108 |
+
pos_embed = self.pos_embed
|
109 |
+
num_extra_tokens = 1 + self.distilled
|
110 |
+
if x.shape[1] != pos_embed.shape[1]:
|
111 |
+
pos_embed = resize_pos_embed(
|
112 |
+
pos_embed,
|
113 |
+
self.patch_embed.grid_size,
|
114 |
+
(H // PS, W // PS),
|
115 |
+
num_extra_tokens,
|
116 |
+
)
|
117 |
+
x = x + pos_embed
|
118 |
+
x = self.dropout(x)
|
119 |
+
device = x.device
|
120 |
+
|
121 |
+
if n_dim_output == 3:
|
122 |
+
heads_out = torch.zeros(size=(len(head_out_idx), B, (H // PS) ** 2 + 1, self.d_model)).to(device)
|
123 |
+
else:
|
124 |
+
heads_out = torch.zeros(size=(len(head_out_idx), B, self.d_model, H // PS, H // PS)).to(device)
|
125 |
+
self.register_buffer("heads_out", heads_out)
|
126 |
+
|
127 |
+
head_idx = 0
|
128 |
+
for idx_layer, blk in enumerate(self.blocks):
|
129 |
+
x = blk(x)
|
130 |
+
if idx_layer in head_out_idx:
|
131 |
+
if n_dim_output == 3:
|
132 |
+
heads_out[head_idx] = x
|
133 |
+
else:
|
134 |
+
heads_out[head_idx] = x[:, 1:, :].reshape((-1, 24, 24, self.d_model)).permute(0, 3, 1, 2)
|
135 |
+
head_idx += 1
|
136 |
+
|
137 |
+
x = self.norm(x)
|
138 |
+
|
139 |
+
if return_features:
|
140 |
+
return heads_out
|
141 |
+
|
142 |
+
if self.distilled:
|
143 |
+
x, x_dist = x[:, 0], x[:, 1]
|
144 |
+
x = self.head(x)
|
145 |
+
x_dist = self.head_dist(x_dist)
|
146 |
+
x = (x + x_dist) / 2
|
147 |
+
else:
|
148 |
+
x = x[:, 0]
|
149 |
+
x = self.head(x)
|
150 |
+
return x
|
151 |
+
|
152 |
+
def get_attention_map(self, im, layer_id):
|
153 |
+
if layer_id >= self.n_layers or layer_id < 0:
|
154 |
+
raise ValueError(f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}.")
|
155 |
+
B, _, H, W = im.shape
|
156 |
+
PS = self.patch_size
|
157 |
+
|
158 |
+
x = self.patch_embed(im)
|
159 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
160 |
+
if self.distilled:
|
161 |
+
dist_tokens = self.dist_token.expand(B, -1, -1)
|
162 |
+
x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
|
163 |
+
else:
|
164 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
165 |
+
|
166 |
+
pos_embed = self.pos_embed
|
167 |
+
num_extra_tokens = 1 + self.distilled
|
168 |
+
if x.shape[1] != pos_embed.shape[1]:
|
169 |
+
pos_embed = resize_pos_embed(
|
170 |
+
pos_embed,
|
171 |
+
self.patch_embed.grid_size,
|
172 |
+
(H // PS, W // PS),
|
173 |
+
num_extra_tokens,
|
174 |
+
)
|
175 |
+
x = x + pos_embed
|
176 |
+
|
177 |
+
for i, blk in enumerate(self.blocks):
|
178 |
+
if i < layer_id:
|
179 |
+
x = blk(x)
|
180 |
+
else:
|
181 |
+
return blk(x, return_attention=True)
|
182 |
+
|
183 |
+
|
184 |
+
class FeatureTransform(nn.Module):
|
185 |
+
def __init__(self, img_size, d_encoder, nls_list=[128, 256, 512, 512], scale_factor_list=[8, 4, 2, 1]):
|
186 |
+
super(FeatureTransform, self).__init__()
|
187 |
+
self.img_size = img_size
|
188 |
+
|
189 |
+
self.decoder_0 = DecoderLinear(n_cls=nls_list[0], d_encoder=d_encoder, scale_factor=scale_factor_list[0])
|
190 |
+
self.decoder_1 = DecoderLinear(n_cls=nls_list[1], d_encoder=d_encoder, scale_factor=scale_factor_list[1])
|
191 |
+
self.decoder_2 = DecoderLinear(n_cls=nls_list[2], d_encoder=d_encoder, scale_factor=scale_factor_list[2])
|
192 |
+
self.decoder_3 = DecoderLinear(n_cls=nls_list[3], d_encoder=d_encoder, scale_factor=scale_factor_list[3])
|
193 |
+
|
194 |
+
def forward(self, x_list):
|
195 |
+
feat_3 = self.decoder_3(x_list[3][:, 1:, :], self.img_size) # (2, 512, 24, 24)
|
196 |
+
feat_2 = self.decoder_2(x_list[2][:, 1:, :], self.img_size) # (2, 512, 48, 48)
|
197 |
+
feat_1 = self.decoder_1(x_list[1][:, 1:, :], self.img_size) # (2, 256, 96, 96)
|
198 |
+
feat_0 = self.decoder_0(x_list[0][:, 1:, :], self.img_size) # (2, 128, 192, 192)
|
199 |
+
return feat_0, feat_1, feat_2, feat_3
|