duongttr commited on
Commit
62b9b3d
·
1 Parent(s): d5f3a99

Upload app cpu version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +138 -0
  3. README.md +1 -1
  4. app.py +50 -0
  5. app_config.py +9 -0
  6. checkpoints/epoch_10/colornet.pth +3 -0
  7. checkpoints/epoch_10/discriminator.pth +3 -0
  8. checkpoints/epoch_10/embed_net.pth +3 -0
  9. checkpoints/epoch_10/learning_state.pth +3 -0
  10. checkpoints/epoch_10/nonlocal_net.pth +3 -0
  11. checkpoints/epoch_12/colornet.pth +3 -0
  12. checkpoints/epoch_12/discriminator.pth +3 -0
  13. checkpoints/epoch_12/embed_net.pth +3 -0
  14. checkpoints/epoch_12/learning_state.pth +3 -0
  15. checkpoints/epoch_12/nonlocal_net.pth +3 -0
  16. checkpoints/epoch_16/colornet.pth +3 -0
  17. checkpoints/epoch_16/discriminator.pth +3 -0
  18. checkpoints/epoch_16/embed_net.pth +3 -0
  19. checkpoints/epoch_16/learning_state.pth +3 -0
  20. checkpoints/epoch_16/nonlocal_net.pth +3 -0
  21. checkpoints/epoch_20/colornet.pth +3 -0
  22. checkpoints/epoch_20/discriminator.pth +3 -0
  23. checkpoints/epoch_20/embed_net.pth +3 -0
  24. checkpoints/epoch_20/learning_state.pth +3 -0
  25. checkpoints/epoch_20/nonlocal_net.pth +3 -0
  26. requirements.txt +0 -0
  27. sample_input/ref1.jpg +0 -0
  28. sample_input/video1.mp4 +3 -0
  29. src/__init__.py +0 -0
  30. src/data/dataloader.py +332 -0
  31. src/data/functional.py +84 -0
  32. src/data/transforms.py +348 -0
  33. src/inference.py +174 -0
  34. src/losses.py +277 -0
  35. src/metrics.py +225 -0
  36. src/models/CNN/ColorVidNet.py +141 -0
  37. src/models/CNN/FrameColor.py +76 -0
  38. src/models/CNN/GAN_models.py +212 -0
  39. src/models/CNN/NonlocalNet.py +437 -0
  40. src/models/CNN/__init__.py +0 -0
  41. src/models/__init__.py +0 -0
  42. src/models/vit/__init__.py +0 -0
  43. src/models/vit/blocks.py +80 -0
  44. src/models/vit/config.py +22 -0
  45. src/models/vit/config.yml +132 -0
  46. src/models/vit/decoder.py +34 -0
  47. src/models/vit/embed.py +52 -0
  48. src/models/vit/factory.py +45 -0
  49. src/models/vit/utils.py +71 -0
  50. src/models/vit/vit.py +199 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flagged/
2
+ sample_output/
3
+ wandb/
4
+ .vscode
5
+ .DS_Store
6
+ *ckpt*/
7
+ # Custom
8
+ *.pt
9
+ data/local
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ pip-wheel-metadata/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ .dmypy.json
135
+ dmypy.json
136
+
137
+ # Pyre type checker
138
+ .pyre/
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: SwinTExCo Cpu
3
  emoji: 🏃
4
  colorFrom: green
5
  colorTo: yellow
 
1
  ---
2
+ title: Exemplar-based Video Colorization using Vision Transformer (CPU version)
3
  emoji: 🏃
4
  colorFrom: green
5
  colorTo: yellow
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.inference import SwinTExCo
3
+ import cv2
4
+ import os
5
+ from PIL import Image
6
+ import time
7
+ import app_config as cfg
8
+
9
+
10
+ model = SwinTExCo(weights_path=cfg.ckpt_path)
11
+
12
+ def video_colorization(video_path, ref_image, progress=gr.Progress()):
13
+ # Initialize video reader
14
+ video_reader = cv2.VideoCapture(video_path)
15
+ fps = video_reader.get(cv2.CAP_PROP_FPS)
16
+ height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
17
+ width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
18
+ num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
19
+
20
+ # Initialize reference image
21
+ ref_image = Image.fromarray(ref_image)
22
+
23
+ # Initialize video writer
24
+ output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
25
+ video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
26
+
27
+ # Init progress bar
28
+
29
+ for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
30
+ colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
31
+ video_writer.write(colorized_frame)
32
+
33
+ # for i in progress.tqdm(range(1000)):
34
+ # time.sleep(0.5)
35
+
36
+ video_writer.release()
37
+
38
+ return output_path
39
+
40
+ app = gr.Interface(
41
+ fn=video_colorization,
42
+ inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
43
+ gr.Image(sources="upload", label="Reference image (color)")],
44
+ outputs=gr.Video(label="Output video (colorized)"),
45
+ title=cfg.TITLE,
46
+ description=cfg.DESCRIPTION
47
+ ).queue()
48
+
49
+
50
+ app.launch()
app_config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path = 'checkpoints/epoch_20'
2
+ TITLE = 'Deep Exemplar-based Video Colorization using Vision Transformer'
3
+ DESCRIPTION = '''
4
+ <center>
5
+ This is a demo app of the thesis: <b>Deep Exemplar-based Video Colorization using Vision Transformer</b>.<br/>
6
+ The code is available at: <i>The link will be updated soon</i>.<br/>
7
+ Our previous work was also written into paper and accepted at the <a href="https://ictc.org/program_proceeding">ICTC 2023 conference</a> (Section <i>B1-4</i>).
8
+ </center>
9
+ '''.strip()
checkpoints/epoch_10/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ecb43b5e02b77bec5342e2e296d336bf8f384a07d3c809d1a548fd5fb1e7365
3
+ size 131239411
checkpoints/epoch_10/discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce8968a9d3d2f99b1bc1e32080507e0d671cee00b66200105c8839be684b84b4
3
+ size 45073068
checkpoints/epoch_10/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc711755a75c43025dabe9407cbd11d164eaa9e21f26430d0c16c7493410d902
3
+ size 110352261
checkpoints/epoch_10/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d09b1e96fdf0205930a21928449a44c51cedd965cc0d573068c73971bcb8bd2
3
+ size 748166487
checkpoints/epoch_10/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86c97d6803d625a0dff8c6c09b70852371906eb5ef77df0277c27875666a68e2
3
+ size 73189765
checkpoints/epoch_12/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f4b92cd59f4c88c0c1d7c93652413d54b1b96d729fc4b93e235887b5164f28
3
+ size 131239846
checkpoints/epoch_12/discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b54b0bad6ceec33569cc5833cbf03ed8ddbb5f07998aa634badf8298d3cd15f
3
+ size 45073513
checkpoints/epoch_12/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
3
+ size 110352698
checkpoints/epoch_12/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f8bb4dbb3cb8e497a9a2079947f0221823fa8b44695e2d2ad8478be48464fad
3
+ size 748166934
checkpoints/epoch_12/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1f76b53dad7bf15c7d26aa106c95387e75751b8c31fafef2bd73ea7d77160cb
3
+ size 73190208
checkpoints/epoch_16/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81ec9cff0ad5b0d920179fa7a9cc229e1424bfc796b7134604ff66b97d748c49
3
+ size 131239846
checkpoints/epoch_16/discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42262d5ed7596f38e65774085222530eee57da8dfaa7fe1aa223d824ed166f62
3
+ size 45073513
checkpoints/epoch_16/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
3
+ size 110352698
checkpoints/epoch_16/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea4cf81341750ebf517c696a0f6241bfeede0584b0ce75ad208e3ffc8280877f
3
+ size 748166934
checkpoints/epoch_16/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85b63363bc9c79732df78ba50ed19491ed86e961214bbd1f796a871334eba516
3
+ size 73190208
checkpoints/epoch_20/colornet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c524f4e5df5f6ce91db1973a30de55299ebcbbde1edd2009718d3b4cd2631339
3
+ size 131239846
checkpoints/epoch_20/discriminator.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcd80950c796fcfe6e4b6bdeeb358776700458d868da94ee31df3d1d37779310
3
+ size 45073513
checkpoints/epoch_20/embed_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e2a156c0737e3d063af0e95e1e7176362e85120b88275a1aa02dcf488e1865
3
+ size 110352698
checkpoints/epoch_20/learning_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b1163b210b246b07d8f1c50eb3766d97c6f03bf409c854d00b7c69edb6d7391
3
+ size 748166934
checkpoints/epoch_20/nonlocal_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:031e5f38cc79eb3c0ed51ca2ad3c8921fdda2fa05946c357f84881259de74e6d
3
+ size 73190208
requirements.txt ADDED
Binary file (434 Bytes). View file
 
