amish1729 commited on
Commit
232568e
·
1 Parent(s): 9d9a461

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.configuration import Configuration
2
+ import tensorflow as tf
3
+ from utils.model import ModelLoss
4
+ from utils.model import LFUNet
5
+ from utils.architectures import UNet
6
+
7
+ import gradio as gr
8
+
9
+ configuration = Configuration()
10
+ filters = (64, 128, 128, 256, 256, 512)
11
+ kernels = (7, 7, 7, 3, 3, 3)
12
+ input_image_size = (256, 256, 3)
13
+ architecture = UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV
14
+
15
+ trained_model = LFUNet.build_model(architecture=architecture, input_size=input_image_size, filters=filters,
16
+ kernels=kernels, configuration=configuration)
17
+ trained_model.compile(
18
+ loss=ModelLoss.ms_ssim_l1_perceptual_loss,
19
+ optimizer=tf.keras.optimizers.Adam(1e-4),
20
+ metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
21
+ )
22
+
23
+ weights_path = "model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
24
+ trained_model.load_weights(weights_path)
25
+
26
+ def main(input_img):
27
+ try:
28
+ print(input_img)
29
+ predicted_image = trained_model.predict(input_img)
30
+ return predicted_image
31
+ except Exception as e:
32
+ raise gr.Error("Sorry, something went wrong. Please try again!")
33
+
34
+ demo = gr.Interface(
35
+ title= "Lightweight network for face unmasking",
36
+ description= "This is a demo of a <b>Lightweight network for face unmasking</b> \
37
+ designed to provide a powerful and efficient solution for restoring facial details obscured by masks.<br> \
38
+ To use it, simply upload your image, or click one of the examples to load them. Inference needs some time since this demo uses CPU.",
39
+ fn = main,
40
+ inputs= gr.Image(type="filepath").style(height=256),
41
+ outputs=gr.Image(type='numpy',shape=(256, 256, 3)).style(height=256),
42
+ # allow_flagging='never',
43
+ examples=[
44
+ ["examples/1.png"],
45
+ ["examples/2.png"],
46
+ ["examples/3.png"],
47
+ ["examples/4.png"],
48
+ ["examples/5.png"],
49
+ ["examples/6.png"],
50
+ ["examples/7.png"],
51
+ ["examples/8.png"],
52
+ ],
53
+ css = """
54
+ .svelte-mppz8v {
55
+ text-align: -webkit-center;
56
+ }
57
+
58
+ .gallery {
59
+ display: flex;
60
+ flex-wrap: wrap;
61
+ width: 100%;
62
+ }
63
+
64
+ p {
65
+ font-size: medium;
66
+ }
67
+
68
+ h1 {
69
+ font-size: xx-large;
70
+ }
71
+ """,
72
+ theme= 'EveryPizza/Cartoony-Gradio-Theme',
73
+ # article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.04676' target='_blank'>Simple Baselines for Image Restoration</a> | <a href='https://arxiv.org/abs/2204.08714' target='_blank'>NAFSSR: Stereo Image Super-Resolution Using NAFNet</a> | <a href='https://github.com/megvii-research/NAFNet' target='_blank'> Github Repo</a></p>"
74
+ )
75
+ demo.launch(show_error=True, share= True)
configuration.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_images_path": "data/lfw-deepfunneled",
3
+ "dataset_archive_download_url": "http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz",
4
+ "path_to_patterns": "data/mask_patterns",
5
+ "train_data_path": "data/train",
6
+ "test_data_path": "data/test",
7
+ "landmarks_predictor_path": "shape_predictor_68_face_landmarks.dat",
8
+ "landmarks_predictor_download_url": "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2",
9
+ "minimal_confidence": 0.8,
10
+ "hyp_ratio": 0.3333,
11
+ "coordinates_range": [-10, 10],
12
+ "test_image_count": 100,
13
+ "train_image_count": 15000,
14
+ "image_size": [256, 256],
15
+ "mask_type" : "random",
16
+ "mask_color" : null,
17
+ "mask_patter" : null,
18
+ "mask_pattern_weight" : 0.9,
19
+ "mask_color_weight" : 0.8,
20
+ "mask_filter_output" : false,
21
+ "mask_filter_radius" : 2,
22
+ "test_results_dir": "data/results/",
23
+ "train_data_limit": 10000,
24
+ "test_data_limit": 1000
25
+ }
environment.yaml ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unmask3
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - atk-1.0=2.36.0=ha1a6a79_0
9
+ - backcall=0.2.0=pyhd3eb1b0_0
10
+ - ca-certificates=2022.10.11=h06a4308_0
11
+ - cairo=1.16.0=h19f5f5c_2
12
+ - certifi=2022.9.24=py37h06a4308_0
13
+ - cudatoolkit=11.2.2=hbe64b41_10
14
+ - cudnn=8.1.0.77=h90431f1_0
15
+ - debugpy=1.5.1=py37h295c915_0
16
+ - decorator=5.1.1=pyhd3eb1b0_0
17
+ - entrypoints=0.4=py37h06a4308_0
18
+ - expat=2.4.9=h6a678d5_0
19
+ - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
20
+ - font-ttf-inconsolata=2.001=hcb22688_0
21
+ - font-ttf-source-code-pro=2.030=hd3eb1b0_0
22
+ - font-ttf-ubuntu=0.83=h8b1ccd4_0
23
+ - fontconfig=2.13.1=h6c09931_0
24
+ - fonts-anaconda=1=h8fa9717_0
25
+ - fonts-conda-ecosystem=1=hd3eb1b0_0
26
+ - freetype=2.11.0=h70c0345_0
27
+ - fribidi=1.0.10=h7b6447c_0
28
+ - gdk-pixbuf=2.42.8=h433bba3_0
29
+ - glib=2.69.1=h4ff587b_1
30
+ - gobject-introspection=1.72.0=py37hbb6d50b_0
31
+ - graphite2=1.3.14=h295c915_1
32
+ - graphviz=2.50.0=h3cd0ef9_0
33
+ - gtk2=2.24.33=h73c1081_2
34
+ - gts=0.7.6=hb67d8dd_3
35
+ - harfbuzz=4.3.0=hd55b92a_0
36
+ - icu=58.2=he6710b0_3
37
+ - ipython=7.31.1=py37h06a4308_1
38
+ - jedi=0.18.1=py37h06a4308_1
39
+ - jpeg=9e=h7f8727e_0
40
+ - jupyter_client=7.3.4=py37h06a4308_0
41
+ - ld_impl_linux-64=2.38=h1181459_1
42
+ - lerc=3.0=h295c915_0
43
+ - libdeflate=1.8=h7f8727e_5
44
+ - libffi=3.3=he6710b0_2
45
+ - libgcc-ng=11.2.0=h1234567_1
46
+ - libgd=2.3.3=h695aa2c_1
47
+ - libgomp=11.2.0=h1234567_1
48
+ - libpng=1.6.37=hbc83047_0
49
+ - librsvg=2.54.4=h19fe530_0
50
+ - libsodium=1.0.18=h7b6447c_0
51
+ - libstdcxx-ng=11.2.0=h1234567_1
52
+ - libtiff=4.4.0=hecacb30_0
53
+ - libtool=2.4.6=h295c915_1008
54
+ - libuuid=1.0.3=h7f8727e_2
55
+ - libwebp-base=1.2.4=h5eee18b_0
56
+ - libxcb=1.15=h7f8727e_0
57
+ - libxml2=2.9.14=h74e7548_0
58
+ - lz4-c=1.9.3=h295c915_1
59
+ - matplotlib-inline=0.1.6=py37h06a4308_0
60
+ - ncurses=6.3=h5eee18b_3
61
+ - nest-asyncio=1.5.5=py37h06a4308_0
62
+ - ninja=1.10.2=h06a4308_5
63
+ - ninja-base=1.10.2=hd09550d_5
64
+ - openssl=1.1.1q=h7f8727e_0
65
+ - pango=1.50.7=h05da053_0
66
+ - parso=0.8.3=pyhd3eb1b0_0
67
+ - pcre=8.45=h295c915_0
68
+ - pexpect=4.8.0=pyhd3eb1b0_3
69
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
70
+ - pixman=0.40.0=h7f8727e_1
71
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
72
+ - pygments=2.11.2=pyhd3eb1b0_0
73
+ - python=3.7.13=h12debd9_0
74
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
75
+ - pyzmq=23.2.0=py37h6a678d5_0
76
+ - readline=8.1.2=h7f8727e_1
77
+ - setuptools=63.4.1=py37h06a4308_0
78
+ - six=1.16.0=pyhd3eb1b0_1
79
+ - sqlite=3.39.2=h5082296_0
80
+ - tk=8.6.12=h1ccaba5_0
81
+ - tornado=6.2=py37h5eee18b_0
82
+ - wcwidth=0.2.5=pyhd3eb1b0_0
83
+ - wheel=0.37.1=pyhd3eb1b0_0
84
+ - xz=5.2.5=h7f8727e_1
85
+ - zeromq=4.3.4=h2531618_0
86
+ - zlib=1.2.12=h7f8727e_2
87
+ - zstd=1.5.2=ha4553b6_0
88
+ - pip:
89
+ - absl-py==1.2.0
90
+ - anyio==3.6.2
91
+ - argon2-cffi==21.3.0
92
+ - argon2-cffi-bindings==21.2.0
93
+ - astunparse==1.6.3
94
+ - attrs==22.2.0
95
+ - beautifulsoup4==4.11.1
96
+ - black==22.10.0
97
+ - bleach==6.0.0
98
+ - cachetools==5.2.0
99
+ - cffi==1.15.1
100
+ - charset-normalizer==2.1.1
101
+ - click==8.1.3
102
+ - cmake==3.24.1
103
+ - cycler==0.11.0
104
+ - defusedxml==0.7.1
105
+ - dlib==19.24.1
106
+ - dotmap==1.3.30
107
+ - fastjsonschema==2.16.2
108
+ - flatbuffers==1.12
109
+ - fonttools==4.37.1
110
+ - gast==0.4.0
111
+ - google-auth==2.11.0
112
+ - google-auth-oauthlib==0.4.6
113
+ - google-pasta==0.2.0
114
+ - grpcio==1.48.1
115
+ - h5py==3.7.0
116
+ - idna==3.3
117
+ - imageio==2.27.0
118
+ - importlib-metadata==4.12.0
119
+ - importlib-resources==5.10.2
120
+ - imutils==0.5.4
121
+ - ipykernel==6.16.2
122
+ - ipython-genutils==0.2.0
123
+ - ipywidgets==8.0.2
124
+ - jinja2==3.1.2
125
+ - joblib==1.1.0
126
+ - jsonschema==4.17.3
127
+ - jupyter-console==6.6.3
128
+ - jupyter-core==4.12.0
129
+ - jupyter-server==1.23.5
130
+ - jupyterlab-pygments==0.2.2
131
+ - jupyterlab-widgets==3.0.3
132
+ - jupyterthemes==0.20.0
133
+ - keras==2.9.0
134
+ - keras-applications==1.0.8
135
+ - keras-preprocessing==1.1.2
136
+ - keras-vggface==0.6
137
+ - kiwisolver==1.4.4
138
+ - lesscpy==0.15.1
139
+ - libclang==14.0.6
140
+ - markdown==3.4.1
141
+ - markupsafe==2.1.1
142
+ - matplotlib==3.5.3
143
+ - mistune==2.0.4
144
+ - mtcnn==0.1.1
145
+ - mypy-extensions==0.4.3
146
+ - nbclassic==0.4.8
147
+ - nbclient==0.7.2
148
+ - nbconvert==7.2.9
149
+ - nbformat==5.7.3
150
+ - networkx==2.6.3
151
+ - notebook==6.5.2
152
+ - notebook-shim==0.2.2
153
+ - numpy==1.21.6
154
+ - oauthlib==3.2.0
155
+ - opencv-contrib-python==4.6.0.66
156
+ - opencv-python==4.6.0.66
157
+ - opt-einsum==3.3.0
158
+ - packaging==21.3
159
+ - pandas==1.3.5
160
+ - pandocfilters==1.5.0
161
+ - pathspec==0.10.1
162
+ - pillow==9.2.0
163
+ - pip==22.3.1
164
+ - pkgutil-resolve-name==1.3.10
165
+ - platformdirs==2.5.2
166
+ - ply==3.11
167
+ - prometheus-client==0.16.0
168
+ - prompt-toolkit==3.0.38
169
+ - protobuf==3.19.4
170
+ - psutil==5.9.5
171
+ - pyasn1==0.4.8
172
+ - pyasn1-modules==0.2.8
173
+ - pycparser==2.21
174
+ - pydot==1.4.2
175
+ - pyparsing==3.0.9
176
+ - pyrsistent==0.19.3
177
+ - pytz==2022.2.1
178
+ - pywavelets==1.3.0
179
+ - pyyaml==6.0
180
+ - requests==2.28.1
181
+ - requests-oauthlib==1.3.1
182
+ - rsa==4.9
183
+ - scikit-image==0.19.3
184
+ - scikit-learn==1.0.2
185
+ - scipy==1.7.3
186
+ - seaborn==0.11.2
187
+ - send2trash==1.8.0
188
+ - sniffio==1.3.0
189
+ - soupsieve==2.3.2.post1
190
+ - tensorboard==2.9.1
191
+ - tensorboard-data-server==0.6.1
192
+ - tensorboard-plugin-wit==1.8.1
193
+ - tensorflow==2.9.2
194
+ - tensorflow-addons==0.19.0
195
+ - tensorflow-estimator==2.9.0
196
+ - tensorflow-io-gcs-filesystem==0.26.0
197
+ - termcolor==1.1.0
198
+ - terminado==0.17.1
199
+ - threadpoolctl==3.1.0
200
+ - tifffile==2021.11.2
201
+ - tinycss2==1.2.1
202
+ - tomli==2.0.1
203
+ - tqdm==4.64.1
204
+ - traitlets==5.8.1
205
+ - trianglesolver==1.2
206
+ - typed-ast==1.5.4
207
+ - typeguard==2.13.3
208
+ - typing-extensions==4.3.0
209
+ - urllib3==1.26.12
210
+ - webencodings==0.5.1
211
+ - websocket-client==1.4.2
212
+ - werkzeug==2.2.2
213
+ - widgetsnbextension==4.0.3
214
+ - wrapt==1.14.1
215
+ - zipp==3.8.1
216
+ prefix: /home/suresh/miniconda3/envs/unmask
examples/1.png ADDED
examples/2.png ADDED
examples/3.png ADDED
examples/4.png ADDED
examples/5.png ADDED
examples/6.png ADDED
examples/7.png ADDED
examples/8.png ADDED
model_weights/checkpoint ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model_checkpoint_path: "model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
2
+ all_model_checkpoint_paths: "model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt.index ADDED
Binary file (32 kB). View file
 