sample_input/ref1.jpg ADDED
sample_input/video1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:077ebcd3cf6c020c95732e74a0fe1fab9b80102bc14d5e201b12c4917e0c0d1d
3
+ size 1011726
src/__init__.py ADDED
File without changes
src/data/dataloader.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from src.utils import (
4
+ CenterPadCrop_numpy,
5
+ Distortion_with_flow_cpu,
6
+ Distortion_with_flow_gpu,
7
+ Normalize,
8
+ RGB2Lab,
9
+ ToTensor,
10
+ Normalize,
11
+ RGB2Lab,
12
+ ToTensor,
13
+ CenterPad,
14
+ read_flow,
15
+ SquaredPadding
16
+ )
17
+ import torch
18
+ import torch.utils.data as data
19
+ import torchvision.transforms as transforms
20
+ from numpy import random
21
+ import os
22
+ from PIL import Image
23
+ from scipy.ndimage.filters import gaussian_filter
24
+ from scipy.ndimage import map_coordinates
25
+ import glob
26
+
27
+
28
+
29
+ def image_loader(path):
30
+ with open(path, "rb") as f:
31
+ with Image.open(f) as img:
32
+ return img.convert("RGB")
33
+
34
+
35
+ class CenterCrop(object):
36
+ """
37
+ center crop the numpy array
38
+ """
39
+
40
+ def __init__(self, image_size):
41
+ self.h0, self.w0 = image_size
42
+
43
+ def __call__(self, input_numpy):
44
+ if input_numpy.ndim == 3:
45
+ h, w, channel = input_numpy.shape
46
+ output_numpy = np.zeros((self.h0, self.w0, channel))
47
+ output_numpy = input_numpy[
48
+ (h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0, :
49
+ ]
50
+ else:
51
+ h, w = input_numpy.shape
52
+ output_numpy = np.zeros((self.h0, self.w0))
53
+ output_numpy = input_numpy[
54
+ (h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0
55
+ ]
56
+ return output_numpy
57
+
58
+
59
+ class VideosDataset(torch.utils.data.Dataset):
60
+ def __init__(
61
+ self,
62
+ video_data_root,
63
+ flow_data_root,
64
+ mask_data_root,
65
+ imagenet_folder,
66
+ annotation_file_path,
67
+ image_size,
68
+ num_refs=5, # max = 20
69
+ image_transform=None,
70
+ real_reference_probability=1,
71
+ nonzero_placeholder_probability=0.5,
72
+ ):
73
+ self.video_data_root = video_data_root
74
+ self.flow_data_root = flow_data_root
75
+ self.mask_data_root = mask_data_root
76
+ self.imagenet_folder = imagenet_folder
77
+ self.image_transform = image_transform
78
+ self.CenterPad = CenterPad(image_size)
79
+ self.Resize = transforms.Resize(image_size)
80
+ self.ToTensor = ToTensor()
81
+ self.CenterCrop = transforms.CenterCrop(image_size)
82
+ self.SquaredPadding = SquaredPadding(image_size[0])
83
+ self.num_refs = num_refs
84
+
85
+ assert os.path.exists(self.video_data_root), "find no video dataroot"
86
+ assert os.path.exists(self.flow_data_root), "find no flow dataroot"
87
+ assert os.path.exists(self.imagenet_folder), "find no imagenet folder"
88
+ # self.epoch = epoch
89
+ self.image_pairs = pd.read_csv(annotation_file_path, dtype=str)
90
+ self.real_len = len(self.image_pairs)
91
+ # self.image_pairs = pd.concat([self.image_pairs] * self.epoch, ignore_index=True)
92
+ self.real_reference_probability = real_reference_probability
93
+ self.nonzero_placeholder_probability = nonzero_placeholder_probability
94
+ print("##### parsing image pairs in %s: %d pairs #####" % (video_data_root, self.__len__()))
95
+
96
+ def __getitem__(self, index):
97
+ (
98
+ video_name,
99
+ prev_frame,
100
+ current_frame,
101
+ flow_forward_name,
102
+ mask_name,
103
+ reference_1_name,
104
+ reference_2_name,
105
+ reference_3_name,
106
+ reference_4_name,
107
+ reference_5_name
108
+ ) = self.image_pairs.iloc[index, :5+self.num_refs].values.tolist()
109
+
110
+ video_path = os.path.join(self.video_data_root, video_name)
111
+ flow_path = os.path.join(self.flow_data_root, video_name)
112
+ mask_path = os.path.join(self.mask_data_root, video_name)
113
+
114
+ prev_frame_path = os.path.join(video_path, prev_frame)
115
+ current_frame_path = os.path.join(video_path, current_frame)
116
+ list_frame_path = glob.glob(os.path.join(video_path, '*'))
117
+ list_frame_path.sort()
118
+
119
+ reference_1_path = os.path.join(self.imagenet_folder, reference_1_name)
120
+ reference_2_path = os.path.join(self.imagenet_folder, reference_2_name)
121
+ reference_3_path = os.path.join(self.imagenet_folder, reference_3_name)
122
+ reference_4_path = os.path.join(self.imagenet_folder, reference_4_name)
123
+ reference_5_path = os.path.join(self.imagenet_folder, reference_5_name)
124
+
125
+ flow_forward_path = os.path.join(flow_path, flow_forward_name)
126
+ mask_path = os.path.join(mask_path, mask_name)
127
+
128
+ #reference_gt_1_path = prev_frame_path
129
+ #reference_gt_2_path = current_frame_path
130
+ try:
131
+ I1 = Image.open(prev_frame_path).convert("RGB")
132
+ I2 = Image.open(current_frame_path).convert("RGB")
133
+ try:
134
+ I_reference_video = Image.open(list_frame_path[0]).convert("RGB") # Get first frame
135
+ except:
136
+ I_reference_video = Image.open(current_frame_path).convert("RGB") # Get current frame if error
137
+
138
+ reference_list = [reference_1_path, reference_2_path, reference_3_path, reference_4_path, reference_5_path]
139
+ while reference_list: # run until getting the colorized reference
140
+ reference_path = random.choice(reference_list)
141
+ I_reference_video_real = Image.open(reference_path)
142
+ if I_reference_video_real.mode == 'L':
143
+ reference_list.remove(reference_path)
144
+ else:
145
+ break
146
+ if not reference_list:
147
+ I_reference_video_real = I_reference_video
148
+
149
+ flow_forward = read_flow(flow_forward_path) # numpy
150
+
151
+ mask = Image.open(mask_path) # PIL
152
+ mask = self.Resize(mask)
153
+ mask = np.array(mask)
154
+ # mask = self.SquaredPadding(mask, return_pil=False, return_paddings=False)
155
+ # binary mask
156
+ mask[mask < 240] = 0
157
+ mask[mask >= 240] = 1
158
+ mask = self.ToTensor(mask)
159
+
160
+ # transform
161
+ I1 = self.image_transform(I1)
162
+ I2 = self.image_transform(I2)
163
+ I_reference_video = self.image_transform(I_reference_video)
164
+ I_reference_video_real = self.image_transform(I_reference_video_real)
165
+ flow_forward = self.ToTensor(flow_forward)
166
+ flow_forward = self.Resize(flow_forward)#, return_pil=False, return_paddings=False, dtype=np.float32)
167
+
168
+
169
+ if np.random.random() < self.real_reference_probability:
170
+ I_reference_output = I_reference_video_real # Use reference from imagenet
171
+ placeholder = torch.zeros_like(I1)
172
+ self_ref_flag = torch.zeros_like(I1)
173
+ else:
174
+ I_reference_output = I_reference_video # Use reference from ground truth
175
+ placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
176
+ self_ref_flag = torch.ones_like(I1)
177
+
178
+ outputs = [
179
+ I1,
180
+ I2,
181
+ I_reference_output,
182
+ flow_forward,
183
+ mask,
184
+ placeholder,
185
+ self_ref_flag,
186
+ video_name + prev_frame,
187
+ video_name + current_frame,
188
+ reference_path
189
+ ]
190
+
191
+ except Exception as e:
192
+ print("error in reading image pair: %s" % str(self.image_pairs[index]))
193
+ print(e)
194
+ return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
195
+ return outputs
196
+
197
+ def __len__(self):
198
+ return len(self.image_pairs)
199
+
200
+
201
+ def parse_imgnet_images(pairs_file):
202
+ pairs = []
203
+ with open(pairs_file, "r") as f:
204
+ lines = f.readlines()
205
+ for line in lines:
206
+ line = line.strip().split("|")
207
+ image_a = line[0]
208
+ image_b = line[1]
209
+ pairs.append((image_a, image_b))
210
+ return pairs
211
+
212
+
213
+ class VideosDataset_ImageNet(data.Dataset):
214
+ def __init__(
215
+ self,
216
+ imagenet_data_root,
217
+ pairs_file,
218
+ image_size,
219
+ transforms_imagenet=None,
220
+ distortion_level=3,
221
+ brightnessjitter=0,
222
+ nonzero_placeholder_probability=0.5,
223
+ extra_reference_transform=None,
224
+ real_reference_probability=1,
225
+ distortion_device='cpu'
226
+ ):
227
+ self.imagenet_data_root = imagenet_data_root
228
+ self.image_pairs = pd.read_csv(pairs_file, names=['i1', 'i2'])
229
+ self.transforms_imagenet_raw = transforms_imagenet
230
+ self.extra_reference_transform = transforms.Compose(extra_reference_transform)
231
+ self.real_reference_probability = real_reference_probability
232
+ self.transforms_imagenet = transforms.Compose(transforms_imagenet)
233
+ self.image_size = image_size
234
+ self.real_len = len(self.image_pairs)
235
+ self.distortion_level = distortion_level
236
+ self.distortion_transform = Distortion_with_flow_cpu() if distortion_device == 'cpu' else Distortion_with_flow_gpu()
237
+ self.brightnessjitter = brightnessjitter
238
+ self.flow_transform = transforms.Compose([CenterPadCrop_numpy(self.image_size), ToTensor()])
239
+ self.nonzero_placeholder_probability = nonzero_placeholder_probability
240
+ self.ToTensor = ToTensor()
241
+ self.Normalize = Normalize()
242
+ print("##### parsing imageNet pairs in %s: %d pairs #####" % (imagenet_data_root, self.__len__()))
243
+
244
+ def __getitem__(self, index):
245
+ pa, pb = self.image_pairs.iloc[index].values.tolist()
246
+ if np.random.random() > 0.5:
247
+ pa, pb = pb, pa
248
+
249
+ image_a_path = os.path.join(self.imagenet_data_root, pa)
250
+ image_b_path = os.path.join(self.imagenet_data_root, pb)
251
+
252
+ I1 = image_loader(image_a_path)
253
+ I2 = I1
254
+ I_reference_video = I1
255
+ I_reference_video_real = image_loader(image_b_path)
256
+ # print("i'm here get image 2")
257
+ # generate the flow
258
+ alpha = np.random.rand() * self.distortion_level
259
+ distortion_range = 50
260
+ random_state = np.random.RandomState(None)
261
+ shape = self.image_size[0], self.image_size[1]
262
+ # dx: flow on the vertical direction; dy: flow on the horizontal direction
263
+ forward_dx = (
264
+ gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
265
+ )
266
+ forward_dy = (
267
+ gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
268
+ )
269
+ # print("i'm here get image 3")
270
+ for transform in self.transforms_imagenet_raw:
271
+ if type(transform) is RGB2Lab:
272
+ I1_raw = I1
273
+ I1 = transform(I1)
274
+ for transform in self.transforms_imagenet_raw:
275
+ if type(transform) is RGB2Lab:
276
+ I2 = self.distortion_transform(I2, forward_dx, forward_dy)
277
+ I2_raw = I2
278
+ I2 = transform(I2)
279
+ # print("i'm here get image 4")
280
+ I2[0:1, :, :] = I2[0:1, :, :] + torch.randn(1) * self.brightnessjitter
281
+
282
+ I_reference_video = self.extra_reference_transform(I_reference_video)
283
+ for transform in self.transforms_imagenet_raw:
284
+ I_reference_video = transform(I_reference_video)
285
+
286
+ I_reference_video_real = self.transforms_imagenet(I_reference_video_real)
287
+ # print("i'm here get image 5")
288
+ flow_forward_raw = np.stack((forward_dy, forward_dx), axis=-1)
289
+ flow_forward = self.flow_transform(flow_forward_raw)
290
+
291
+ # update the mask for the pixels on the border
292
+ grid_x, grid_y = np.meshgrid(np.arange(self.image_size[0]), np.arange(self.image_size[1]), indexing="ij")
293
+ grid = np.stack((grid_y, grid_x), axis=-1)
294
+ grid_warp = grid + flow_forward_raw
295
+ location_y = grid_warp[:, :, 0].flatten()
296
+ location_x = grid_warp[:, :, 1].flatten()
297
+ I2_raw = np.array(I2_raw).astype(float)
298
+ I21_r = map_coordinates(I2_raw[:, :, 0], np.stack((location_x, location_y)), cval=-1).reshape(
299
+ (self.image_size[0], self.image_size[1])
300
+ )
301
+ I21_g = map_coordinates(I2_raw[:, :, 1], np.stack((location_x, location_y)), cval=-1).reshape(
302
+ (self.image_size[0], self.image_size[1])
303
+ )
304
+ I21_b = map_coordinates(I2_raw[:, :, 2], np.stack((location_x, location_y)), cval=-1).reshape(
305
+ (self.image_size[0], self.image_size[1])
306
+ )
307
+ I21_raw = np.stack((I21_r, I21_g, I21_b), axis=2)
308
+ mask = np.ones((self.image_size[0], self.image_size[1]))
309
+ mask[(I21_raw[:, :, 0] == -1) & (I21_raw[:, :, 1] == -1) & (I21_raw[:, :, 2] == -1)] = 0
310
+ mask[abs(I21_raw - I1_raw).sum(axis=-1) > 50] = 0
311
+ mask = self.ToTensor(mask)
312
+ # print("i'm here get image 6")
313
+ if np.random.random() < self.real_reference_probability:
314
+ I_reference_output = I_reference_video_real
315
+ placeholder = torch.zeros_like(I1)
316
+ self_ref_flag = torch.zeros_like(I1)
317
+ else:
318
+ I_reference_output = I_reference_video
319
+ placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
320
+ self_ref_flag = torch.ones_like(I1)
321
+
322
+ # except Exception as e:
323
+ # if combo_path is not None:
324
+ # print("problem in ", combo_path)
325
+ # print("problem in, ", image_a_path)
326
+ # print(e)
327
+ # return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
328
+ # print("i'm here get image 7")
329
+ return [I1, I2, I_reference_output, flow_forward, mask, placeholder, self_ref_flag, "holder", pb, pa]
330
+
331
+ def __len__(self):
332
+ return len(self.image_pairs)
src/data/functional.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import torch
4
+ import numbers
5
+ import collections
6
+ import numpy as np
7
+ from PIL import Image, ImageOps
8
+
9
+
10
+ def _is_pil_image(img):
11
+ return isinstance(img, Image.Image)
12
+
13
+
14
+ def _is_tensor_image(img):
15
+ return torch.is_tensor(img) and img.ndimension() == 3
16
+
17
+
18
+ def _is_numpy_image(img):
19
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
20
+
21
+
22
+ def to_mytensor(pic):
23
+ pic_arr = np.array(pic)
24
+ if pic_arr.ndim == 2:
25
+ pic_arr = pic_arr[..., np.newaxis]
26
+ img = torch.from_numpy(pic_arr.transpose((2, 0, 1)))
27
+ if not isinstance(img, torch.FloatTensor):
28
+ return img.float() # no normalize .div(255)
29
+ else:
30
+ return img
31
+
32
+
33
+ def normalize(tensor, mean, std):
34
+ if not _is_tensor_image(tensor):
35
+ raise TypeError("tensor is not a torch image.")
36
+ if tensor.size(0) == 1:
37
+ tensor.sub_(mean).div_(std)
38
+ else:
39
+ for t, m, s in zip(tensor, mean, std):
40
+ t.sub_(m).div_(s)
41
+ return tensor
42
+
43
+
44
+ def resize(img, size, interpolation=Image.BILINEAR):
45
+ if not _is_pil_image(img):
46
+ raise TypeError("img should be PIL Image. Got {}".format(type(img)))
47
+ if not isinstance(size, int) and (not isinstance(size, collections.Iterable) or len(size) != 2):
48
+ raise TypeError("Got inappropriate size arg: {}".format(size))
49
+
50
+ if not isinstance(size, int):
51
+ return img.resize(size[::-1], interpolation)
52
+
53
+ w, h = img.size
54
+ if (w <= h and w == size) or (h <= w and h == size):
55
+ return img
56
+ if w < h:
57
+ ow = size
58
+ oh = int(round(size * h / w))
59
+ else:
60
+ oh = size
61
+ ow = int(round(size * w / h))
62
+ return img.resize((ow, oh), interpolation)
63
+
64
+
65
+ def pad(img, padding, fill=0):
66
+ if not _is_pil_image(img):
67
+ raise TypeError("img should be PIL Image. Got {}".format(type(img)))
68
+
69
+ if not isinstance(padding, (numbers.Number, tuple)):
70
+ raise TypeError("Got inappropriate padding arg")
71
+ if not isinstance(fill, (numbers.Number, str, tuple)):
72
+ raise TypeError("Got inappropriate fill arg")
73
+
74
+ if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
75
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)))
76
+
77
+ return ImageOps.expand(img, border=padding, fill=fill)
78
+
79
+
80
+ def crop(img, i, j, h, w):
81
+ if not _is_pil_image(img):
82
+ raise TypeError("img should be PIL Image. Got {}".format(type(img)))
83
+
84
+ return img.crop((j, i, j + w, i + h))
src/data/transforms.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import collections
4
+ import numbers
5
+ import random
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from skimage import color
10
+
11
+ import src.data.functional as F
12
+
13
+ __all__ = [
14
+ "Compose",
15
+ "Concatenate",
16
+ "ToTensor",
17
+ "Normalize",
18
+ "Resize",
19
+ "Scale",
20
+ "CenterCrop",
21
+ "Pad",
22
+ "RandomCrop",
23
+ "RandomHorizontalFlip",
24
+ "RandomVerticalFlip",
25
+ "RandomResizedCrop",
26
+ "RandomSizedCrop",
27
+ "FiveCrop",
28
+ "TenCrop",
29
+ "RGB2Lab",
30
+ ]
31
+
32
+
33
+ def CustomFunc(inputs, func, *args, **kwargs):
34
+ im_l = func(inputs[0], *args, **kwargs)
35
+ im_ab = func(inputs[1], *args, **kwargs)
36
+ warp_ba = func(inputs[2], *args, **kwargs)
37
+ warp_aba = func(inputs[3], *args, **kwargs)
38
+ im_gbl_ab = func(inputs[4], *args, **kwargs)
39
+ bgr_mc_im = func(inputs[5], *args, **kwargs)
40
+
41
+ layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
42
+
43
+ for l in range(5):
44
+ layer = inputs[6 + l]
45
+ err_ba = func(layer[0], *args, **kwargs)
46
+ err_ab = func(layer[1], *args, **kwargs)
47
+
48
+ layer_data.append([err_ba, err_ab])
49
+
50
+ return layer_data
51
+
52
+
53
+ class Compose(object):
54
+ """Composes several transforms together.
55
+
56
+ Args:
57
+ transforms (list of ``Transform`` objects): list of transforms to compose.
58
+
59
+ Example:
60
+ >>> transforms.Compose([
61
+ >>> transforms.CenterCrop(10),
62
+ >>> transforms.ToTensor(),
63
+ >>> ])
64
+ """
65
+
66
+ def __init__(self, transforms):
67
+ self.transforms = transforms
68
+
69
+ def __call__(self, inputs):
70
+ for t in self.transforms:
71
+ inputs = t(inputs)
72
+ return inputs
73
+
74
+
75
+ class Concatenate(object):
76
+ """
77
+ Input: [im_l, im_ab, inputs]
78
+ inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba]
79
+
80
+ Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba]
81
+ """
82
+
83
+ def __call__(self, inputs):
84
+ im_l = inputs[0]
85
+ im_ab = inputs[1]
86
+ warp_ba = inputs[2]
87
+ warp_aba = inputs[3]
88
+ im_glb_ab = inputs[4]
89
+ bgr_mc_im = inputs[5]
90
+ bgr_mc_im = bgr_mc_im[[2, 1, 0], ...]
91
+
92
+ err_ba = []
93
+ err_ab = []
94
+
95
+ for l in range(5):
96
+ layer = inputs[6 + l]
97
+ err_ba.append(layer[0])
98
+ err_ab.append(layer[1])
99
+
100
+ cerr_ba = torch.cat(err_ba, 0)
101
+ cerr_ab = torch.cat(err_ab, 0)
102
+
103
+ return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab)
104
+
105
+
106
+ class ToTensor(object):
107
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
108
+
109
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
110
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
111
+ """
112
+
113
+ def __call__(self, inputs):
114
+ """
115
+ Args:
116
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
117
+
118
+ Returns:
119
+ Tensor: Converted image.
120
+ """
121
+ return CustomFunc(inputs, F.to_mytensor)
122
+
123
+
124
+ class Normalize(object):
125
+ """Normalize an tensor image with mean and standard deviation.
126
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
127
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
128
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
129
+
130
+ Args:
131
+ mean (sequence): Sequence of means for each channel.
132
+ std (sequence): Sequence of standard deviations for each channel.
133
+ """
134
+
135
+ def __call__(self, inputs):
136
+ """
137
+ Args:
138
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
139
+
140
+ Returns:
141
+ Tensor: Normalized Tensor image.
142
+ """
143
+
144
+ im_l = F.normalize(inputs[0], 50, 1) # [0, 100]
145
+ im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100]
146
+
147
+ inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1)
148
+ inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1))
149
+ warp_ba = inputs[2]
150
+
151
+ inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1)
152
+ inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1))
153
+ warp_aba = inputs[3]
154
+
155
+ im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100]
156
+
157
+ bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1))
158
+
159
+ layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
160
+
161
+ for l in range(5):
162
+ layer = inputs[6 + l]
163
+ err_ba = F.normalize(layer[0], 127, 2) # [0, 255]
164
+ err_ab = F.normalize(layer[1], 127, 2) # [0, 255]
165
+ layer_data.append([err_ba, err_ab])
166
+
167
+ return layer_data
168
+
169
+
170
+ class Resize(object):
171
+ """Resize the input PIL Image to the given size.
172
+
173
+ Args:
174
+ size (sequence or int): Desired output size. If size is a sequence like
175
+ (h, w), output size will be matched to this. If size is an int,
176
+ smaller edge of the image will be matched to this number.
177
+ i.e, if height > width, then image will be rescaled to
178
+ (size * height / width, size)
179
+ interpolation (int, optional): Desired interpolation. Default is
180
+ ``PIL.Image.BILINEAR``
181
+ """
182
+
183
+ def __init__(self, size, interpolation=Image.BILINEAR):
184
+ assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
185
+ self.size = size
186
+ self.interpolation = interpolation
187
+
188
+ def __call__(self, inputs):
189
+ """
190
+ Args:
191
+ img (PIL Image): Image to be scaled.
192
+
193
+ Returns:
194
+ PIL Image: Rescaled image.
195
+ """
196
+ return CustomFunc(inputs, F.resize, self.size, self.interpolation)
197
+
198
+
199
+ class RandomCrop(object):
200
+ """Crop the given PIL Image at a random location.
201
+
202
+ Args:
203
+ size (sequence or int): Desired output size of the crop. If size is an
204
+ int instead of sequence like (h, w), a square crop (size, size) is
205
+ made.
206
+ padding (int or sequence, optional): Optional padding on each border
207
+ of the image. Default is 0, i.e no padding. If a sequence of length
208
+ 4 is provided, it is used to pad left, top, right, bottom borders
209
+ respectively.
210
+ """
211
+
212
+ def __init__(self, size, padding=0):
213
+ if isinstance(size, numbers.Number):
214
+ self.size = (int(size), int(size))
215
+ else:
216
+ self.size = size
217
+ self.padding = padding
218
+
219
+ @staticmethod
220
+ def get_params(img, output_size):
221
+ """Get parameters for ``crop`` for a random crop.
222
+
223
+ Args:
224
+ img (PIL Image): Image to be cropped.
225
+ output_size (tuple): Expected output size of the crop.
226
+
227
+ Returns:
228
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
229
+ """
230
+ w, h = img.size
231
+ th, tw = output_size
232
+ if w == tw and h == th:
233
+ return 0, 0, h, w
234
+
235
+ i = random.randint(0, h - th)
236
+ j = random.randint(0, w - tw)
237
+ return i, j, th, tw
238
+
239
+ def __call__(self, inputs):
240
+ """
241
+ Args:
242
+ img (PIL Image): Image to be cropped.
243
+
244
+ Returns:
245
+ PIL Image: Cropped image.
246
+ """
247
+ if self.padding > 0:
248
+ inputs = CustomFunc(inputs, F.pad, self.padding)
249
+
250
+ i, j, h, w = self.get_params(inputs[0], self.size)
251
+ return CustomFunc(inputs, F.crop, i, j, h, w)
252
+
253
+
254
+ class CenterCrop(object):
255
+ """Crop the given PIL Image at a random location.
256
+
257
+ Args:
258
+ size (sequence or int): Desired output size of the crop. If size is an
259
+ int instead of sequence like (h, w), a square crop (size, size) is
260
+ made.
261
+ padding (int or sequence, optional): Optional padding on each border
262
+ of the image. Default is 0, i.e no padding. If a sequence of length
263
+ 4 is provided, it is used to pad left, top, right, bottom borders
264
+ respectively.
265
+ """
266
+
267
+ def __init__(self, size, padding=0):
268
+ if isinstance(size, numbers.Number):
269
+ self.size = (int(size), int(size))
270
+ else:
271
+ self.size = size
272
+ self.padding = padding
273
+
274
+ @staticmethod
275
+ def get_params(img, output_size):
276
+ """Get parameters for ``crop`` for a random crop.
277
+
278
+ Args:
279
+ img (PIL Image): Image to be cropped.
280
+ output_size (tuple): Expected output size of the crop.
281
+
282
+ Returns:
283
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
284
+ """
285
+ w, h = img.size
286
+ th, tw = output_size
287
+ if w == tw and h == th:
288
+ return 0, 0, h, w
289
+
290
+ i = (h - th) // 2
291
+ j = (w - tw) // 2
292
+ return i, j, th, tw
293
+
294
+ def __call__(self, inputs):
295
+ """
296
+ Args:
297
+ img (PIL Image): Image to be cropped.
298
+
299
+ Returns:
300
+ PIL Image: Cropped image.
301
+ """
302
+ if self.padding > 0:
303
+ inputs = CustomFunc(inputs, F.pad, self.padding)
304
+
305
+ i, j, h, w = self.get_params(inputs[0], self.size)
306
+ return CustomFunc(inputs, F.crop, i, j, h, w)
307
+
308
+
309
+ class RandomHorizontalFlip(object):
310
+ """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
311
+
312
+ def __call__(self, inputs):
313
+ """
314
+ Args:
315
+ img (PIL Image): Image to be flipped.
316
+
317
+ Returns:
318
+ PIL Image: Randomly flipped image.
319
+ """
320
+
321
+ if random.random() < 0.5:
322
+ return CustomFunc(inputs, F.hflip)
323
+ return inputs
324
+
325
+
326
+ class RGB2Lab(object):
327
+ def __call__(self, inputs):
328
+ """
329
+ Args:
330
+ img (PIL Image): Image to be flipped.
331
+
332
+ Returns:
333
+ PIL Image: Randomly flipped image.
334
+ """
335
+
336
+ def __call__(self, inputs):
337
+ image_lab = color.rgb2lab(inputs[0])
338
+ warp_ba_lab = color.rgb2lab(inputs[2])
339
+ warp_aba_lab = color.rgb2lab(inputs[3])
340
+ im_gbl_lab = color.rgb2lab(inputs[4])
341
+
342
+ inputs[0] = image_lab[:, :, :1] # l channel
343
+ inputs[1] = image_lab[:, :, 1:] # ab channel
344
+ inputs[2] = warp_ba_lab # lab channel
345
+ inputs[3] = warp_aba_lab # lab channel
346
+ inputs[4] = im_gbl_lab[:, :, 1:] # ab channel
347
+
348
+ return inputs
src/inference.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models.CNN.ColorVidNet import ColorVidNet
2
+ from src.models.vit.embed import SwinModel
3
+ from src.models.CNN.NonlocalNet import WarpNet
4
+ from src.models.CNN.FrameColor import frame_colorization
5
+ import torch
6
+ from src.models.vit.utils import load_params
7
+ import os
8
+ import cv2
9
+ from PIL import Image
10
+ from PIL import ImageEnhance as IE
11
+ import torchvision.transforms as T
12
+ from src.utils import (
13
+ RGB2Lab,
14
+ ToTensor,
15
+ Normalize,
16
+ uncenter_l,
17
+ tensor_lab2rgb
18
+ )
19
+ import numpy as np
20
+ from tqdm import tqdm
21
+
22
+ class SwinTExCo:
23
+ def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
24
+ if device == None:
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ else:
27
+ self.device = device
28
+
29
+ self.embed_net = SwinModel(pretrained_model=swin_backbone, device=self.device).to(self.device)
30
+ self.nonlocal_net = WarpNet(feature_channel=128).to(self.device)
31
+ self.colornet = ColorVidNet(7).to(self.device)
32
+
33
+ self.embed_net.eval()
34
+ self.nonlocal_net.eval()
35
+ self.colornet.eval()
36
+
37
+ self.__load_models(self.embed_net, os.path.join(weights_path, "embed_net.pth"))
38
+ self.__load_models(self.nonlocal_net, os.path.join(weights_path, "nonlocal_net.pth"))
39
+ self.__load_models(self.colornet, os.path.join(weights_path, "colornet.pth"))
40
+
41
+ self.processor = T.Compose([
42
+ T.Resize((224,224)),
43
+ RGB2Lab(),
44
+ ToTensor(),
45
+ Normalize()
46
+ ])
47
+
48
+ pass
49
+
50
+ def __load_models(self, model, weight_path):
51
+ params = load_params(weight_path, self.device)
52
+ model.load_state_dict(params, strict=True)
53
+
54
+ def __preprocess_reference(self, img):
55
+ color_enhancer = IE.Color(img)
56
+ img = color_enhancer.enhance(1.5)
57
+ return img
58
+
59
+ def __upscale_image(self, large_IA_l, I_current_ab_predict):
60
+ H, W = large_IA_l.shape[2:]
61
+ large_current_ab_predict = torch.nn.functional.interpolate(I_current_ab_predict,
62
+ size=(H,W),
63
+ mode="bilinear",
64
+ align_corners=False)
65
+ large_IA_l = torch.cat((large_IA_l, large_current_ab_predict.cpu()), dim=1)
66
+ large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
67
+ return large_current_rgb_predict
68
+
69
+ def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
70
+ large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
71
+ large_IA_l = large_IA_lab[:, 0:1, :, :]
72
+
73
+ IA_lab = self.processor(curr_frame)
74
+ IA_lab = IA_lab.unsqueeze(0).to(self.device)
75
+ IA_l = IA_lab[:, 0:1, :, :]
76
+ if I_last_lab_predict is None:
77
+ I_last_lab_predict = torch.zeros_like(IA_lab).to(self.device)
78
+
79
+
80
+ with torch.no_grad():
81
+ I_current_ab_predict, _ = frame_colorization(
82
+ IA_l,
83
+ I_reference_lab,
84
+ I_last_lab_predict,
85
+ features_B,
86
+ self.embed_net,
87
+ self.nonlocal_net,
88
+ self.colornet,
89
+ luminance_noise=0,
90
+ temperature=1e-10,
91
+ joint_training=False
92
+ )
93
+ I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
94
+
95
+ IA_predict_rgb = self.__upscale_image(large_IA_l, I_current_ab_predict)
96
+ IA_predict_rgb = (IA_predict_rgb.squeeze(0).cpu().numpy() * 255.)
97
+ IA_predict_rgb = np.clip(IA_predict_rgb, 0, 255).astype(np.uint8)
98
+
99
+ return I_last_lab_predict, IA_predict_rgb
100
+
101
+ def predict_video(self, video, ref_image):
102
+ ref_image = self.__preprocess_reference(ref_image)
103
+
104
+ I_last_lab_predict = None
105
+
106
+ IB_lab = self.processor(ref_image)
107
+ IB_lab = IB_lab.unsqueeze(0).to(self.device)
108
+
109
+ with torch.no_grad():
110
+ I_reference_lab = IB_lab
111
+ I_reference_l = I_reference_lab[:, 0:1, :, :]
112
+ I_reference_ab = I_reference_lab[:, 1:3, :, :]
113
+ I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
114
+ features_B = self.embed_net(I_reference_rgb)
115
+
116
+ #PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame")
117
+ while video.isOpened():
118
+ #PBAR.update(1)
119
+ ret, curr_frame = video.read()
120
+
121
+ if not ret:
122
+ break
123
+
124
+ curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2RGB)
125
+ curr_frame = Image.fromarray(curr_frame)
126
+
127
+ I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
128
+
129
+ IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
130
+
131
+ yield IA_predict_rgb
132
+
133
+ #PBAR.close()
134
+ video.release()
135
+
136
+ def predict_image(self, image, ref_image):
137
+ ref_image = self.__preprocess_reference(ref_image)
138
+
139
+ I_last_lab_predict = None
140
+
141
+ IB_lab = self.processor(ref_image)
142
+ IB_lab = IB_lab.unsqueeze(0).to(self.device)
143
+
144
+ with torch.no_grad():
145
+ I_reference_lab = IB_lab
146
+ I_reference_l = I_reference_lab[:, 0:1, :, :]
147
+ I_reference_ab = I_reference_lab[:, 1:3, :, :]
148
+ I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
149
+ features_B = self.embed_net(I_reference_rgb)
150
+
151
+ curr_frame = image
152
+ I_last_lab_predict, IA_predict_rgb = self.__proccess_sample(curr_frame, I_last_lab_predict, I_reference_lab, features_B)
153
+
154
+ IA_predict_rgb = IA_predict_rgb.transpose(1,2,0)
155
+
156
+ return IA_predict_rgb
157
+
158
+ if __name__ == "__main__":
159
+ model = SwinTExCo('checkpoints/epoch_20/')
160
+
161
+ # Initialize video reader and writer
162
+ video = cv2.VideoCapture('sample_input/video_2.mp4')
163
+ fps = video.get(cv2.CAP_PROP_FPS)
164
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
165
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
166
+ video_writer = cv2.VideoWriter('sample_output/video_2_ref_2.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
167
+
168
+ # Initialize reference image
169
+ ref_image = Image.open('sample_input/refs_2/ref2.jpg').convert('RGB')
170
+
171
+ for colorized_frame in model.predict_video(video, ref_image):
172
+ video_writer.write(colorized_frame)
173
+
174
+ video_writer.release()
src/losses.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.utils import feature_normalize
4
+
5
+
6
+ ### START### CONTEXTUAL LOSS ####
7
+ class ContextualLoss(nn.Module):
8
+ """
9
+ input is Al, Bl, channel = 1, range ~ [0, 255]
10
+ """
11
+
12
+ def __init__(self):
13
+ super(ContextualLoss, self).__init__()
14
+ return None
15
+
16
+ def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
17
+ """
18
+ X_features&Y_features are are feature vectors or feature 2d array
19
+ h: bandwidth
20
+ return the per-sample loss
21
+ """
22
+ batch_size = X_features.shape[0]
23
+ feature_depth = X_features.shape[1]
24
+
25
+ # to normalized feature vectors
26
+ if feature_centering:
27
+ X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
28
+ dim=-1
29
+ )
30
+ Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
31
+ dim=-1
32
+ )
33
+ X_features = feature_normalize(X_features).view(
34
+ batch_size, feature_depth, -1
35
+ ) # batch_size * feature_depth * feature_size^2
36
+ Y_features = feature_normalize(Y_features).view(
37
+ batch_size, feature_depth, -1
38
+ ) # batch_size * feature_depth * feature_size^2
39
+
40
+ # conine distance = 1 - similarity
41
+ X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
42
+ d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
43
+
44
+ # normalized distance: dij_bar
45
+ d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
46
+
47
+ # pairwise affinity
48
+ w = torch.exp((1 - d_norm) / h)
49
+ A_ij = w / torch.sum(w, dim=-1, keepdim=True)
50
+
51
+ # contextual loss per sample
52
+ CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1)
53
+ return -torch.log(CX)
54
+
55
+
56
+ class ContextualLoss_forward(nn.Module):
57
+ """
58
+ input is Al, Bl, channel = 1, range ~ [0, 255]
59
+ """
60
+
61
+ def __init__(self):
62
+ super(ContextualLoss_forward, self).__init__()
63
+ return None
64
+
65
+ def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
66
+ """
67
+ X_features&Y_features are are feature vectors or feature 2d array
68
+ h: bandwidth
69
+ return the per-sample loss
70
+ """
71
+ batch_size = X_features.shape[0]
72
+ feature_depth = X_features.shape[1]
73
+
74
+ # to normalized feature vectors
75
+ if feature_centering:
76
+ X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
77
+ dim=-1
78
+ )
79
+ Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
80
+ dim=-1
81
+ )
82
+ X_features = feature_normalize(X_features).view(
83
+ batch_size, feature_depth, -1
84
+ ) # batch_size * feature_depth * feature_size^2
85
+ Y_features = feature_normalize(Y_features).view(
86
+ batch_size, feature_depth, -1
87
+ ) # batch_size * feature_depth * feature_size^2
88
+
89
+ # conine distance = 1 - similarity
90
+ X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
91
+ d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
92
+
93
+ # normalized distance: dij_bar
94
+ d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
95
+
96
+ # pairwise affinity
97
+ w = torch.exp((1 - d_norm) / h)
98
+ A_ij = w / torch.sum(w, dim=-1, keepdim=True)
99
+
100
+ # contextual loss per sample
101
+ CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
102
+ return -torch.log(CX)
103
+
104
+
105
+ ### END### CONTEXTUAL LOSS ####
106
+
107
+
108
+ ##########################
109
+
110
+
111
+ def mse_loss_fn(input, target=0):
112
+ return torch.mean((input - target) ** 2)
113
+
114
+
115
+ ### START### PERCEPTUAL LOSS ###
116
+ def Perceptual_loss(domain_invariant, weight_perceptual):
117
+ instancenorm = nn.InstanceNorm2d(512, affine=False)
118
+
119
+ def __call__(A_relu5_1, predict_relu5_1):
120
+ if domain_invariant:
121
+ feat_loss = (
122
+ mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2
123
+ )
124
+ else:
125
+ feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual
126
+ return feat_loss
127
+
128
+ return __call__
129
+
130
+
131
+ ### END### PERCEPTUAL LOSS ###
132
+
133
+
134
+ def l1_loss_fn(input, target=0):
135
+ return torch.mean(torch.abs(input - target))
136
+
137
+
138
+ ### END#################
139
+
140
+
141
+ ### START### ADVERSIAL LOSS ###
142
+ def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device):
143
+ if weight_gan > 0:
144
+ y_pred_fake, _ = discriminator(fake_data_lab)
145
+ y_pred_real, _ = discriminator(real_data_lab)
146
+
147
+ y = torch.ones_like(y_pred_real)
148
+ generator_loss = (
149
+ (
150
+ torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2)
151
+ + torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2)
152
+ )
153
+ / 2
154
+ * weight_gan
155
+ )
156
+ return generator_loss
157
+
158
+ return torch.Tensor([0]).to(device)
159
+
160
+
161
+ def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator):
162
+ y_pred_fake, _ = discriminator(fake_data_lab.detach())
163
+ y_pred_real, _ = discriminator(real_data_lab.detach())
164
+
165
+ y = torch.ones_like(y_pred_real)
166
+ discriminator_loss = (
167
+ torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2)
168
+ + torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2)
169
+ ) / 2
170
+ return discriminator_loss
171
+
172
+
173
+ ### END### ADVERSIAL LOSS #####
174
+
175
+
176
+ def consistent_loss_fn(
177
+ I_current_lab_predict,
178
+ I_last_ab_predict,
179
+ I_current_nonlocal_lab_predict,
180
+ I_last_nonlocal_lab_predict,
181
+ flow_forward,
182
+ mask,
183
+ warping_layer,
184
+ weight_consistent=0.02,
185
+ weight_nonlocal_consistent=0.0,
186
+ device="cuda",
187
+ ):
188
+ def weighted_mse_loss(input, target, weights):
189
+ out = (input - target) ** 2
190
+ out = out * weights.expand_as(out)
191
+ return out.mean()
192
+
193
+ def consistent():
194
+ I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward)
195
+ I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :]
196
+ consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent
197
+ return consistent_loss
198
+
199
+ def nonlocal_consistent():
200
+ I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward)
201
+ nonlocal_consistent_loss = (
202
+ weighted_mse_loss(
203
+ I_current_nonlocal_lab_predict_warp[:, 1:3, :, :],
204
+ I_last_nonlocal_lab_predict[:, 1:3, :, :],
205
+ mask,
206
+ )
207
+ * weight_nonlocal_consistent
208
+ )
209
+
210
+ return nonlocal_consistent_loss
211
+
212
+ consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device)
213
+ nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device)
214
+
215
+ return consistent_loss + nonlocal_consistent_loss
216
+
217
+
218
+ ### END### CONSISTENCY LOSS #####
219
+
220
+
221
+ ### START### SMOOTHNESS LOSS ###
222
+ def smoothness_loss_fn(
223
+ I_current_l,
224
+ I_current_lab,
225
+ I_current_ab_predict,
226
+ A_relu2_1,
227
+ weighted_layer_color,
228
+ nonlocal_weighted_layer,
229
+ weight_smoothness=5.0,
230
+ weight_nonlocal_smoothness=0.0,
231
+ device="cuda",
232
+ ):
233
+ def smoothness(scale_factor=1.0):
234
+ I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
235
+ IA_ab_weighed = weighted_layer_color(
236
+ I_current_lab,
237
+ I_current_lab_predict,
238
+ patch_size=3,
239
+ alpha=10,
240
+ scale_factor=scale_factor,
241
+ )
242
+ smoothness_loss = (
243
+ mse_loss_fn(
244
+ nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
245
+ IA_ab_weighed,
246
+ )
247
+ * weight_smoothness
248
+ )
249
+
250
+ return smoothness_loss
251
+
252
+ def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5):
253
+ nonlocal_smooth_feature = feature_normalize(A_relu2_1)
254
+ I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
255
+ I_current_ab_weighted_nonlocal = nonlocal_weighted_layer(
256
+ I_current_lab_predict,
257
+ nonlocal_smooth_feature.detach(),
258
+ patch_size=3,
259
+ alpha=alpha_nonlocal_smoothness,
260
+ scale_factor=scale_factor,
261
+ )
262
+ nonlocal_smoothness_loss = (
263
+ mse_loss_fn(
264
+ nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
265
+ I_current_ab_weighted_nonlocal,
266
+ )
267
+ * weight_nonlocal_smoothness
268
+ )
269
+ return nonlocal_smoothness_loss
270
+
271
+ smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device)
272
+ nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device)
273
+
274
+ return smoothness_loss + nonlocal_smoothness_loss
275
+
276
+
277
+ ### END### SMOOTHNESS LOSS #####
src/metrics.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage.metrics import structural_similarity, peak_signal_noise_ratio
2
+ import numpy as np
3
+ import lpips
4
+ import torch
5
+ from pytorch_fid.fid_score import calculate_frechet_distance
6
+ from pytorch_fid.inception import InceptionV3
7
+ import torch.nn as nn
8
+ import cv2
9
+ from scipy import stats
10
+ import os
11
+
12
+ def calc_ssim(pred_image, gt_image):
13
+ '''
14
+ Structural Similarity Index (SSIM) is a perceptual metric that quantifies the image quality degradation that is
15
+ caused by processing such as data compression or by losses in data transmission.
16
+
17
+ # Arguments
18
+ img1: PIL.Image
19
+ img2: PIL.Image
20
+ # Returns
21
+ ssim: float (-1.0, 1.0)
22
+ '''
23
+ pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
24
+ gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
25
+ ssim = structural_similarity(pred_image, gt_image, channel_axis=2, data_range=255.)
26
+ return ssim
27
+
28
+ def calc_psnr(pred_image, gt_image):
29
+ '''
30
+ Peak Signal-to-Noise Ratio (PSNR) is an expression for the ratio between the maximum possible value (power) of a signal
31
+ and the power of distorting noise that affects the quality of its representation.
32
+
33
+ # Arguments
34
+ img1: PIL.Image
35
+ img2: PIL.Image
36
+ # Returns
37
+ psnr: float
38
+ '''
39
+ pred_image = np.array(pred_image.convert('RGB')).astype(np.float32)
40
+ gt_image = np.array(gt_image.convert('RGB')).astype(np.float32)
41
+
42
+ psnr = peak_signal_noise_ratio(gt_image, pred_image, data_range=255.)
43
+ return psnr
44
+
45
+ class LPIPS_utils:
46
+ def __init__(self, device = 'cuda'):
47
+ self.loss_fn = lpips.LPIPS(net='vgg', spatial=True) # Can set net = 'squeeze' or 'vgg'or 'alex'
48
+ self.loss_fn = self.loss_fn.to(device)
49
+ self.device = device
50
+
51
+ def compare_lpips(self,img_fake, img_real, data_range=255.): # input: torch 1 c h w / h w c
52
+ img_fake = torch.from_numpy(np.array(img_fake).astype(np.float32)/data_range)
53
+ img_real = torch.from_numpy(np.array(img_real).astype(np.float32)/data_range)
54
+ if img_fake.ndim==3:
55
+ img_fake = img_fake.permute(2,0,1).unsqueeze(0)
56
+ img_real = img_real.permute(2,0,1).unsqueeze(0)
57
+ img_fake = img_fake.to(self.device)
58
+ img_real = img_real.to(self.device)
59
+
60
+ dist = self.loss_fn.forward(img_fake,img_real)
61
+ return dist.mean().item()
62
+
63
+ class FID_utils(nn.Module):
64
+ """Class for computing the Fréchet Inception Distance (FID) metric score.
65
+ It is implemented as a class in order to hold the inception model instance
66
+ in its state.
67
+ Parameters
68
+ ----------
69
+ resize_input : bool (optional)
70
+ Whether or not to resize the input images to the image size (299, 299)
71
+ on which the inception model was trained. Since the model is fully
72
+ convolutional, the score also works without resizing. In literature
73
+ and when working with GANs people tend to set this value to True,
74
+ however, for internal evaluation this is not necessary.
75
+ device : str or torch.device
76
+ The device on which to run the inception model.
77
+ """
78
+
79
+ def __init__(self, resize_input=True, device="cuda"):
80
+ super(FID_utils, self).__init__()
81
+ self.device = device
82
+ if self.device is None:
83
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ #self.model = InceptionV3(resize_input=resize_input).to(device)
85
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
86
+ self.model = InceptionV3([block_idx]).to(device)
87
+ self.model = self.model.eval()
88
+
89
+ def get_activations(self,batch): # 1 c h w
90
+ with torch.no_grad():
91
+ pred = self.model(batch)[0]
92
+ # If model output is not scalar, apply global spatial average pooling.
93
+ # This happens if you choose a dimensionality not equal 2048.
94
+ if pred.size(2) != 1 or pred.size(3) != 1:
95
+ #pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
96
+ print("error in get activations!")
97
+ #pred = pred.squeeze(3).squeeze(2).cpu().numpy()
98
+ return pred
99
+
100
+
101
+ def _get_mu_sigma(self, batch,data_range):
102
+ """Compute the inception mu and sigma for a batch of images.
103
+ Parameters
104
+ ----------
105
+ images : np.ndarray
106
+ A batch of images with shape (n_images,3, width, height).
107
+ Returns
108
+ -------
109
+ mu : np.ndarray
110
+ The array of mean activations with shape (2048,).
111
+ sigma : np.ndarray
112
+ The covariance matrix of activations with shape (2048, 2048).
113
+ """
114
+ # forward pass
115
+ if batch.ndim ==3 and batch.shape[2]==3:
116
+ batch=batch.permute(2,0,1).unsqueeze(0)
117
+ batch /= data_range
118
+ #batch = torch.tensor(batch)#.unsqueeze(1).repeat((1, 3, 1, 1))
119
+ batch = batch.to(self.device, torch.float32)
120
+ #(activations,) = self.model(batch)
121
+ activations = self.get_activations(batch)
122
+ activations = activations.detach().cpu().numpy().squeeze(3).squeeze(2)
123
+
124
+ # compute statistics
125
+ mu = np.mean(activations,axis=0)
126
+ sigma = np.cov(activations, rowvar=False)
127
+
128
+ return mu, sigma
129
+
130
+ def score(self, images_1, images_2, data_range=255.):
131
+ """Compute the FID score.
132
+ The input batches should have the shape (n_images,3, width, height). or (h,w,3)
133
+ Parameters
134
+ ----------
135
+ images_1 : np.ndarray
136
+ First batch of images.
137
+ images_2 : np.ndarray
138
+ Section batch of images.
139
+ Returns
140
+ -------
141
+ score : float
142
+ The FID score.
143
+ """
144
+ images_1 = torch.from_numpy(np.array(images_1).astype(np.float32))
145
+ images_2 = torch.from_numpy(np.array(images_2).astype(np.float32))
146
+ images_1 = images_1.to(self.device)
147
+ images_2 = images_2.to(self.device)
148
+
149
+ mu_1, sigma_1 = self._get_mu_sigma(images_1,data_range)
150
+ mu_2, sigma_2 = self._get_mu_sigma(images_2,data_range)
151
+ score = calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2)
152
+
153
+ return score
154
+
155
+ def JS_divergence(p, q):
156
+ M = (p + q) / 2
157
+ return 0.5 * stats.entropy(p, M) + 0.5 * stats.entropy(q, M)
158
+
159
+
160
+ def compute_JS_bgr(input_dir, dilation=1):
161
+ input_img_list = os.listdir(input_dir)
162
+ input_img_list.sort()
163
+ # print(input_img_list)
164
+
165
+ hist_b_list = [] # [img1_histb, img2_histb, ...]
166
+ hist_g_list = []
167
+ hist_r_list = []
168
+
169
+ for img_name in input_img_list:
170
+ # print(os.path.join(input_dir, img_name))
171
+ img_in = cv2.imread(os.path.join(input_dir, img_name))
172
+ H, W, C = img_in.shape
173
+
174
+ hist_b = cv2.calcHist([img_in], [0], None, [256], [0,256]) # B
175
+ hist_g = cv2.calcHist([img_in], [1], None, [256], [0,256]) # G
176
+ hist_r = cv2.calcHist([img_in], [2], None, [256], [0,256]) # R
177
+
178
+ hist_b = hist_b / (H * W)
179
+ hist_g = hist_g / (H * W)
180
+ hist_r = hist_r / (H * W)
181
+
182
+ hist_b_list.append(hist_b)
183
+ hist_g_list.append(hist_g)
184
+ hist_r_list.append(hist_r)
185
+
186
+ JS_b_list = []
187
+ JS_g_list = []
188
+ JS_r_list = []
189
+
190
+ for i in range(len(hist_b_list)):
191
+ if i + dilation > len(hist_b_list) - 1:
192
+ break
193
+ hist_b_img1 = hist_b_list[i]
194
+ hist_b_img2 = hist_b_list[i + dilation]
195
+ JS_b = JS_divergence(hist_b_img1, hist_b_img2)
196
+ JS_b_list.append(JS_b)
197
+
198
+ hist_g_img1 = hist_g_list[i]
199
+ hist_g_img2 = hist_g_list[i+dilation]
200
+ JS_g = JS_divergence(hist_g_img1, hist_g_img2)
201
+ JS_g_list.append(JS_g)
202
+
203
+ hist_r_img1 = hist_r_list[i]
204
+ hist_r_img2 = hist_r_list[i+dilation]
205
+ JS_r = JS_divergence(hist_r_img1, hist_r_img2)
206
+ JS_r_list.append(JS_r)
207
+
208
+ return JS_b_list, JS_g_list, JS_r_list
209
+
210
+
211
+ def calc_cdc(vid_folder, dilation=[1, 2, 4], weight=[1/3, 1/3, 1/3]):
212
+ mean_b, mean_g, mean_r = 0, 0, 0
213
+ for d, w in zip(dilation, weight):
214
+ JS_b_list_one, JS_g_list_one, JS_r_list_one = compute_JS_bgr(vid_folder, d)
215
+ mean_b += w * np.mean(JS_b_list_one)
216
+ mean_g += w * np.mean(JS_g_list_one)
217
+ mean_r += w * np.mean(JS_r_list_one)
218
+
219
+ cdc = np.mean([mean_b, mean_g, mean_r])
220
+ return cdc
221
+
222
+
223
+
224
+
225
+
src/models/CNN/ColorVidNet.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+
5
+ class ColorVidNet(nn.Module):
6
+ def __init__(self, ic):
7
+ super(ColorVidNet, self).__init__()
8
+ self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1))
9
+ self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
10
+ self.conv1_2norm = nn.BatchNorm2d(64, affine=False)
11
+ self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64)
12
+ self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
13
+ self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
14
+ self.conv2_2norm = nn.BatchNorm2d(128, affine=False)
15
+ self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128)
16
+ self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
17
+ self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
18
+ self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
19
+ self.conv3_3norm = nn.BatchNorm2d(256, affine=False)
20
+ self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256)
21
+ self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
22
+ self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
23
+ self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
24
+ self.conv4_3norm = nn.BatchNorm2d(512, affine=False)
25
+ self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
26
+ self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
27
+ self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
28
+ self.conv5_3norm = nn.BatchNorm2d(512, affine=False)
29
+ self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
30
+ self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
31
+ self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
32
+ self.conv6_3norm = nn.BatchNorm2d(512, affine=False)
33
+ self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1)
34
+ self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1)
35
+ self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1)
36
+ self.conv7_3norm = nn.BatchNorm2d(512, affine=False)
37
+ self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
38
+ self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1)
39
+ self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1)
40
+ self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1)
41
+ self.conv8_3norm = nn.BatchNorm2d(256, affine=False)
42
+ self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
43
+ self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1)
44
+ self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1)
45
+ self.conv9_2norm = nn.BatchNorm2d(128, affine=False)
46
+ self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1)
47
+ self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1)
48
+ self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1)
49
+ self.conv10_ab = nn.Conv2d(128, 2, 1, 1)
50
+
51
+ # add self.relux_x
52
+ self.relu1_1 = nn.PReLU()
53
+ self.relu1_2 = nn.PReLU()
54
+ self.relu2_1 = nn.PReLU()
55
+ self.relu2_2 = nn.PReLU()
56
+ self.relu3_1 = nn.PReLU()
57
+ self.relu3_2 = nn.PReLU()
58
+ self.relu3_3 = nn.PReLU()
59
+ self.relu4_1 = nn.PReLU()
60
+ self.relu4_2 = nn.PReLU()
61
+ self.relu4_3 = nn.PReLU()
62
+ self.relu5_1 = nn.PReLU()
63
+ self.relu5_2 = nn.PReLU()
64
+ self.relu5_3 = nn.PReLU()
65
+ self.relu6_1 = nn.PReLU()
66
+ self.relu6_2 = nn.PReLU()
67
+ self.relu6_3 = nn.PReLU()
68
+ self.relu7_1 = nn.PReLU()
69
+ self.relu7_2 = nn.PReLU()
70
+ self.relu7_3 = nn.PReLU()
71
+ self.relu8_1_comb = nn.PReLU()
72
+ self.relu8_2 = nn.PReLU()
73
+ self.relu8_3 = nn.PReLU()
74
+ self.relu9_1_comb = nn.PReLU()
75
+ self.relu9_2 = nn.PReLU()
76
+ self.relu10_1_comb = nn.PReLU()
77
+ self.relu10_2 = nn.LeakyReLU(0.2, True)
78
+
79
+ self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1))
80
+ self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1))
81
+ self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1))
82
+
83
+ self.conv1_2norm = nn.InstanceNorm2d(64)
84
+ self.conv2_2norm = nn.InstanceNorm2d(128)
85
+ self.conv3_3norm = nn.InstanceNorm2d(256)
86
+ self.conv4_3norm = nn.InstanceNorm2d(512)
87
+ self.conv5_3norm = nn.InstanceNorm2d(512)
88
+ self.conv6_3norm = nn.InstanceNorm2d(512)
89
+ self.conv7_3norm = nn.InstanceNorm2d(512)
90
+ self.conv8_3norm = nn.InstanceNorm2d(256)
91
+ self.conv9_2norm = nn.InstanceNorm2d(128)
92
+
93
+ def forward(self, x):
94
+ """x: gray image (1 channel), ab(2 channel), ab_err, ba_err"""
95
+ conv1_1 = self.relu1_1(self.conv1_1(x))
96
+ conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
97
+ conv1_2norm = self.conv1_2norm(conv1_2)
98
+ conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm)
99
+ conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss))
100
+ conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
101
+ conv2_2norm = self.conv2_2norm(conv2_2)
102
+ conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm)
103
+ conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss))
104
+ conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
105
+ conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
106
+ conv3_3norm = self.conv3_3norm(conv3_3)
107
+ conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm)
108
+ conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss))
109
+ conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
110
+ conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
111
+ conv4_3norm = self.conv4_3norm(conv4_3)
112
+ conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm))
113
+ conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
114
+ conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
115
+ conv5_3norm = self.conv5_3norm(conv5_3)
116
+ conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm))
117
+ conv6_2 = self.relu6_2(self.conv6_2(conv6_1))
118
+ conv6_3 = self.relu6_3(self.conv6_3(conv6_2))
119
+ conv6_3norm = self.conv6_3norm(conv6_3)
120
+ conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm))
121
+ conv7_2 = self.relu7_2(self.conv7_2(conv7_1))
122
+ conv7_3 = self.relu7_3(self.conv7_3(conv7_2))
123
+ conv7_3norm = self.conv7_3norm(conv7_3)
124
+ conv8_1 = self.conv8_1(conv7_3norm)
125
+ conv3_3_short = self.conv3_3_short(conv3_3norm)
126
+ conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short)
127
+ conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb))
128
+ conv8_3 = self.relu8_3(self.conv8_3(conv8_2))
129
+ conv8_3norm = self.conv8_3norm(conv8_3)
130
+ conv9_1 = self.conv9_1(conv8_3norm)
131
+ conv2_2_short = self.conv2_2_short(conv2_2norm)
132
+ conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short)
133
+ conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb))
134
+ conv9_2norm = self.conv9_2norm(conv9_2)
135
+ conv10_1 = self.conv10_1(conv9_2norm)
136
+ conv1_2_short = self.conv1_2_short(conv1_2norm)
137
+ conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short)
138
+ conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb))
139
+ conv10_ab = self.conv10_ab(conv10_2)
140
+
141
+ return torch.tanh(conv10_ab) * 128
src/models/CNN/FrameColor.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.utils import *
3
+ from src.models.vit.vit import FeatureTransform
4
+
5
+
6
+ def warp_color(
7
+ IA_l,
8
+ IB_lab,
9
+ features_B,
10
+ embed_net,
11
+ nonlocal_net,
12
+ temperature=0.01,
13
+ ):
14
+ IA_rgb_from_gray = gray2rgb_batch(IA_l)
15
+
16
+ with torch.no_grad():
17
+ A_feat0, A_feat1, A_feat2, A_feat3 = embed_net(IA_rgb_from_gray)
18
+ B_feat0, B_feat1, B_feat2, B_feat3 = features_B
19
+
20
+ A_feat0 = feature_normalize(A_feat0)
21
+ A_feat1 = feature_normalize(A_feat1)
22
+ A_feat2 = feature_normalize(A_feat2)
23
+ A_feat3 = feature_normalize(A_feat3)
24
+
25
+ B_feat0 = feature_normalize(B_feat0)
26
+ B_feat1 = feature_normalize(B_feat1)
27
+ B_feat2 = feature_normalize(B_feat2)
28
+ B_feat3 = feature_normalize(B_feat3)
29
+
30
+ return nonlocal_net(
31
+ IB_lab,
32
+ A_feat0,
33
+ A_feat1,
34
+ A_feat2,
35
+ A_feat3,
36
+ B_feat0,
37
+ B_feat1,
38
+ B_feat2,
39
+ B_feat3,
40
+ temperature=temperature,
41
+ )
42
+
43
+
44
+ def frame_colorization(
45
+ IA_l,
46
+ IB_lab,
47
+ IA_last_lab,
48
+ features_B,
49
+ embed_net,
50
+ nonlocal_net,
51
+ colornet,
52
+ joint_training=True,
53
+ luminance_noise=0,
54
+ temperature=0.01,
55
+ ):
56
+ if luminance_noise:
57
+ IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise
58
+
59
+ with torch.autograd.set_grad_enabled(joint_training):
60
+ nonlocal_BA_lab, similarity_map = warp_color(
61
+ IA_l,
62
+ IB_lab,
63
+ features_B,
64
+ embed_net,
65
+ nonlocal_net,
66
+ temperature=temperature,
67
+ )
68
+ nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :]
69
+ IA_ab_predict = colornet(
70
+ torch.cat(
71
+ (IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab),
72
+ dim=1,
73
+ )
74
+ )
75
+
76
+ return IA_ab_predict, nonlocal_BA_lab
src/models/CNN/GAN_models.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DCGAN-like generator and discriminator
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import Parameter
6
+
7
+
8
+ def l2normalize(v, eps=1e-12):
9
+ return v / (v.norm() + eps)
10
+
11
+
12
+ class SpectralNorm(nn.Module):
13
+ def __init__(self, module, name="weight", power_iterations=1):
14
+ super(SpectralNorm, self).__init__()
15
+ self.module = module
16
+ self.name = name
17
+ self.power_iterations = power_iterations
18
+ if not self._made_params():
19
+ self._make_params()
20
+
21
+ def _update_u_v(self):
22
+ u = getattr(self.module, self.name + "_u")
23
+ v = getattr(self.module, self.name + "_v")
24
+ w = getattr(self.module, self.name + "_bar")
25
+
26
+ height = w.data.shape[0]
27
+ for _ in range(self.power_iterations):
28
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
29
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
30
+
31
+ sigma = u.dot(w.view(height, -1).mv(v))
32
+ setattr(self.module, self.name, w / sigma.expand_as(w))
33
+
34
+ def _made_params(self):
35
+ try:
36
+ u = getattr(self.module, self.name + "_u")
37
+ v = getattr(self.module, self.name + "_v")
38
+ w = getattr(self.module, self.name + "_bar")
39
+ return True
40
+ except AttributeError:
41
+ return False
42
+
43
+ def _make_params(self):
44
+ w = getattr(self.module, self.name)
45
+
46
+ height = w.data.shape[0]
47
+ width = w.view(height, -1).data.shape[1]
48
+
49
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
50
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
51
+ u.data = l2normalize(u.data)
52
+ v.data = l2normalize(v.data)
53
+ w_bar = Parameter(w.data)
54
+
55
+ del self.module._parameters[self.name]
56
+
57
+ self.module.register_parameter(self.name + "_u", u)
58
+ self.module.register_parameter(self.name + "_v", v)
59
+ self.module.register_parameter(self.name + "_bar", w_bar)
60
+
61
+ def forward(self, *args):
62
+ self._update_u_v()
63
+ return self.module.forward(*args)
64
+
65
+
66
+ class Generator(nn.Module):
67
+ def __init__(self, z_dim):
68
+ super(Generator, self).__init__()
69
+ self.z_dim = z_dim
70
+
71
+ self.model = nn.Sequential(
72
+ nn.ConvTranspose2d(z_dim, 512, 4, stride=1),
73
+ nn.InstanceNorm2d(512),
74
+ nn.ReLU(),
75
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)),
76
+ nn.InstanceNorm2d(256),
77
+ nn.ReLU(),
78
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)),
79
+ nn.InstanceNorm2d(128),
80
+ nn.ReLU(),
81
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)),
82
+ nn.InstanceNorm2d(64),
83
+ nn.ReLU(),
84
+ nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1, 1)),
85
+ nn.Tanh(),
86
+ )
87
+
88
+ def forward(self, z):
89
+ return self.model(z.view(-1, self.z_dim, 1, 1))
90
+
91
+
92
+ channels = 3
93
+ leak = 0.1
94
+ w_g = 4
95
+
96
+
97
+ class Discriminator(nn.Module):
98
+ def __init__(self):
99
+ super(Discriminator, self).__init__()
100
+
101
+ self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1, 1)))
102
+ self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1, 1)))
103
+ self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1, 1)))
104
+ self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1, 1)))
105
+ self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1, 1)))
106
+ self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1, 1)))
107
+ self.conv7 = SpectralNorm(nn.Conv2d(256, 256, 3, stride=1, padding=(1, 1)))
108
+ self.conv8 = SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1, 1)))
109
+ self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, 1))
110
+
111
+ def forward(self, x):
112
+ m = x
113
+ m = nn.LeakyReLU(leak)(self.conv1(m))
114
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(64)(self.conv2(m)))
115
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv3(m)))
116
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv4(m)))
117
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv5(m)))
118
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv6(m)))
119
+ m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv7(m)))
120
+ m = nn.LeakyReLU(leak)(self.conv8(m))
121
+
122
+ return self.fc(m.view(-1, w_g * w_g * 512))
123
+
124
+
125
+ class Self_Attention(nn.Module):
126
+ """Self attention Layer"""
127
+
128
+ def __init__(self, in_dim):
129
+ super(Self_Attention, self).__init__()
130
+ self.chanel_in = in_dim
131
+
132
+ self.query_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
133
+ self.key_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
134
+ self.value_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1))
135
+ self.gamma = nn.Parameter(torch.zeros(1))
136
+
137
+ self.softmax = nn.Softmax(dim=-1) #
138
+
139
+ def forward(self, x):
140
+ """
141
+ inputs :
142
+ x : input feature maps( B X C X W X H)
143
+ returns :
144
+ out : self attention value + input feature
145
+ attention: B X N X N (N is Width*Height)
146
+ """
147
+ m_batchsize, C, width, height = x.size()
148
+ proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
149
+ proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
150
+ energy = torch.bmm(proj_query, proj_key) # transpose check
151
+ attention = self.softmax(energy) # BX (N) X (N)
152
+ proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
153
+
154
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
155
+ out = out.view(m_batchsize, C, width, height)
156
+
157
+ out = self.gamma * out + x
158
+ return out
159
+
160
+ class Discriminator_x64_224(nn.Module):
161
+ """
162
+ Discriminative Network
163
+ """
164
+
165
+ def __init__(self, in_size=6, ndf=64):
166
+ super(Discriminator_x64_224, self).__init__()
167
+ self.in_size = in_size
168
+ self.ndf = ndf
169
+
170
+ self.layer1 = nn.Sequential(SpectralNorm(nn.Conv2d(self.in_size, self.ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True))
171
+
172
+ self.layer2 = nn.Sequential(
173
+ SpectralNorm(nn.Conv2d(self.ndf, self.ndf, 4, 2, 1)),
174
+ nn.InstanceNorm2d(self.ndf),
175
+ nn.LeakyReLU(0.2, inplace=True),
176
+ )
177
+ self.attention = Self_Attention(self.ndf)
178
+ self.layer3 = nn.Sequential(
179
+ SpectralNorm(nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1)),
180
+ nn.InstanceNorm2d(self.ndf * 2),
181
+ nn.LeakyReLU(0.2, inplace=True),
182
+ )
183
+ self.layer4 = nn.Sequential(
184
+ SpectralNorm(nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1)),
185
+ nn.InstanceNorm2d(self.ndf * 4),
186
+ nn.LeakyReLU(0.2, inplace=True),
187
+ )
188
+ self.layer5 = nn.Sequential(
189
+ SpectralNorm(nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1)),
190
+ nn.InstanceNorm2d(self.ndf * 8),
191
+ nn.LeakyReLU(0.2, inplace=True),
192
+ )
193
+ self.layer6 = nn.Sequential(
194
+ SpectralNorm(nn.Conv2d(self.ndf * 8, self.ndf * 16, 4, 2, 1)),
195
+ nn.InstanceNorm2d(self.ndf * 16),
196
+ nn.LeakyReLU(0.2, inplace=True),
197
+ )
198
+
199
+ self.last = SpectralNorm(nn.Conv2d(self.ndf * 16, 1, [3, 3], 1, 0))
200
+
201
+ def forward(self, input):
202
+ feature1 = self.layer1(input)
203
+ feature2 = self.layer2(feature1)
204
+ feature_attention = self.attention(feature2)
205
+ feature3 = self.layer3(feature_attention)
206
+ feature4 = self.layer4(feature3)
207
+ feature5 = self.layer5(feature4)
208
+ feature6 = self.layer6(feature5)
209
+ output = self.last(feature6)
210
+ output = F.avg_pool2d(output, output.size()[2:]).view(output.size()[0], -1)
211
+
212
+ return output, feature4
src/models/CNN/NonlocalNet.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from src.utils import uncenter_l
6
+
7
+
8
+ def find_local_patch(x, patch_size):
9
+ """
10
+ > We take a tensor `x` and return a tensor `x_unfold` that contains all the patches of size
11
+ `patch_size` in `x`
12
+
13
+ Args:
14
+ x: the input tensor
15
+ patch_size: the size of the patch to be extracted.
16
+ """
17
+
18
+ N, C, H, W = x.shape
19
+ x_unfold = F.unfold(x, kernel_size=(patch_size, patch_size), padding=(patch_size // 2, patch_size // 2), stride=(1, 1))
20
+
21
+ return x_unfold.view(N, x_unfold.shape[1], H, W)
22
+
23
+
24
+ class WeightedAverage(nn.Module):
25
+ def __init__(
26
+ self,
27
+ ):
28
+ super(WeightedAverage, self).__init__()
29
+
30
+ def forward(self, x_lab, patch_size=3, alpha=1, scale_factor=1):
31
+ """
32
+ It takes a 3-channel image (L, A, B) and returns a 2-channel image (A, B) where each pixel is a
33
+ weighted average of the A and B values of the pixels in a 3x3 neighborhood around it
34
+
35
+ Args:
36
+ x_lab: the input image in LAB color space
37
+ patch_size: the size of the patch to use for the local average. Defaults to 3
38
+ alpha: the higher the alpha, the smoother the output. Defaults to 1
39
+ scale_factor: the scale factor of the input image. Defaults to 1
40
+
41
+ Returns:
42
+ The output of the forward function is a tensor of size (batch_size, 2, height, width)
43
+ """
44
+ # alpha=0: less smooth; alpha=inf: smoother
45
+ x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
46
+ l = x_lab[:, 0:1, :, :]
47
+ a = x_lab[:, 1:2, :, :]
48
+ b = x_lab[:, 2:3, :, :]
49
+ local_l = find_local_patch(l, patch_size)
50
+ local_a = find_local_patch(a, patch_size)
51
+ local_b = find_local_patch(b, patch_size)
52
+ local_difference_l = (local_l - l) ** 2
53
+ correlation = nn.functional.softmax(-1 * local_difference_l / alpha, dim=1)
54
+
55
+ return torch.cat(
56
+ (
57
+ torch.sum(correlation * local_a, dim=1, keepdim=True),
58
+ torch.sum(correlation * local_b, dim=1, keepdim=True),
59
+ ),
60
+ 1,
61
+ )
62
+
63
+
64
+ class WeightedAverage_color(nn.Module):
65
+ """
66
+ smooth the image according to the color distance in the LAB space
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ ):
72
+ super(WeightedAverage_color, self).__init__()
73
+
74
+ def forward(self, x_lab, x_lab_predict, patch_size=3, alpha=1, scale_factor=1):
75
+ """
76
+ It takes the predicted a and b channels, and the original a and b channels, and finds the
77
+ weighted average of the predicted a and b channels based on the similarity of the original a and
78
+ b channels to the predicted a and b channels
79
+
80
+ Args:
81
+ x_lab: the input image in LAB color space
82
+ x_lab_predict: the predicted LAB image
83
+ patch_size: the size of the patch to use for the local color correction. Defaults to 3
84
+ alpha: controls the smoothness of the output. Defaults to 1
85
+ scale_factor: the scale factor of the input image. Defaults to 1
86
+
87
+ Returns:
88
+ The return is the weighted average of the local a and b channels.
89
+ """
90
+ """ alpha=0: less smooth; alpha=inf: smoother """
91
+ x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
92
+ l = uncenter_l(x_lab[:, 0:1, :, :])
93
+ a = x_lab[:, 1:2, :, :]
94
+ b = x_lab[:, 2:3, :, :]
95
+ a_predict = x_lab_predict[:, 1:2, :, :]
96
+ b_predict = x_lab_predict[:, 2:3, :, :]
97
+ local_l = find_local_patch(l, patch_size)
98
+ local_a = find_local_patch(a, patch_size)
99
+ local_b = find_local_patch(b, patch_size)
100
+ local_a_predict = find_local_patch(a_predict, patch_size)
101
+ local_b_predict = find_local_patch(b_predict, patch_size)
102
+
103
+ local_color_difference = (local_l - l) ** 2 + (local_a - a) ** 2 + (local_b - b) ** 2
104
+ # so that sum of weights equal to 1
105
+ correlation = nn.functional.softmax(-1 * local_color_difference / alpha, dim=1)
106
+
107
+ return torch.cat(
108
+ (
109
+ torch.sum(correlation * local_a_predict, dim=1, keepdim=True),
110
+ torch.sum(correlation * local_b_predict, dim=1, keepdim=True),
111
+ ),
112
+ 1,
113
+ )
114
+
115
+
116
+ class NonlocalWeightedAverage(nn.Module):
117
+ def __init__(
118
+ self,
119
+ ):
120
+ super(NonlocalWeightedAverage, self).__init__()
121
+
122
+ def forward(self, x_lab, feature, patch_size=3, alpha=0.1, scale_factor=1):
123
+ """
124
+ It takes in a feature map and a label map, and returns a smoothed label map
125
+
126
+ Args:
127
+ x_lab: the input image in LAB color space
128
+ feature: the feature map of the input image
129
+ patch_size: the size of the patch to be used for the correlation matrix. Defaults to 3
130
+ alpha: the higher the alpha, the smoother the output.
131
+ scale_factor: the scale factor of the input image. Defaults to 1
132
+
133
+ Returns:
134
+ weighted_ab is the weighted ab channel of the image.
135
+ """
136
+ # alpha=0: less smooth; alpha=inf: smoother
137
+ # input feature is normalized feature
138
+ x_lab = F.interpolate(x_lab, scale_factor=scale_factor)
139
+ batch_size, channel, height, width = x_lab.shape
140
+ feature = F.interpolate(feature, size=(height, width))
141
+ batch_size = x_lab.shape[0]
142
+ x_ab = x_lab[:, 1:3, :, :].view(batch_size, 2, -1)
143
+ x_ab = x_ab.permute(0, 2, 1)
144
+
145
+ local_feature = find_local_patch(feature, patch_size)
146
+ local_feature = local_feature.view(batch_size, local_feature.shape[1], -1)
147
+
148
+ correlation_matrix = torch.matmul(local_feature.permute(0, 2, 1), local_feature)
149
+ correlation_matrix = nn.functional.softmax(correlation_matrix / alpha, dim=-1)
150
+
151
+ weighted_ab = torch.matmul(correlation_matrix, x_ab)
152
+ weighted_ab = weighted_ab.permute(0, 2, 1).contiguous()
153
+ weighted_ab = weighted_ab.view(batch_size, 2, height, width)
154
+ return weighted_ab
155
+
156
+
157
+ class CorrelationLayer(nn.Module):
158
+ def __init__(self, search_range):
159
+ super(CorrelationLayer, self).__init__()
160
+ self.search_range = search_range
161
+
162
+ def forward(self, x1, x2, alpha=1, raw_output=False, metric="similarity"):
163
+ """
164
+ It takes two tensors, x1 and x2, and returns a tensor of shape (batch_size, (search_range * 2 +
165
+ 1) ** 2, height, width) where each element is the dot product of the corresponding patch in x1
166
+ and x2
167
+
168
+ Args:
169
+ x1: the first image
170
+ x2: the image to be warped
171
+ alpha: the temperature parameter for the softmax function. Defaults to 1
172
+ raw_output: if True, return the raw output of the network, otherwise return the softmax
173
+ output. Defaults to False
174
+ metric: "similarity" or "subtraction". Defaults to similarity
175
+
176
+ Returns:
177
+ The output of the forward function is a softmax of the correlation volume.
178
+ """
179
+ shape = list(x1.size())
180
+ shape[1] = (self.search_range * 2 + 1) ** 2
181
+ cv = torch.zeros(shape).to(torch.device("cuda"))
182
+
183
+ for i in range(-self.search_range, self.search_range + 1):
184
+ for j in range(-self.search_range, self.search_range + 1):
185
+ if i < 0:
186
+ slice_h, slice_h_r = slice(None, i), slice(-i, None)
187
+ elif i > 0:
188
+ slice_h, slice_h_r = slice(i, None), slice(None, -i)
189
+ else:
190
+ slice_h, slice_h_r = slice(None), slice(None)
191
+
192
+ if j < 0:
193
+ slice_w, slice_w_r = slice(None, j), slice(-j, None)
194
+ elif j > 0:
195
+ slice_w, slice_w_r = slice(j, None), slice(None, -j)
196
+ else:
197
+ slice_w, slice_w_r = slice(None), slice(None)
198
+
199
+ if metric == "similarity":
200
+ cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = (
201
+ x1[:, :, slice_h, slice_w] * x2[:, :, slice_h_r, slice_w_r]
202
+ ).sum(1)
203
+ else: # patchwise subtraction
204
+ cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = -(
205
+ (x1[:, :, slice_h, slice_w] - x2[:, :, slice_h_r, slice_w_r]) ** 2
206
+ ).sum(1)
207
+
208
+ # TODO sigmoid?
209
+ if raw_output:
210
+ return cv
211
+ else:
212
+ return nn.functional.softmax(cv / alpha, dim=1)
213
+
214
+
215
+ class WTA_scale(torch.autograd.Function):
216
+ """
217
+ We can implement our own custom autograd Functions by subclassing
218
+ torch.autograd.Function and implementing the forward and backward passes
219
+ which operate on Tensors.
220
+ """
221
+
222
+ @staticmethod
223
+ def forward(ctx, input, scale=1e-4):
224
+ """
225
+ In the forward pass we receive a Tensor containing the input and return a
226
+ Tensor containing the output. You can cache arbitrary Tensors for use in the
227
+ backward pass using the save_for_backward method.
228
+ """
229
+ activation_max, index_max = torch.max(input, -1, keepdim=True)
230
+ input_scale = input * scale # default: 1e-4
231
+ # input_scale = input * scale # default: 1e-4
232
+ output_max_scale = torch.where(input == activation_max, input, input_scale)
233
+
234
+ mask = (input == activation_max).type(torch.float)
235
+ ctx.save_for_backward(input, mask)
236
+ return output_max_scale
237
+
238
+ @staticmethod
239
+ def backward(ctx, grad_output):
240
+ """
241
+ In the backward pass we receive a Tensor containing the gradient of the loss
242
+ with respect to the output, and we need to compute the gradient of the loss
243
+ with respect to the input.
244
+ """
245
+ input, mask = ctx.saved_tensors
246
+ mask_ones = torch.ones_like(mask)
247
+ mask_small_ones = torch.ones_like(mask) * 1e-4
248
+ # mask_small_ones = torch.ones_like(mask) * 1e-4
249
+
250
+ grad_scale = torch.where(mask == 1, mask_ones, mask_small_ones)
251
+ grad_input = grad_output.clone() * grad_scale
252
+ return grad_input, None
253
+
254
+
255
+ class ResidualBlock(nn.Module):
256
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
257
+ super(ResidualBlock, self).__init__()
258
+ self.padding1 = nn.ReflectionPad2d(padding)
259
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
260
+ self.bn1 = nn.InstanceNorm2d(out_channels)
261
+ self.prelu = nn.PReLU()
262
+ self.padding2 = nn.ReflectionPad2d(padding)
263
+ self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
264
+ self.bn2 = nn.InstanceNorm2d(out_channels)
265
+
266
+ def forward(self, x):
267
+ residual = x
268
+ out = self.padding1(x)
269
+ out = self.conv1(out)
270
+ out = self.bn1(out)
271
+ out = self.prelu(out)
272
+ out = self.padding2(out)
273
+ out = self.conv2(out)
274
+ out = self.bn2(out)
275
+ out += residual
276
+ out = self.prelu(out)
277
+ return out
278
+
279
+
280
+ class WarpNet(nn.Module):
281
+ """input is Al, Bl, channel = 1, range~[0,255]"""
282
+
283
+ def __init__(self, feature_channel=128):
284
+ super(WarpNet, self).__init__()
285
+ self.feature_channel = feature_channel
286
+ self.in_channels = self.feature_channel * 4
287
+ self.inter_channels = 256
288
+ # 44*44
289
+ self.layer2_1 = nn.Sequential(
290
+ nn.ReflectionPad2d(1),
291
+ # nn.Conv2d(128, 128, kernel_size=3, padding=0, stride=1),
292
+ # nn.Conv2d(96, 128, kernel_size=3, padding=20, stride=1),
293
+ nn.Conv2d(96, 128, kernel_size=3, padding=0, stride=1),
294
+ nn.InstanceNorm2d(128),
295
+ nn.PReLU(),
296
+ nn.ReflectionPad2d(1),
297
+ nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=2),
298
+ nn.InstanceNorm2d(self.feature_channel),
299
+ nn.PReLU(),
300
+ nn.Dropout(0.2),
301
+ )
302
+ self.layer3_1 = nn.Sequential(
303
+ nn.ReflectionPad2d(1),
304
+ # nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1),
305
+ # nn.Conv2d(192, 128, kernel_size=3, padding=10, stride=1),
306
+ nn.Conv2d(192, 128, kernel_size=3, padding=0, stride=1),
307
+ nn.InstanceNorm2d(128),
308
+ nn.PReLU(),
309
+ nn.ReflectionPad2d(1),
310
+ nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=1),
311
+ nn.InstanceNorm2d(self.feature_channel),
312
+ nn.PReLU(),
313
+ nn.Dropout(0.2),
314
+ )
315
+
316
+ # 22*22->44*44
317
+ self.layer4_1 = nn.Sequential(
318
+ nn.ReflectionPad2d(1),
319
+ # nn.Conv2d(512, 256, kernel_size=3, padding=0, stride=1),
320
+ # nn.Conv2d(384, 256, kernel_size=3, padding=5, stride=1),
321
+ nn.Conv2d(384, 256, kernel_size=3, padding=0, stride=1),
322
+ nn.InstanceNorm2d(256),
323
+ nn.PReLU(),
324
+ nn.ReflectionPad2d(1),
325
+ nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1),
326
+ nn.InstanceNorm2d(self.feature_channel),
327
+ nn.PReLU(),
328
+ nn.Upsample(scale_factor=2),
329
+ nn.Dropout(0.2),
330
+ )
331
+
332
+ # 11*11->44*44
333
+ self.layer5_1 = nn.Sequential(
334
+ nn.ReflectionPad2d(1),
335
+ # nn.Conv2d(1024, 256, kernel_size=3, padding=0, stride=1),
336
+ # nn.Conv2d(768, 256, kernel_size=2, padding=2, stride=1),
337
+ nn.Conv2d(768, 256, kernel_size=3, padding=0, stride=1),
338
+ nn.InstanceNorm2d(256),
339
+ nn.PReLU(),
340
+ nn.Upsample(scale_factor=2),
341
+ nn.ReflectionPad2d(1),
342
+ nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1),
343
+ nn.InstanceNorm2d(self.feature_channel),
344
+ nn.PReLU(),
345
+ nn.Upsample(scale_factor=2),
346
+ nn.Dropout(0.2),
347
+ )
348
+
349
+ self.layer = nn.Sequential(
350
+ ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
351
+ ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
352
+ ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1),
353
+ )
354
+
355
+ self.theta = nn.Conv2d(
356
+ in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
357
+ )
358
+ self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
359
+
360
+ self.upsampling = nn.Upsample(scale_factor=4)
361
+
362
+ def forward(
363
+ self,
364
+ B_lab_map,
365
+ A_relu2_1,
366
+ A_relu3_1,
367
+ A_relu4_1,
368
+ A_relu5_1,
369
+ B_relu2_1,
370
+ B_relu3_1,
371
+ B_relu4_1,
372
+ B_relu5_1,
373
+ temperature=0.001 * 5,
374
+ detach_flag=False,
375
+ WTA_scale_weight=1,
376
+ ):
377
+ batch_size = B_lab_map.shape[0]
378
+ channel = B_lab_map.shape[1]
379
+ image_height = B_lab_map.shape[2]
380
+ image_width = B_lab_map.shape[3]
381
+ feature_height = int(image_height / 4)
382
+ feature_width = int(image_width / 4)
383
+
384
+ # scale feature size to 44*44
385
+ A_feature2_1 = self.layer2_1(A_relu2_1)
386
+ B_feature2_1 = self.layer2_1(B_relu2_1)
387
+ A_feature3_1 = self.layer3_1(A_relu3_1)
388
+ B_feature3_1 = self.layer3_1(B_relu3_1)
389
+ A_feature4_1 = self.layer4_1(A_relu4_1)
390
+ B_feature4_1 = self.layer4_1(B_relu4_1)
391
+ A_feature5_1 = self.layer5_1(A_relu5_1)
392
+ B_feature5_1 = self.layer5_1(B_relu5_1)
393
+
394
+ # concatenate features
395
+ if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]:
396
+ A_feature5_1 = F.pad(A_feature5_1, (0, 0, 1, 1), "replicate")
397
+ B_feature5_1 = F.pad(B_feature5_1, (0, 0, 1, 1), "replicate")
398
+
399
+ A_features = self.layer(torch.cat((A_feature2_1, A_feature3_1, A_feature4_1, A_feature5_1), 1))
400
+ B_features = self.layer(torch.cat((B_feature2_1, B_feature3_1, B_feature4_1, B_feature5_1), 1))
401
+
402
+ # pairwise cosine similarity
403
+ theta = self.theta(A_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width)
404
+ theta = theta - theta.mean(dim=-1, keepdim=True) # center the feature
405
+ theta_norm = torch.norm(theta, 2, 1, keepdim=True) + sys.float_info.epsilon
406
+ theta = torch.div(theta, theta_norm)
407
+ theta_permute = theta.permute(0, 2, 1) # 2*(feature_height*feature_width)*256
408
+ phi = self.phi(B_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width)
409
+ phi = phi - phi.mean(dim=-1, keepdim=True) # center the feature
410
+ phi_norm = torch.norm(phi, 2, 1, keepdim=True) + sys.float_info.epsilon
411
+ phi = torch.div(phi, phi_norm)
412
+ f = torch.matmul(theta_permute, phi) # 2*(feature_height*feature_width)*(feature_height*feature_width)
413
+ if detach_flag:
414
+ f = f.detach()
415
+
416
+ f_similarity = f.unsqueeze_(dim=1)
417
+ similarity_map = torch.max(f_similarity, -1, keepdim=True)[0]
418
+ similarity_map = similarity_map.view(batch_size, 1, feature_height, feature_width)
419
+
420
+ # f can be negative
421
+ f_WTA = f if WTA_scale_weight == 1 else WTA_scale.apply(f, WTA_scale_weight)
422
+ f_WTA = f_WTA / temperature
423
+ f_div_C = F.softmax(f_WTA.squeeze_(), dim=-1) # 2*1936*1936;
424
+
425
+ # downsample the reference color
426
+ B_lab = F.avg_pool2d(B_lab_map, 4)
427
+ B_lab = B_lab.view(batch_size, channel, -1)
428
+ B_lab = B_lab.permute(0, 2, 1) # 2*1936*channel
429
+
430
+ # multiply the corr map with color
431
+ y = torch.matmul(f_div_C, B_lab) # 2*1936*channel
432
+ y = y.permute(0, 2, 1).contiguous()
433
+ y = y.view(batch_size, channel, feature_height, feature_width) # 2*3*44*44
434
+ y = self.upsampling(y)
435
+ similarity_map = self.upsampling(similarity_map)
436
+
437
+ return y, similarity_map
src/models/CNN/__init__.py ADDED
File without changes
src/models/__init__.py ADDED
File without changes
src/models/vit/__init__.py ADDED
File without changes
src/models/vit/blocks.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from timm.models.layers import DropPath
3
+
4
+
5
+ class FeedForward(nn.Module):
6
+ def __init__(self, dim, hidden_dim, dropout, out_dim=None):
7
+ super().__init__()
8
+ self.fc1 = nn.Linear(dim, hidden_dim)
9
+ self.act = nn.GELU()
10
+ if out_dim is None:
11
+ out_dim = dim
12
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
13
+ self.drop = nn.Dropout(dropout)
14
+
15
+ @property
16
+ def unwrapped(self):
17
+ return self
18
+
19
+ def forward(self, x):
20
+ x = self.fc1(x)
21
+ x = self.act(x)
22
+ x = self.drop(x)
23
+ x = self.fc2(x)
24
+ x = self.drop(x)
25
+ return x
26
+
27
+
28
+ class Attention(nn.Module):
29
+ def __init__(self, dim, heads, dropout):
30
+ super().__init__()
31
+ self.heads = heads
32
+ head_dim = dim // heads
33
+ self.scale = head_dim**-0.5
34
+ self.attn = None
35
+
36
+ self.qkv = nn.Linear(dim, dim * 3)
37
+ self.attn_drop = nn.Dropout(dropout)
38
+ self.proj = nn.Linear(dim, dim)
39
+ self.proj_drop = nn.Dropout(dropout)
40
+
41
+ @property
42
+ def unwrapped(self):
43
+ return self
44
+
45
+ def forward(self, x, mask=None):
46
+ B, N, C = x.shape
47
+ qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
48
+ q, k, v = (
49
+ qkv[0],
50
+ qkv[1],
51
+ qkv[2],
52
+ )
53
+
54
+ attn = (q @ k.transpose(-2, -1)) * self.scale
55
+ attn = attn.softmax(dim=-1)
56
+ attn = self.attn_drop(attn)
57
+
58
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
59
+ x = self.proj(x)
60
+ x = self.proj_drop(x)
61
+
62
+ return x, attn
63
+
64
+
65
+ class Block(nn.Module):
66
+ def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
67
+ super().__init__()
68
+ self.norm1 = nn.LayerNorm(dim)
69
+ self.norm2 = nn.LayerNorm(dim)
70
+ self.attn = Attention(dim, heads, dropout)
71
+ self.mlp = FeedForward(dim, mlp_dim, dropout)
72
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ def forward(self, x, mask=None, return_attention=False):
75
+ y, attn = self.attn(self.norm1(x), mask)
76
+ if return_attention:
77
+ return attn
78
+ x = x + self.drop_path(y)
79
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
80
+ return x
src/models/vit/config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+
4
+ import os
5
+
6
+
7
+ def load_config():
8
+ return yaml.load(
9
+ open(Path(__file__).parent / "config.yml", "r"), Loader=yaml.FullLoader
10
+ )
11
+
12
+
13
+ def check_os_environ(key, use):
14
+ if key not in os.environ:
15
+ raise ValueError(
16
+ f"{key} is not defined in the os variables, it is required for {use}."
17
+ )
18
+
19
+
20
+ def dataset_dir():
21
+ check_os_environ("DATASET", "data loading")
22
+ return os.environ["DATASET"]
src/models/vit/config.yml ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ # deit
3
+ deit_tiny_distilled_patch16_224:
4
+ image_size: 224
5
+ patch_size: 16
6
+ d_model: 192
7
+ n_heads: 3
8
+ n_layers: 12
9
+ normalization: deit
10
+ distilled: true
11
+ deit_small_distilled_patch16_224:
12
+ image_size: 224
13
+ patch_size: 16
14
+ d_model: 384
15
+ n_heads: 6
16
+ n_layers: 12
17
+ normalization: deit
18
+ distilled: true
19
+ deit_base_distilled_patch16_224:
20
+ image_size: 224
21
+ patch_size: 16
22
+ d_model: 768
23
+ n_heads: 12
24
+ n_layers: 12
25
+ normalization: deit
26
+ distilled: true
27
+ deit_base_distilled_patch16_384:
28
+ image_size: 384
29
+ patch_size: 16
30
+ d_model: 768
31
+ n_heads: 12
32
+ n_layers: 12
33
+ normalization: deit
34
+ distilled: true
35
+ # vit
36
+ vit_base_patch8_384:
37
+ image_size: 384
38
+ patch_size: 8
39
+ d_model: 768
40
+ n_heads: 12
41
+ n_layers: 12
42
+ normalization: vit
43
+ distilled: false
44
+ vit_tiny_patch16_384:
45
+ image_size: 384
46
+ patch_size: 16
47
+ d_model: 192
48
+ n_heads: 3
49
+ n_layers: 12
50
+ normalization: vit
51
+ distilled: false
52
+ vit_small_patch16_384:
53
+ image_size: 384
54
+ patch_size: 16
55
+ d_model: 384
56
+ n_heads: 6
57
+ n_layers: 12
58
+ normalization: vit
59
+ distilled: false
60
+ vit_base_patch16_384:
61
+ image_size: 384
62
+ patch_size: 16
63
+ d_model: 768
64
+ n_heads: 12
65
+ n_layers: 12
66
+ normalization: vit
67
+ distilled: false
68
+ vit_large_patch16_384:
69
+ image_size: 384
70
+ patch_size: 16
71
+ d_model: 1024
72
+ n_heads: 16
73
+ n_layers: 24
74
+ normalization: vit
75
+ vit_small_patch32_384:
76
+ image_size: 384
77
+ patch_size: 32
78
+ d_model: 384
79
+ n_heads: 6
80
+ n_layers: 12
81
+ normalization: vit
82
+ distilled: false
83
+ vit_base_patch32_384:
84
+ image_size: 384
85
+ patch_size: 32
86
+ d_model: 768
87
+ n_heads: 12
88
+ n_layers: 12
89
+ normalization: vit
90
+ vit_large_patch32_384:
91
+ image_size: 384
92
+ patch_size: 32
93
+ d_model: 1024
94
+ n_heads: 16
95
+ n_layers: 24
96
+ normalization: vit
97
+ decoder:
98
+ linear: {}
99
+ deeplab_dec:
100
+ encoder_layer: -1
101
+ mask_transformer:
102
+ drop_path_rate: 0.0
103
+ dropout: 0.1
104
+ n_layers: 2
105
+ dataset:
106
+ ade20k:
107
+ epochs: 64
108
+ eval_freq: 2
109
+ batch_size: 8
110
+ learning_rate: 0.001
111
+ im_size: 512
112
+ crop_size: 512
113
+ window_size: 512
114
+ window_stride: 512
115
+ pascal_context:
116
+ epochs: 256
117
+ eval_freq: 8
118
+ batch_size: 16
119
+ learning_rate: 0.001
120
+ im_size: 520
121
+ crop_size: 480
122
+ window_size: 480
123
+ window_stride: 320
124
+ cityscapes:
125
+ epochs: 216
126
+ eval_freq: 4
127
+ batch_size: 8
128
+ learning_rate: 0.01
129
+ im_size: 1024
130
+ crop_size: 768
131
+ window_size: 768
132
+ window_stride: 512
src/models/vit/decoder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+ from src.models.vit.utils import init_weights
4
+
5
+
6
+ class DecoderLinear(nn.Module):
7
+ def __init__(
8
+ self,
9
+ n_cls,
10
+ d_encoder,
11
+ scale_factor,
12
+ dropout_rate=0.3,
13
+ ):
14
+ super().__init__()
15
+ self.scale_factor = scale_factor
16
+ self.head = nn.Linear(d_encoder, n_cls)
17
+ self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear")
18
+ self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor))
19
+ self.dropout = nn.Dropout(dropout_rate)
20
+ self.gelu = nn.GELU()
21
+ self.apply(init_weights)
22
+
23
+ def forward(self, x, img_size):
24
+ H, _ = img_size
25
+ x = self.head(x) ####### (2, 577, 64)
26
+ x = x.transpose(2, 1) ## (2, 64, 576)
27
+ x = self.upsampling(x) # (2, 64, 576*scale_factor*scale_factor)
28
+ x = x.transpose(2, 1) ## (2, 576*scale_factor*scale_factor, 64)
29
+ x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor)) # (2, 64, 24*scale_factor, 24*scale_factor)
30
+ x = self.norm(x)
31
+ x = self.dropout(x)
32
+ x = self.gelu(x)
33
+
34
+ return x # (2, 64, a, a)
src/models/vit/embed.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from timm import create_model
3
+ from torchvision.transforms import Normalize
4
+
5
+ class SwinModel(nn.Module):
6
+ def __init__(self, pretrained_model="swinv2-cr-t-224", device="cuda") -> None:
7
+ """
8
+ vit_tiny_patch16_224.augreg_in21k_ft_in1k
9
+ swinv2_cr_tiny_ns_224.sw_in1k
10
+ """
11
+ super().__init__()
12
+ self.device = device
13
+ self.pretrained_model = pretrained_model
14
+ if pretrained_model == "swinv2-cr-t-224":
15
+ self.pretrained = create_model(
16
+ "swinv2_cr_tiny_ns_224.sw_in1k",
17
+ pretrained=True,
18
+ features_only=True,
19
+ out_indices=[-4, -3, -2, -1],
20
+ ).to(device)
21
+ elif pretrained_model == "swinv2-t-256":
22
+ self.pretrained = create_model(
23
+ "swinv2_tiny_window16_256.ms_in1k",
24
+ pretrained=True,
25
+ features_only=True,
26
+ out_indices=[-4, -3, -2, -1],
27
+ ).to(device)
28
+ elif pretrained_model == "swinv2-cr-s-224":
29
+ self.pretrained = create_model(
30
+ "swinv2_cr_small_ns_224.sw_in1k",
31
+ pretrained=True,
32
+ features_only=True,
33
+ out_indices=[-4, -3, -2, -1],
34
+ ).to(device)
35
+ else:
36
+ raise NotImplementedError
37
+
38
+ self.pretrained.eval()
39
+ self.normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40
+ self.upsample = nn.Upsample(scale_factor=2)
41
+
42
+ for params in self.pretrained.parameters():
43
+ params.requires_grad = False
44
+
45
+ def forward(self, x):
46
+ outputs = self.pretrained(x)
47
+ if self.pretrained_model in ["swinv2-t-256"]:
48
+ for i in range(len(outputs)):
49
+ outputs[i] = outputs[i].permute(0, 3, 1, 2) # Change channel-last to channel-first
50
+ outputs = [self.upsample(feat) for feat in outputs]
51
+
52
+ return outputs
src/models/vit/factory.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from timm.models.vision_transformer import default_cfgs
4
+ from timm.models.helpers import load_pretrained, load_custom_pretrained
5
+ from src.models.vit.utils import checkpoint_filter_fn
6
+ from src.models.vit.vit import VisionTransformer
7
+
8
+
9
+ def create_vit(model_cfg):
10
+ model_cfg = model_cfg.copy()
11
+ backbone = model_cfg.pop("backbone")
12
+
13
+ model_cfg.pop("normalization")
14
+ model_cfg["n_cls"] = 1000
15
+ mlp_expansion_ratio = 4
16
+ model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
17
+
18
+ if backbone in default_cfgs:
19
+ default_cfg = default_cfgs[backbone]
20
+ else:
21
+ default_cfg = dict(
22
+ pretrained=False,
23
+ num_classes=1000,
24
+ drop_rate=0.0,
25
+ drop_path_rate=0.0,
26
+ drop_block_rate=None,
27
+ )
28
+
29
+ default_cfg["input_size"] = (
30
+ 3,
31
+ model_cfg["image_size"][0],
32
+ model_cfg["image_size"][1],
33
+ )
34
+ model = VisionTransformer(**model_cfg)
35
+ if backbone == "vit_base_patch8_384":
36
+ path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth")
37
+ state_dict = torch.load(path, map_location="cpu")
38
+ filtered_dict = checkpoint_filter_fn(state_dict, model)
39
+ model.load_state_dict(filtered_dict, strict=True)
40
+ elif "deit" in backbone:
41
+ load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
42
+ else:
43
+ load_custom_pretrained(model, default_cfg)
44
+
45
+ return model
src/models/vit/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from timm.models.layers import trunc_normal_
6
+ from collections import OrderedDict
7
+
8
+
9
+ def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
10
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
11
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
12
+ posemb_tok, posemb_grid = (
13
+ posemb[:, :num_extra_tokens],
14
+ posemb[0, num_extra_tokens:],
15
+ )
16
+ if grid_old_shape is None:
17
+ gs_old_h = int(math.sqrt(len(posemb_grid)))
18
+ gs_old_w = gs_old_h
19
+ else:
20
+ gs_old_h, gs_old_w = grid_old_shape
21
+
22
+ gs_h, gs_w = grid_new_shape
23
+ posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
24
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
25
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
26
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
27
+ return posemb
28
+
29
+
30
+ def init_weights(m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+ elif isinstance(m, nn.LayerNorm):
36
+ nn.init.constant_(m.bias, 0)
37
+ nn.init.constant_(m.weight, 1.0)
38
+
39
+
40
+ def checkpoint_filter_fn(state_dict, model):
41
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
42
+ out_dict = {}
43
+ if "model" in state_dict:
44
+ # For deit models
45
+ state_dict = state_dict["model"]
46
+ num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
47
+ patch_size = model.patch_size
48
+ image_size = model.patch_embed.image_size
49
+ for k, v in state_dict.items():
50
+ if k == "pos_embed" and v.shape != model.pos_embed.shape:
51
+ # To resize pos embedding when using model at different size from pretrained weights
52
+ v = resize_pos_embed(
53
+ v,
54
+ None,
55
+ (image_size[0] // patch_size, image_size[1] // patch_size),
56
+ num_extra_tokens,
57
+ )
58
+ out_dict[k] = v
59
+ return out_dict
60
+
61
+ def load_params(ckpt_file, device):
62
+ # params = torch.load(ckpt_file, map_location=f'cuda:{local_rank}')
63
+ # new_params = []
64
+ # for key, value in params.items():
65
+ # new_params.append(("module."+key if has_module else key, value))
66
+ # return OrderedDict(new_params)
67
+ params = torch.load(ckpt_file, map_location=device)
68
+ new_params = []
69
+ for key, value in params.items():
70
+ new_params.append((key, value))
71
+ return OrderedDict(new_params)
src/models/vit/vit.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.models.vision_transformer import _load_weights
5
+ from timm.models.layers import trunc_normal_
6
+ from typing import List
7
+
8
+ from src.models.vit.utils import init_weights, resize_pos_embed
9
+ from src.models.vit.blocks import Block
10
+ from src.models.vit.decoder import DecoderLinear
11
+
12
+
13
+ class PatchEmbedding(nn.Module):
14
+ def __init__(self, image_size, patch_size, embed_dim, channels):
15
+ super().__init__()
16
+
17
+ self.image_size = image_size
18
+ if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
19
+ raise ValueError("image dimensions must be divisible by the patch size")
20
+ self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
21
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
22
+ self.patch_size = patch_size
23
+
24
+ self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size)
25
+
26
+ def forward(self, im):
27
+ B, C, H, W = im.shape
28
+ x = self.proj(im).flatten(2).transpose(1, 2)
29
+ return x
30
+
31
+
32
+ class VisionTransformer(nn.Module):
33
+ def __init__(
34
+ self,
35
+ image_size,
36
+ patch_size,
37
+ n_layers,
38
+ d_model,
39
+ d_ff,
40
+ n_heads,
41
+ n_cls,
42
+ dropout=0.1,
43
+ drop_path_rate=0.0,
44
+ distilled=False,
45
+ channels=3,
46
+ ):
47
+ super().__init__()
48
+ self.patch_embed = PatchEmbedding(
49
+ image_size,
50
+ patch_size,
51
+ d_model,
52
+ channels,
53
+ )
54
+ self.patch_size = patch_size
55
+ self.n_layers = n_layers
56
+ self.d_model = d_model
57
+ self.d_ff = d_ff
58
+ self.n_heads = n_heads
59
+ self.dropout = nn.Dropout(dropout)
60
+ self.n_cls = n_cls
61
+
62
+ # cls and pos tokens
63
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
64
+ self.distilled = distilled
65
+ if self.distilled:
66
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
67
+ self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 2, d_model))
68
+ self.head_dist = nn.Linear(d_model, n_cls)
69
+ else:
70
+ self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, d_model))
71
+
72
+ # transformer blocks
73
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
74
+ self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)])
75
+
76
+ # output head
77
+ self.norm = nn.LayerNorm(d_model)
78
+ self.head = nn.Linear(d_model, n_cls)
79
+
80
+ trunc_normal_(self.pos_embed, std=0.02)
81
+ trunc_normal_(self.cls_token, std=0.02)
82
+ if self.distilled:
83
+ trunc_normal_(self.dist_token, std=0.02)
84
+ self.pre_logits = nn.Identity()
85
+
86
+ self.apply(init_weights)
87
+
88
+ @torch.jit.ignore
89
+ def no_weight_decay(self):
90
+ return {"pos_embed", "cls_token", "dist_token"}
91
+
92
+ @torch.jit.ignore()
93
+ def load_pretrained(self, checkpoint_path, prefix=""):
94
+ _load_weights(self, checkpoint_path, prefix)
95
+
96
+ def forward(self, im, head_out_idx: List[int], n_dim_output=3, return_features=False):
97
+ B, _, H, W = im.shape
98
+ PS = self.patch_size
99
+ assert n_dim_output == 3 or n_dim_output == 4, "n_dim_output must be 3 or 4"
100
+ x = self.patch_embed(im)
101
+ cls_tokens = self.cls_token.expand(B, -1, -1)
102
+ if self.distilled:
103
+ dist_tokens = self.dist_token.expand(B, -1, -1)
104
+ x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
105
+ else:
106
+ x = torch.cat((cls_tokens, x), dim=1)
107
+
108
+ pos_embed = self.pos_embed
109
+ num_extra_tokens = 1 + self.distilled
110
+ if x.shape[1] != pos_embed.shape[1]:
111
+ pos_embed = resize_pos_embed(
112
+ pos_embed,
113
+ self.patch_embed.grid_size,
114
+ (H // PS, W // PS),
115
+ num_extra_tokens,
116
+ )
117
+ x = x + pos_embed
118
+ x = self.dropout(x)
119
+ device = x.device
120
+
121
+ if n_dim_output == 3:
122
+ heads_out = torch.zeros(size=(len(head_out_idx), B, (H // PS) ** 2 + 1, self.d_model)).to(device)
123
+ else:
124
+ heads_out = torch.zeros(size=(len(head_out_idx), B, self.d_model, H // PS, H // PS)).to(device)
125
+ self.register_buffer("heads_out", heads_out)
126
+
127
+ head_idx = 0
128
+ for idx_layer, blk in enumerate(self.blocks):
129
+ x = blk(x)
130
+ if idx_layer in head_out_idx:
131
+ if n_dim_output == 3:
132
+ heads_out[head_idx] = x
133
+ else:
134
+ heads_out[head_idx] = x[:, 1:, :].reshape((-1, 24, 24, self.d_model)).permute(0, 3, 1, 2)
135
+ head_idx += 1
136
+
137
+ x = self.norm(x)
138
+
139
+ if return_features:
140
+ return heads_out
141
+
142
+ if self.distilled:
143
+ x, x_dist = x[:, 0], x[:, 1]
144
+ x = self.head(x)
145
+ x_dist = self.head_dist(x_dist)
146
+ x = (x + x_dist) / 2
147
+ else:
148
+ x = x[:, 0]
149
+ x = self.head(x)
150
+ return x
151
+
152
+ def get_attention_map(self, im, layer_id):
153
+ if layer_id >= self.n_layers or layer_id < 0:
154
+ raise ValueError(f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}.")
155
+ B, _, H, W = im.shape
156
+ PS = self.patch_size
157
+
158
+ x = self.patch_embed(im)
159
+ cls_tokens = self.cls_token.expand(B, -1, -1)
160
+ if self.distilled:
161
+ dist_tokens = self.dist_token.expand(B, -1, -1)
162
+ x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
163
+ else:
164
+ x = torch.cat((cls_tokens, x), dim=1)
165
+
166
+ pos_embed = self.pos_embed
167
+ num_extra_tokens = 1 + self.distilled
168
+ if x.shape[1] != pos_embed.shape[1]:
169
+ pos_embed = resize_pos_embed(
170
+ pos_embed,
171
+ self.patch_embed.grid_size,
172
+ (H // PS, W // PS),
173
+ num_extra_tokens,
174
+ )
175
+ x = x + pos_embed
176
+
177
+ for i, blk in enumerate(self.blocks):
178
+ if i < layer_id:
179
+ x = blk(x)
180
+ else:
181
+ return blk(x, return_attention=True)
182
+
183
+
184
+ class FeatureTransform(nn.Module):
185
+ def __init__(self, img_size, d_encoder, nls_list=[128, 256, 512, 512], scale_factor_list=[8, 4, 2, 1]):
186
+ super(FeatureTransform, self).__init__()
187
+ self.img_size = img_size
188
+
189
+ self.decoder_0 = DecoderLinear(n_cls=nls_list[0], d_encoder=d_encoder, scale_factor=scale_factor_list[0])
190
+ self.decoder_1 = DecoderLinear(n_cls=nls_list[1], d_encoder=d_encoder, scale_factor=scale_factor_list[1])
191
+ self.decoder_2 = DecoderLinear(n_cls=nls_list[2], d_encoder=d_encoder, scale_factor=scale_factor_list[2])
192
+ self.decoder_3 = DecoderLinear(n_cls=nls_list[3], d_encoder=d_encoder, scale_factor=scale_factor_list[3])
193
+
194
+ def forward(self, x_list):
195
+ feat_3 = self.decoder_3(x_list[3][:, 1:, :], self.img_size) # (2, 512, 24, 24)
196
+ feat_2 = self.decoder_2(x_list[2][:, 1:, :], self.img_size) # (2, 512, 48, 48)
197
+ feat_1 = self.decoder_1(x_list[1][:, 1:, :], self.img_size) # (2, 256, 96, 96)
198
+ feat_0 = self.decoder_0(x_list[0][:, 1:, :], self.img_size) # (2, 128, 192, 192)
199
+ return feat_0, feat_1, feat_2, feat_3