utils/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import requests
4
+ import functools
5
+ from tqdm.notebook import tqdm
6
+ import shutil
7
+
8
+ def image_to_array(image: Image) -> np.ndarray:
9
+ """Convert Image to array"""
10
+ return np.asarray(image).astype('uint8')
11
+
12
+
13
+ def load_image(img_path: str) -> Image:
14
+ """Load image to array"""
15
+ return Image.open(img_path)
16
+
17
+
18
+ def download_data(url, save_path, file_size=None):
19
+ """Downloads data from `url` to `save_path`"""
20
+ r = requests.get(url, stream=True, allow_redirects=True)
21
+ if r.status_code != 200:
22
+ r.raise_for_status()
23
+ raise RuntimeError(f'Request to {url} returned status code {r.status_code}')
24
+
25
+ if file_size is None:
26
+ file_size = int(r.headers.get('content-length', 0))
27
+
28
+ r.raw.read = functools.partial(r.raw.read, decode_content=True) # Decompress if needed
29
+ with tqdm.wrapattr(r.raw, 'read', total=file_size, desc='') as r_raw:
30
+ with open(save_path, 'wb') as f:
31
+ shutil.copyfileobj(r_raw, f)
32
+
33
+ def plot_image_triple():
34
+ pass
utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.41 kB). View file
 
utils/__pycache__/architectures.cpython-37.pyc ADDED
Binary file (9.89 kB). View file
 
utils/__pycache__/configuration.cpython-37.pyc ADDED
Binary file (996 Bytes). View file
 
utils/__pycache__/face_detection.cpython-37.pyc ADDED
Binary file (3.38 kB). View file
 
utils/__pycache__/model.cpython-37.pyc ADDED
Binary file (15.7 kB). View file
 
utils/architectures.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from typing import Tuple, Optional
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.keras.layers import *
7
+ from tensorflow.keras.models import *
8
+
9
+ class BaseUNet(ABC):
10
+ """
11
+ Base Interface for UNet
12
+ """
13
+
14
+ def __init__(self, model: Model):
15
+ self.model: Model = model
16
+
17
+ def get_model(self):
18
+ return self.model
19
+
20
+ @staticmethod
21
+ @abstractmethod
22
+ def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
23
+ pass
24
+
25
+
26
+ class UNet(Enum):
27
+ """
28
+ Enum class defining different architecture types available
29
+ """
30
+ DEFAULT = 0
31
+ DEFAULT_IMAGENET_EMBEDDING = 1
32
+ RESNET = 3
33
+ RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV = 4
34
+
35
+ def build_model(self, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
36
+ kernels: Optional[Tuple] = None) -> BaseUNet:
37
+
38
+ # set default filters
39
+ if filters is None:
40
+ filters = (16, 32, 64, 128, 256)
41
+
42
+ # set default kernels
43
+ if kernels is None:
44
+ kernels = list(3 for _ in range(len(filters)))
45
+
46
+ # check kernels and filters
47
+ if len(filters) != len(kernels):
48
+ raise Exception('Kernels and filter count has to match.')
49
+
50
+ if self == UNet.DEFAULT_IMAGENET_EMBEDDING:
51
+ print('Using default UNet model with imagenet embedding')
52
+ return UNetDefault.build_model(input_size, filters, kernels, use_embedding=True)
53
+ elif self == UNet.RESNET:
54
+ print('Using UNet Resnet model')
55
+ return UNet_resnet.build_model(input_size, filters, kernels)
56
+ elif self == UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV:
57
+ print('Using UNet Resnet model with attention mechanism and separable convolutions')
58
+ return UNet_ResNet_Attention_SeparableConv.build_model(input_size, filters, kernels)
59
+
60
+ print('Using default UNet model')
61
+ return UNetDefault.build_model(input_size, filters, kernels, use_embedding=False)
62
+
63
+
64
+ class Attention(Layer):
65
+ def __init__(self, **kwargs):
66
+ super(Attention, self).__init__(**kwargs)
67
+
68
+ def build(self, input_shape):
69
+ # Create a trainable weight variable for this layer.
70
+ self.kernel = self.add_weight(name='kernel',
71
+ shape=(input_shape[-1], 1),
72
+ initializer='glorot_normal',
73
+ trainable=True)
74
+ self.bias = self.add_weight(name='bias',
75
+ shape=(1,),
76
+ initializer='zeros',
77
+ trainable=True)
78
+ super(Attention, self).build(input_shape) # Be sure to call this at the end
79
+
80
+ def call(self, x):
81
+ attention = tf.nn.softmax(tf.matmul(x, self.kernel) + self.bias, axis=-1)
82
+ return tf.multiply(x, attention)
83
+
84
+ def compute_output_shape(self, input_shape):
85
+ return input_shape
86
+
87
+
88
+ class UNet_ResNet_Attention_SeparableConv(BaseUNet):
89
+ """
90
+ UNet architecture with resnet blocks, attention mechanism and separable convolutions
91
+ """
92
+ @staticmethod
93
+ def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
94
+
95
+ p0 = Input(shape=input_size)
96
+ conv_outputs = []
97
+ first_layer = SeparableConv2D(filters[0], kernels[0], padding='same')(p0)
98
+ int_layer = first_layer
99
+ for i, f in enumerate(filters):
100
+ int_layer, skip = UNet_ResNet_Attention_SeparableConv.down_block(int_layer, f, kernels[i])
101
+ conv_outputs.append(skip)
102
+
103
+ int_layer = UNet_ResNet_Attention_SeparableConv.bottleneck(int_layer, filters[-1], kernels[-1])
104
+
105
+ conv_outputs = list(reversed(conv_outputs))
106
+ reversed_filter = list(reversed(filters))
107
+ reversed_kernels = list(reversed(kernels))
108
+ for i, f in enumerate(reversed_filter):
109
+ if i + 1 < len(reversed_filter):
110
+ num_filters_next = reversed_filter[i + 1]
111
+ num_kernels_next = reversed_kernels[i + 1]
112
+ else:
113
+ num_filters_next = f
114
+ num_kernels_next = reversed_kernels[i]
115
+ int_layer = UNet_ResNet_Attention_SeparableConv.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
116
+ int_layer = Attention()(int_layer)
117
+
118
+ # concat. with the first layer
119
+ int_layer = Concatenate()([first_layer, int_layer])
120
+ int_layer = SeparableConv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
121
+ outputs = SeparableConv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
122
+ model = Model(p0, outputs)
123
+ return UNet_ResNet_Attention_SeparableConv(model)
124
+
125
+ @staticmethod
126
+ def down_block(x, num_filters: int = 64, kernel: int = 3):
127
+ # down-sample inputs
128
+ x = SeparableConv2D(num_filters, kernel, padding='same', strides=2, dilation_rate = 2)(x)
129
+
130
+ # inner block
131
+ out = SeparableConv2D(num_filters, kernel, padding='same')(x)
132
+ # out = BatchNormalization()(out)
133
+ out = Activation('relu')(out)
134
+ out = SeparableConv2D(num_filters, kernel, padding='same')(out)
135
+
136
+ # merge with the skip connection
137
+ out = Add()([out, x])
138
+ # out = BatchNormalization()(out)
139
+ return Activation('relu')(out), x
140
+
141
+ @staticmethod
142
+ def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
143
+ # add U-Net skip connection - before up-sampling
144
+ concat = Concatenate()([x, skip])
145
+
146
+ # inner block
147
+ out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(concat)
148
+ # out = BatchNormalization()(out)
149
+ out = Activation('relu')(out)
150
+ out = SeparableConv2D(num_filters, kernel, padding='same')(out)
151
+
152
+ # merge with the skip connection
153
+ out = Add()([out, x])
154
+ # out = BatchNormalization()(out)
155
+ out = Activation('relu')(out)
156
+
157
+ # up-sample
158
+ out = UpSampling2D((2, 2))(out)
159
+ out = SeparableConv2D(num_filters_next, kernel, padding='same')(out)
160
+ # out = BatchNormalization()(out)
161
+ return Activation('relu')(out)
162
+
163
+ @staticmethod
164
+ def bottleneck(x, num_filters: int = 64, kernel: int = 3):
165
+ # inner block
166
+ out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(x)
167
+ # out = BatchNormalization()(out)
168
+ out = Activation('relu')(out)
169
+ out = SeparableConv2D(num_filters, kernel, padding='same')(out)
170
+ out = Add()([out, x])
171
+ # out = BatchNormalization()(out)
172
+ return Activation('relu')(out)
173
+
174
+
175
+
176
+
177
+
178
+
179
+
180
+
181
+ # Class for UNet with Resnet blocks
182
+
183
+ class UNet_resnet(BaseUNet):
184
+ """
185
+ UNet architecture with resnet blocks
186
+ """
187
+
188
+ @staticmethod
189
+ def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
190
+
191
+ p0 = Input(shape=input_size)
192
+ conv_outputs = []
193
+ first_layer = Conv2D(filters[0], kernels[0], padding='same')(p0)
194
+ int_layer = first_layer
195
+ for i, f in enumerate(filters):
196
+ int_layer, skip = UNet_resnet.down_block(int_layer, f, kernels[i])
197
+ conv_outputs.append(skip)
198
+
199
+ int_layer = UNet_resnet.bottleneck(int_layer, filters[-1], kernels[-1])
200
+
201
+ conv_outputs = list(reversed(conv_outputs))
202
+ reversed_filter = list(reversed(filters))
203
+ reversed_kernels = list(reversed(kernels))
204
+ for i, f in enumerate(reversed_filter):
205
+ if i + 1 < len(reversed_filter):
206
+ num_filters_next = reversed_filter[i + 1]
207
+ num_kernels_next = reversed_kernels[i + 1]
208
+ else:
209
+ num_filters_next = f
210
+ num_kernels_next = reversed_kernels[i]
211
+ int_layer = UNet_resnet.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
212
+
213
+ # concat. with the first layer
214
+ int_layer = Concatenate()([first_layer, int_layer])
215
+ int_layer = Conv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
216
+ outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
217
+ model = Model(p0, outputs)
218
+ return UNet_resnet(model)
219
+
220
+ @staticmethod
221
+ def down_block(x, num_filters: int = 64, kernel: int = 3):
222
+ # down-sample inputs
223
+ x = Conv2D(num_filters, kernel, padding='same', strides=2)(x)
224
+
225
+ # inner block
226
+ out = Conv2D(num_filters, kernel, padding='same')(x)
227
+ # out = BatchNormalization()(out)
228
+ out = Activation('relu')(out)
229
+ out = Conv2D(num_filters, kernel, padding='same')(out)
230
+
231
+ # merge with the skip connection
232
+ out = Add()([out, x])
233
+ # out = BatchNormalization()(out)
234
+ return Activation('relu')(out), x
235
+
236
+ @staticmethod
237
+ def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
238
+
239
+ # add U-Net skip connection - before up-sampling
240
+ concat = Concatenate()([x, skip])
241
+
242
+ # inner block
243
+ out = Conv2D(num_filters, kernel, padding='same')(concat)
244
+ # out = BatchNormalization()(out)
245
+ out = Activation('relu')(out)
246
+ out = Conv2D(num_filters, kernel, padding='same')(out)
247
+
248
+ # merge with the skip connection
249
+ out = Add()([out, x])
250
+ # out = BatchNormalization()(out)
251
+ out = Activation('relu')(out)
252
+
253
+ # add U-Net skip connection - before up-sampling
254
+ concat = Concatenate()([out, skip])
255
+
256
+ # up-sample
257
+ # out = UpSampling2D((2, 2))(concat)
258
+ out = Conv2DTranspose(num_filters_next, kernel, padding='same', strides=2)(concat)
259
+ out = Conv2D(num_filters_next, kernel, padding='same')(out)
260
+ # out = BatchNormalization()(out)
261
+ return Activation('relu')(out)
262
+
263
+ @staticmethod
264
+ def bottleneck(x, filters, kernel: int = 3):
265
+ x = Conv2D(filters, kernel, padding='same', name='bottleneck')(x)
266
+ # x = BatchNormalization()(x)
267
+ return Activation('relu')(x)
268
+
269
+
270
+ class UNetDefault(BaseUNet):
271
+ """
272
+ UNet architecture from following github notebook for image segmentation:
273
+ https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb
274
+ https://github.com/nikhilroxtomar/Polyp-Segmentation-using-UNET-in-TensorFlow-2.0
275
+ """
276
+
277
+ @staticmethod
278
+ def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple, use_embedding: bool = True):
279
+
280
+ p0 = Input(input_size)
281
+
282
+ if use_embedding:
283
+ mobilenet_model = tf.keras.applications.MobileNetV2(
284
+ input_shape=input_size, include_top=False, weights='imagenet'
285
+ )
286
+ mobilenet_model.trainable = False
287
+ mn1 = mobilenet_model(p0)
288
+ mn1 = Reshape((16, 16, 320))(mn1)
289
+
290
+ conv_outputs = []
291
+ int_layer = p0
292
+
293
+ for f in filters:
294
+ conv_output, int_layer = UNetDefault.down_block(int_layer, f)
295
+ conv_outputs.append(conv_output)
296
+
297
+ int_layer = UNetDefault.bottleneck(int_layer, filters[-1])
298
+
299
+ if use_embedding:
300
+ int_layer = Concatenate()([int_layer, mn1])
301
+
302
+ conv_outputs = list(reversed(conv_outputs))
303
+ for i, f in enumerate(reversed(filters)):
304
+ int_layer = UNetDefault.up_block(int_layer, conv_outputs[i], f)
305
+
306
+ int_layer = Conv2D(filters[0] // 2, 3, padding="same", activation="relu")(int_layer)
307
+ outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
308
+ model = Model(p0, outputs)
309
+ return UNetDefault(model)
310
+
311
+ @staticmethod
312
+ def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
313
+ c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
314
+ # c = BatchNormalization()(c)
315
+ p = MaxPool2D((2, 2), (2, 2))(c)
316
+ return c, p
317
+
318
+ @staticmethod
319
+ def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
320
+ us = UpSampling2D((2, 2))(x)
321
+ c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(us)
322
+ # c = BatchNormalization()(c)
323
+ concat = Concatenate()([c, skip])
324
+ c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
325
+ # c = BatchNormalization()(c)
326
+ return c
327
+
328
+ @staticmethod
329
+ def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
330
+ c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
331
+ # c = BatchNormalization()(c)
332
+ return c
333
+
334
+
335
+ if __name__ == "__main__":
336
+ filters = (64, 128, 128, 256, 256, 512)
337
+ kernels = (7, 7, 7, 3, 3, 3)
338
+ input_image_size = (256, 256, 3)
339
+ # model = UNet_resnet()
340
+ # model = model.build_model(input_size=input_image_size,filters=filters,kernels=kernels)
341
+ # print(model.summary())
342
+ # __init__() missing 1 required positional argument: 'model'
343
+ model = UNetDefault.build_model(input_size=input_image_size, filters=filters, kernels=kernels)
344
+ print(model.summary())
utils/configuration.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class Configuration:
8
+ def __init__(self, config_file_path: str = "configuration.json"):
9
+ self.config_file_path = config_file_path
10
+ self.config_json = None
11
+ if os.path.exists(config_file_path):
12
+ with open(self.config_file_path, 'r') as json_file:
13
+ self.config_json = json.load(json_file)
14
+ else:
15
+ print(f'ERROR: Configuration JSON {config_file_path} does not exist.')
16
+
17
+ def get(self, key: str):
18
+ if key in self.config_json:
19
+ return self.config_json[key]
20
+ else:
21
+ print(f'ERROR: Key \'{key}\' is not in configuration JSON.')
22
+ return None
utils/data_generator.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import dlib
3
+ import os
4
+ import bz2
5
+ import random
6
+ from tqdm.notebook import tqdm
7
+ import shutil
8
+ from utils import image_to_array, load_image, download_data
9
+ from utils.face_detection import crop_face, get_face_keypoints_detecting_function
10
+ from mask_utils.mask_utils import mask_image
11
+
12
+
13
+ class DataGenerator:
14
+ def __init__(self, configuration):
15
+ self.configuration = configuration
16
+ self.path_to_data = configuration.get('input_images_path')
17
+ self.path_to_patterns = configuration.get('path_to_patterns')
18
+ self.minimal_confidence = configuration.get('minimal_confidence')
19
+ self.hyp_ratio = configuration.get('hyp_ratio')
20
+ self.coordinates_range = configuration.get('coordinates_range')
21
+ self.test_image_count = configuration.get('test_image_count')
22
+ self.train_image_count = configuration.get('train_image_count')
23
+ self.train_data_path = configuration.get('train_data_path')
24
+ self.test_data_path = configuration.get('test_data_path')
25
+ self.predictor_path = configuration.get('landmarks_predictor_path')
26
+ self.check_predictor()
27
+
28
+ self.valid_image_extensions = ('png', 'jpg', 'jpeg')
29
+ self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(self.minimal_confidence)
30
+
31
+ def check_predictor(self):
32
+ """ Check if predictor exists. If not downloads it. """
33
+ if not os.path.exists(self.predictor_path):
34
+ print('Downloading missing predictor.')
35
+ url = self.configuration.get('landmarks_predictor_download_url')
36
+ download_data(url, self.predictor_path + '.bz2', 64040097)
37
+ print(f'Decompressing downloaded file into {self.predictor_path}')
38
+ with bz2.BZ2File(self.predictor_path + '.bz2') as fr, open(self.predictor_path, 'wb') as fw:
39
+ shutil.copyfileobj(fr, fw)
40
+
41
+ def get_face_landmarks(self, image):
42
+ """Compute 68 facial landmarks"""
43
+ landmarks = []
44
+ image_array = image_to_array(image)
45
+ detector = dlib.get_frontal_face_detector()
46
+ predictor = dlib.shape_predictor(self.predictor_path)
47
+ face_rectangles = detector(image_array)
48
+ if len(face_rectangles) < 1:
49
+ return None
50
+ dlib_shape = predictor(image_array, face_rectangles[0])
51
+ for i in range(0, dlib_shape.num_parts):
52
+ landmarks.append([dlib_shape.part(i).x, dlib_shape.part(i).y])
53
+ return landmarks
54
+
55
+ def get_files_faces(self):
56
+ """Get path of all images in dataset"""
57
+ image_files = []
58
+ for dirpath, dirs, files in os.walk(self.path_to_data):
59
+ for filename in files:
60
+ fname = os.path.join(dirpath, filename)
61
+ if fname.endswith(self.valid_image_extensions):
62
+ image_files.append(fname)
63
+
64
+ return image_files
65
+
66
+ def generate_images(self, image_size=None, test_image_count=None, train_image_count=None):
67
+ """Generate test and train data (images with and without the mask)"""
68
+ if image_size is None:
69
+ image_size = self.configuration.get('image_size')
70
+ if test_image_count is None:
71
+ test_image_count = self.test_image_count
72
+ if train_image_count is None:
73
+ train_image_count = self.train_image_count
74
+
75
+ if not os.path.exists(self.train_data_path):
76
+ os.mkdir(self.train_data_path)
77
+ os.mkdir(os.path.join(self.train_data_path, 'inputs'))
78
+ os.mkdir(os.path.join(self.train_data_path, 'outputs'))
79
+
80
+ if not os.path.exists(self.test_data_path):
81
+ os.mkdir(self.test_data_path)
82
+ os.mkdir(os.path.join(self.test_data_path, 'inputs'))
83
+ os.mkdir(os.path.join(self.test_data_path, 'outputs'))
84
+
85
+ print('Generating testing data')
86
+ self.generate_data(test_image_count,
87
+ image_size=image_size,
88
+ save_to=self.test_data_path)
89
+ print('Generating training data')
90
+ self.generate_data(train_image_count,
91
+ image_size=image_size,
92
+ save_to=self.train_data_path)
93
+
94
+ def generate_data(self, number_of_images, image_size=None, save_to=None):
95
+ """ Add masks on `number_of_images` images
96
+ if save_to is valid path to folder images are saved there otherwise generated data are just returned in list
97
+ """
98
+ inputs = []
99
+ outputs = []
100
+
101
+ if image_size is None:
102
+ image_size = self.configuration.get('image_size')
103
+
104
+ for i, file in tqdm(enumerate(random.sample(self.get_files_faces(), number_of_images)), total=number_of_images):
105
+ # Load images
106
+ image = load_image(file)
107
+
108
+ # Detect keypoints and landmarks on face
109
+ face_landmarks = self.get_face_landmarks(image)
110
+ if face_landmarks is None:
111
+ continue
112
+ keypoints = self.face_keypoints_detecting_fun(image)
113
+
114
+ # Generate mask
115
+ image_with_mask = mask_image(copy.deepcopy(image), face_landmarks, self.configuration)
116
+
117
+ # Crop images
118
+ cropped_image = crop_face(image_with_mask, keypoints)
119
+ cropped_original = crop_face(image, keypoints)
120
+
121
+ # Resize all images to NN input size
122
+ res_image = cropped_image.resize(image_size)
123
+ res_original = cropped_original.resize(image_size)
124
+
125
+ # Save generated data to lists or to folder
126
+ if save_to is None:
127
+ inputs.append(res_image)
128
+ outputs.append(res_original)
129
+ else:
130
+ res_image.save(os.path.join(save_to, 'inputs', f"{i:06d}.png"))
131
+ res_original.save(os.path.join(save_to, 'outputs', f"{i:06d}.png"))
132
+
133
+ if save_to is None:
134
+ return inputs, outputs
135
+
136
+ def get_dataset_examples(self, n=10, test_dataset=False):
137
+ """
138
+ Returns `n` random images form dataset. If `test_dataset` parameter
139
+ is not provided or False it will return images from training part of dataset.
140
+ If `test_dataset` parameter is True it will return images from testing part of dataset.
141
+ """
142
+ if test_dataset:
143
+ data_path = self.test_data_path
144
+ else:
145
+ data_path = self.train_data_path
146
+
147
+ images = os.listdir(os.path.join(data_path, 'inputs'))
148
+ images = random.sample(images, n)
149
+ inputs = [os.path.join(data_path, 'inputs', img) for img in images]
150
+ outputs = [os.path.join(data_path, 'outputs', img) for img in images]
151
+ return inputs, outputs
utils/face_detection.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions for face detection"""
2
+ from math import pi
3
+ from typing import Tuple, Optional, Dict
4
+
5
+ import tensorflow as tf
6
+ import matplotlib.patches as patches
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ from mtcnn import MTCNN
10
+ from trianglesolver import solve
11
+
12
+ from utils import image_to_array
13
+
14
+
15
+ def compute_slacks(height, width, hyp_ratio) -> Tuple[float, float]:
16
+ """Compute slacks to add to bounding box on each site"""
17
+
18
+ # compute angle and side for hypotenuse
19
+ _, b, _, A, _, _ = solve(c=width, a=height, B=pi / 2)
20
+
21
+ # compute new height and width
22
+ a, _, c, _, _, _ = solve(b=b * (1.0 + hyp_ratio), B=pi / 2, A=A)
23
+
24
+ # compute slacks
25
+ return c - width, a - height
26
+
27
+
28
+ def get_face_keypoints_detecting_function(minimal_confidence: float = 0.8):
29
+ """Create function for face keypoints detection"""
30
+
31
+ # face detector
32
+ detector = MTCNN()
33
+
34
+ # detect faces and their keypoints
35
+ def get_keypoints(image: Image) -> Optional[Dict]:
36
+
37
+ # run inference to detect faces (on CPU only)
38
+ with tf.device("/cpu:0"):
39
+ detection = detector.detect_faces(image_to_array(image))
40
+
41
+ # run detection and keep results with certain confidence only
42
+ results = [item for item in detection if item['confidence'] > minimal_confidence]
43
+
44
+ # nothing found
45
+ if len(results) == 0:
46
+ return None
47
+
48
+ # return result with highest confidence and size
49
+ return max(results, key=lambda item: item['confidence'] * item['box'][2] * item['box'][3])
50
+
51
+ # return function
52
+ return get_keypoints
53
+
54
+
55
+ def plot_face_detection(image: Image, ax, face_keypoints: Optional, hyp_ratio: float = 1 / 3):
56
+ """Plot faces with keypoints and bounding boxes"""
57
+
58
+ # make annotations
59
+ if face_keypoints is not None:
60
+
61
+ # get bounding box
62
+ x, y, width, height = face_keypoints['box']
63
+
64
+ # add rectangle patch for detected face
65
+ rectangle = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
66
+ ax.add_patch(rectangle)
67
+
68
+ # add rectangle patch with slacks
69
+ w_s, h_s = compute_slacks(height, width, hyp_ratio)
70
+ rectangle = patches.Rectangle((x - w_s, y - h_s), width + 2 * w_s, height + 2 * h_s, linewidth=1, edgecolor='r',
71
+ facecolor='none')
72
+ ax.add_patch(rectangle)
73
+
74
+ # add keypoints
75
+ for coordinates in face_keypoints['keypoints'].values():
76
+ circle = plt.Circle(coordinates, 3, color='r')
77
+ ax.add_artist(circle)
78
+
79
+ # add image
80
+ ax.imshow(image)
81
+
82
+
83
+ def get_crop_points(image: Image, face_keypoints: Optional, hyp_ratio: float = 1 / 3) -> Image:
84
+ """Find position where to crop face from image"""
85
+ if face_keypoints is None:
86
+ return 0, 0, image.width, image.height
87
+
88
+ # get bounding box
89
+ x, y, width, height = face_keypoints['box']
90
+
91
+ # compute slacks
92
+ w_s, h_s = compute_slacks(height, width, hyp_ratio)
93
+
94
+ # compute coordinates
95
+ left = min(max(0, x - w_s), image.width)
96
+ upper = min(max(0, y - h_s), image.height)
97
+ right = min(x + width + w_s, image.width)
98
+ lower = min(y + height + h_s, image.height)
99
+
100
+ return left, upper, right, lower
101
+
102
+
103
+ def crop_face(image: Image, face_keypoints: Optional, hyp_ratio: float = 1 / 3) -> Image:
104
+ """Crop input image to just the face"""
105
+ if face_keypoints is None:
106
+ print("No keypoints detected on image")
107
+ return image
108
+
109
+ left, upper, right, lower = get_crop_points(image, face_keypoints, hyp_ratio)
110
+
111
+ return image.crop((left, upper, right, lower))
utils/model.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from glob import glob
4
+ from typing import Tuple, Optional
5
+ from utils import load_image
6
+ import random
7
+ import cv2
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from PIL import Image
11
+ from sklearn.model_selection import train_test_split
12
+ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
13
+ from tensorflow.keras.utils import CustomObjectScope
14
+ from utils.face_detection import get_face_keypoints_detecting_function, crop_face, get_crop_points
15
+ from utils.architectures import UNet
16
+ from tensorflow.keras.losses import MeanSquaredError, mean_squared_error
17
+ from keras_vggface.vggface import VGGFace
18
+ import tensorflow.keras.backend as K
19
+ from tensorflow.keras.applications import VGG19
20
+
21
+ # # VGG19 model for perceptual loss
22
+ # vgg = VGG19(include_top=False, weights='imagenet')
23
+
24
+ # def preprocess_image(image):
25
+ # image = tf.image.resize(image, (224, 224))
26
+ # image = tf.keras.applications.vgg19.preprocess_input(image)
27
+ # return image
28
+
29
+ # def perceptual_loss(y_true, y_pred):
30
+ # y_true = preprocess_image(y_true)
31
+ # y_pred = preprocess_image(y_pred)
32
+ # y_true_c = vgg(y_true)
33
+ # y_pred_c = vgg(y_pred)
34
+ # loss = K.mean(K.square(y_pred_c - y_true_c))
35
+ # return loss
36
+
37
+ vgg_face_model = VGGFace(model='resnet50', include_top=False, input_shape=(256, 256, 3), pooling='avg')
38
+
39
+
40
+ class ModelLoss:
41
+ @staticmethod
42
+ @tf.function
43
+ def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
44
+ """
45
+ Computes MS-SSIM and perceptual loss
46
+ @param gt: Ground truth image
47
+ @param y_pred: Predicted image
48
+ @param max_val: Maximal MS-SSIM value
49
+ @param l1_weight: Weight of L1 normalization
50
+ @return: MS-SSIM and perceptual loss
51
+ """
52
+
53
+ # Compute SSIM loss
54
+ ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
55
+
56
+ # Compute perceptual loss
57
+ vgg_face_outputs = vgg_face_model(y_pred)
58
+ vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt)))
59
+
60
+
61
+ # Combine both losses with l1 normalization
62
+ l1 = mean_squared_error(gt, y_pred)
63
+ l1_casted = tf.cast(l1 * l1_weight, tf.float32)
64
+ return ssim_loss + l1_casted + vgg_face_loss
65
+
66
+
67
+
68
+ class LFUNet(tf.keras.models.Model):
69
+ """
70
+ Model for Mask2Face - removes mask from people faces using U-net neural network
71
+ """
72
+ def __init__(self, model: tf.keras.models.Model, configuration=None, *args, **kwargs):
73
+ super().__init__(*args, **kwargs)
74
+ self.model: tf.keras.models.Model = model
75
+ self.configuration = configuration
76
+ self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(0.8)
77
+ self.mse = MeanSquaredError()
78
+
79
+ def call(self, x, **kwargs):
80
+ return self.model(x)
81
+
82
+ @staticmethod
83
+ @tf.function
84
+ def ssim_loss(gt, y_pred, max_val=1.0):
85
+ """
86
+ Computes standard SSIM loss
87
+ @param gt: Ground truth image
88
+ @param y_pred: Predicted image
89
+ @param max_val: Maximal SSIM value
90
+ @return: SSIM loss
91
+ """
92
+ return 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
93
+
94
+ @staticmethod
95
+ @tf.function
96
+ def ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
97
+ """
98
+ Computes SSIM loss with L1 normalization
99
+ @param gt: Ground truth image
100
+ @param y_pred: Predicted image
101
+ @param max_val: Maximal SSIM value
102
+ @param l1_weight: Weight of L1 normalization
103
+ @return: SSIM L1 loss
104
+ """
105
+ ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
106
+ l1 = mean_squared_error(gt, y_pred)
107
+ return ssim_loss + tf.cast(l1 * l1_weight, tf.float32)
108
+
109
+
110
+
111
+ # @staticmethod
112
+ # @tf.function
113
+ # def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0, perceptual_weight=1.0):
114
+ # """
115
+ # Computes MS-SSIM loss, L1 loss, and perceptual loss
116
+ # @param gt: Ground truth image
117
+ # @param y_pred: Predicted image
118
+ # @param max_val: Maximal SSIM value
119
+ # @param l1_weight: Weight of L1 normalization
120
+ # @param perceptual_weight: Weight of perceptual loss
121
+ # @return: MS-SSIM L1 perceptual loss
122
+ # """
123
+ # y_pred = tf.clip_by_value(y_pred, 0, float("inf"))
124
+ # y_pred = tf.debugging.check_numerics(y_pred, message='y_pred has NaN values')
125
+ # ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val))
126
+ # l1_loss = tf.losses.mean_absolute_error(gt, y_pred)
127
+ # vgg_face_outputs = vgg_face_model(y_pred)
128
+ # vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt)))
129
+ # return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32) + perceptual_weight*vgg_face_loss
130
+
131
+ # Function for ms-ssim loss + l1 loss
132
+ @staticmethod
133
+ @tf.function
134
+ def ms_ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
135
+ """
136
+ Computes MS-SSIM loss and L1 loss
137
+ @param gt: Ground truth image
138
+ @param y_pred: Predicted image
139
+ @param max_val: Maximal SSIM value
140
+ @param l1_weight: Weight of L1 normalization
141
+ @return: MS-SSIM L1 loss
142
+ """
143
+ # Replace NaN values with 0
144
+ y_pred = tf.clip_by_value(y_pred, 0, float("inf"))
145
+
146
+ ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val))
147
+ l1_loss = tf.losses.mean_absolute_error(gt, y_pred)
148
+ return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32)
149
+
150
+
151
+ @staticmethod
152
+ def load_model(model_path, configuration=None):
153
+ """
154
+ Loads saved h5 file with trained model.
155
+ @param configuration: Optional instance of Configuration with config JSON
156
+ @param model_path: Path to h5 file
157
+ @return: LFUNet
158
+ """
159
+ with CustomObjectScope({'ssim_loss': LFUNet.ssim_loss, 'ssim_l1_loss': LFUNet.ssim_l1_loss, 'ms_ssim_l1_perceptual_loss': ModelLoss.ms_ssim_l1_perceptual_loss, 'ms_ssim_l1_loss': LFUNet.ms_ssim_l1_loss}):
160
+ model = tf.keras.models.load_model(model_path)
161
+ return LFUNet(model, configuration)
162
+
163
+ @staticmethod
164
+ def build_model(architecture: UNet, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
165
+ kernels: Optional[Tuple] = None, configuration=None):
166
+ """
167
+ Builds model based on input arguments
168
+ @param architecture: utils.architectures.UNet architecture
169
+ @param input_size: Size of input images
170
+ @param filters: Tuple with sizes of filters in U-net
171
+ @param kernels: Tuple with sizes of kernels in U-net. Must be the same size as filters.
172
+ @param configuration: Optional instance of Configuration with config JSON
173
+ @return: LFUNet
174
+ """
175
+ return LFUNet(architecture.build_model(input_size, filters, kernels).get_model(), configuration)
176
+
177
+ def train(self, epochs=20, batch_size=20, loss_function='mse', learning_rate=1e-4,
178
+ predict_difference: bool = False):
179
+ """
180
+ Train the model.
181
+ @param epochs: Number of epochs during training
182
+ @param batch_size: Batch size
183
+ @param loss_function: Loss function. Either standard tensorflow loss function or `ssim_loss` or `ssim_l1_loss`
184
+ @param learning_rate: Learning rate
185
+ @param predict_difference: Compute prediction on difference between input and output image
186
+ @return: History of training
187
+ """
188
+ # get data
189
+ (train_x, train_y), (valid_x, valid_y) = self.load_train_data()
190
+ (test_x, test_y) = self.load_test_data()
191
+
192
+ train_dataset = LFUNet.tf_dataset(train_x, train_y, batch_size, predict_difference)
193
+ valid_dataset = LFUNet.tf_dataset(valid_x, valid_y, batch_size, predict_difference, train=False)
194
+ test_dataset = LFUNet.tf_dataset(test_x, test_y, batch_size, predict_difference, train=False)
195
+
196
+ # select loss
197
+ if loss_function == 'ssim_loss':
198
+ loss = LFUNet.ssim_loss
199
+ elif loss_function == 'ssim_l1_loss':
200
+ loss = LFUNet.ssim_l1_loss
201
+ elif loss_function == 'ms_ssim_l1_perceptual_loss':
202
+ loss = ModelLoss.ms_ssim_l1_perceptual_loss
203
+ elif loss_function == 'ms_ssim_l1_loss':
204
+ loss = LFUNet.ms_ssim_l1_loss
205
+ else:
206
+ loss = loss_function
207
+
208
+ # compile loss with selected loss function
209
+ self.model.compile(
210
+ loss=loss,
211
+ optimizer=tf.keras.optimizers.Adam(learning_rate),
212
+ metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
213
+ )
214
+
215
+ # define callbacks
216
+ callbacks = [
217
+ ModelCheckpoint(
218
+ f'models/model_epochs-{epochs}_batch-{batch_size}_loss-{loss_function}_{LFUNet.get_datetime_string()}.h5'),
219
+ EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
220
+ ]
221
+
222
+ # evaluation before training
223
+ results = self.model.evaluate(test_dataset)
224
+ print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results))
225
+
226
+ # fit the model
227
+ history = self.model.fit(train_dataset, validation_data=valid_dataset, epochs=epochs, callbacks=callbacks)
228
+
229
+ # evaluation after training
230
+ results = self.model.evaluate(test_dataset)
231
+ print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results))
232
+
233
+ # use the model for inference on several test images
234
+ self._test_results(test_x, test_y, predict_difference)
235
+
236
+ # return training history
237
+ return history
238
+
239
+ def _test_results(self, test_x, test_y, predict_difference: bool):
240
+ """
241
+ Test trained model on testing dataset. All images in testing dataset are processed and result image triples
242
+ (input with mask, ground truth, model output) are stored to `data/results` into folder with time stamp
243
+ when this method was executed.
244
+ @param test_x: List of input images
245
+ @param test_y: List of ground truth output images
246
+ @param predict_difference: Compute prediction on difference between input and output image
247
+ @return: None
248
+ """
249
+ if self.configuration is None:
250
+ result_dir = f'data/results/{LFUNet.get_datetime_string()}/'
251
+ else:
252
+ result_dir = os.path.join(self.configuration.get('test_results_dir'), LFUNet.get_datetime_string())
253
+ os.makedirs(result_dir, exist_ok=True)
254
+
255
+ for i, (x, y) in enumerate(zip(test_x, test_y)):
256
+ x = LFUNet.read_image(x)
257
+ y = LFUNet.read_image(y)
258
+
259
+ y_pred = self.model.predict(np.expand_dims(x, axis=0))
260
+ if predict_difference:
261
+ y_pred = (y_pred * 2) - 1
262
+ y_pred = np.clip(x - y_pred.squeeze(axis=0), 0.0, 1.0)
263
+ else:
264
+ y_pred = y_pred.squeeze(axis=0)
265
+ h, w, _ = x.shape
266
+ white_line = np.ones((h, 10, 3)) * 255.0
267
+
268
+ all_images = [
269
+ x * 255.0, white_line,
270
+ y * 255.0, white_line,
271
+ y_pred * 255.0
272
+ ]
273
+ image = np.concatenate(all_images, axis=1)
274
+ cv2.imwrite(os.path.join(result_dir, f"{i}.png"), image)
275
+
276
+ def summary(self):
277
+ """
278
+ Prints model summary
279
+ """
280
+ self.model.summary()
281
+
282
+ def predict(self, img_path, predict_difference: bool = False):
283
+ """
284
+ Use trained model to take down the mask from image with person wearing the mask.
285
+ @param img_path: Path to image to processed
286
+ @param predict_difference: Compute prediction on difference between input and output image
287
+ @return: Image without the mask on the face
288
+ """
289
+ # Load image into RGB format
290
+ image = load_image(img_path)
291
+ image = image.convert('RGB')
292
+
293
+ # Find facial keypoints and crop the image to just the face
294
+ keypoints = self.face_keypoints_detecting_fun(image)
295
+ cropped_image = crop_face(image, keypoints)
296
+ print(cropped_image.size)
297
+
298
+ # Resize image to input recognized by neural net
299
+ resized_image = cropped_image.resize((256, 256))
300
+ image_array = np.array(resized_image)
301
+
302
+ # Convert from RGB to BGR (open cv format)
303
+ image_array = image_array[:, :, ::-1].copy()
304
+ image_array = image_array / 255.0
305
+
306
+ # Remove mask from input image
307
+ y_pred = self.model.predict(np.expand_dims(image_array, axis=0))
308
+ h, w, _ = image_array.shape
309
+
310
+ if predict_difference:
311
+ y_pred = (y_pred * 2) - 1
312
+ y_pred = np.clip(image_array - y_pred.squeeze(axis=0), 0.0, 1.0)
313
+ else:
314
+ y_pred = y_pred.squeeze(axis=0)
315
+
316
+ # Convert output from model to image and scale it back to original size
317
+ y_pred = y_pred * 255.0
318
+ im = Image.fromarray(y_pred.astype(np.uint8)[:, :, ::-1])
319
+ im = im.resize(cropped_image.size)
320
+ left, upper, _, _ = get_crop_points(image, keypoints)
321
+
322
+ # Combine original image with output from model
323
+ image.paste(im, (int(left), int(upper)))
324
+ return image
325
+
326
+ @staticmethod
327
+ def get_datetime_string():
328
+ """
329
+ Creates date-time string
330
+ @return: String with current date and time
331
+ """
332
+ now = datetime.now()
333
+ return now.strftime("%Y%m%d_%H_%M_%S")
334
+
335
+ def load_train_data(self, split=0.2):
336
+ """
337
+ Loads training data (paths to training images)
338
+ @param split: Percentage of training data used for validation as float from 0.0 to 1.0. Default 0.2.
339
+ @return: Two tuples - first with training data (tuple with (input images, output images)) and second
340
+ with validation data (tuple with (input images, output images))
341
+ """
342
+ if self.configuration is None:
343
+ train_dir = 'data/train/'
344
+ limit = None
345
+ else:
346
+ train_dir = self.configuration.get('train_data_path')
347
+ limit = self.configuration.get('train_data_limit')
348
+ print(f'Loading training data from {train_dir} with limit of {limit} images')
349
+ return LFUNet.load_data(os.path.join(train_dir, 'inputs'), os.path.join(train_dir, 'outputs'), split, limit)
350
+
351
+ def load_test_data(self):
352
+ """
353
+ Loads testing data (paths to testing images)
354
+ @return: Tuple with testing data - (input images, output images)
355
+ """
356
+ if self.configuration is None:
357
+ test_dir = 'data/test/'
358
+ limit = None
359
+ else:
360
+ test_dir = self.configuration.get('test_data_path')
361
+ limit = self.configuration.get('test_data_limit')
362
+ print(f'Loading testing data from {test_dir} with limit of {limit} images')
363
+ return LFUNet.load_data(os.path.join(test_dir, 'inputs'), os.path.join(test_dir, 'outputs'), None, limit)
364
+
365
+ @staticmethod
366
+ def load_data(input_path, output_path, split=0.2, limit=None):
367
+ """
368
+ Loads data (paths to images)
369
+ @param input_path: Path to folder with input images
370
+ @param output_path: Path to folder with output images
371
+ @param split: Percentage of data used for validation as float from 0.0 to 1.0. Default 0.2.
372
+ If split is None it expects you are loading testing data, otherwise expects training data.
373
+ @param limit: Maximal number of images loaded from data folder. Default None (no limit).
374
+ @return: If split is not None: Two tuples - first with training data (tuple with (input images, output images))
375
+ and second with validation data (tuple with (input images, output images))
376
+ Else: Tuple with testing data - (input images, output images)
377
+ """
378
+ images = sorted(glob(os.path.join(input_path, "*.png")))
379
+ masks = sorted(glob(os.path.join(output_path, "*.png")))
380
+ if len(images) == 0:
381
+ raise TypeError(f'No images found in {input_path}')
382
+ if len(masks) == 0:
383
+ raise TypeError(f'No images found in {output_path}')
384
+
385
+ if limit is not None:
386
+ images = images[:limit]
387
+ masks = masks[:limit]
388
+
389
+ if split is not None:
390
+ total_size = len(images)
391
+ valid_size = int(split * total_size)
392
+ train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42)
393
+ train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42)
394
+ return (train_x, train_y), (valid_x, valid_y)
395
+
396
+ else:
397
+ return images, masks
398
+
399
+ @staticmethod
400
+ def read_image(path):
401
+ """
402
+ Loads image, resize it to size 256x256 and normalize to float values from 0.0 to 1.0.
403
+ @param path: Path to image to be loaded.
404
+ @return: Loaded image in open CV format.
405
+ """
406
+ x = cv2.imread(path, cv2.IMREAD_COLOR)
407
+ x = cv2.resize(x, (256, 256))
408
+ x = x / 255.0
409
+ return x
410
+
411
+ @staticmethod
412
+ def tf_parse(x, y):
413
+ """
414
+ Mapping function for dataset creation. Load and resize images.
415
+ @param x: Path to input image
416
+ @param y: Path to output image
417
+ @return: Tuple with input and output image with shape (256, 256, 3)
418
+ """
419
+ def _parse(x, y):
420
+ x = LFUNet.read_image(x.decode())
421
+ y = LFUNet.read_image(y.decode())
422
+ return x, y
423
+
424
+ x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
425
+ x.set_shape([256, 256, 3])
426
+ y.set_shape([256, 256, 3])
427
+ return x, y
428
+
429
+ @staticmethod
430
+ def tf_dataset(x, y, batch=8, predict_difference: bool = False, train: bool = True):
431
+ """
432
+ Creates standard tensorflow dataset.
433
+ @param x: List of paths to input images
434
+ @param y: List of paths to output images
435
+ @param batch: Batch size
436
+ @param predict_difference: Compute prediction on difference between input and output image
437
+ @param train: Flag if training dataset should be generated
438
+ @return: Dataset with loaded images
439
+ """
440
+ dataset = tf.data.Dataset.from_tensor_slices((x, y))
441
+ dataset = dataset.map(LFUNet.tf_parse)
442
+ random_seed = random.randint(0, 999999999)
443
+
444
+ if predict_difference:
445
+ def map_output(img_in, img_target):
446
+ return img_in, (img_in - img_target + 1.0) / 2.0
447
+
448
+ dataset = dataset.map(map_output)
449
+
450
+ if train:
451
+ # for the train set, we want to apply data augmentations and shuffle data to different batches
452
+
453
+ # random flip
454
+ def flip(img_in, img_out):
455
+ return tf.image.random_flip_left_right(img_in, random_seed), \
456
+ tf.image.random_flip_left_right(img_out, random_seed)
457
+
458
+ # augmenting quality - parameters
459
+ hue_delta = 0.05
460
+ saturation_low = 0.2
461
+ saturation_up = 1.3
462
+ brightness_delta = 0.1
463
+ contrast_low = 0.2
464
+ contrast_up = 1.5
465
+
466
+ # augmenting quality
467
+ def color(img_in, img_out):
468
+ # Augmentations applied are:
469
+ # - random hue
470
+ # - random saturation
471
+ # - random brightness
472
+ # - random contrast
473
+ # - random flip left right
474
+ # - random flip up down
475
+ img_in = tf.image.random_hue(img_in, hue_delta, random_seed)
476
+ img_in = tf.image.random_saturation(img_in, saturation_low, saturation_up, random_seed)
477
+ img_in = tf.image.random_brightness(img_in, brightness_delta, random_seed)
478
+ img_in = tf.image.random_contrast(img_in, contrast_low, contrast_up, random_seed)
479
+ img_out = tf.image.random_hue(img_out, hue_delta, random_seed)
480
+ img_out = tf.image.random_saturation(img_out, saturation_low, saturation_up, random_seed)
481
+ img_out = tf.image.random_brightness(img_out, brightness_delta, random_seed)
482
+ img_out = tf.image.random_contrast(img_out, contrast_low, contrast_up, random_seed)
483
+ return img_in, img_out
484
+
485
+ # shuffle data and create batches
486
+ dataset = dataset.shuffle(5000)
487
+ dataset = dataset.batch(batch)
488
+
489
+ # apply augmentations
490
+ dataset = dataset.map(flip)
491
+ dataset = dataset.map(color)
492
+ else:
493
+ dataset = dataset.batch(batch)
494
+
495
+ return dataset.prefetch(tf.data.experimental.AUTOTUNE)