Initial commit
Browse files- app.py +75 -0
- configuration.json +25 -0
- environment.yaml +216 -0
- examples/1.png +0 -0
- examples/2.png +0 -0
- examples/3.png +0 -0
- examples/4.png +0 -0
- examples/5.png +0 -0
- examples/6.png +0 -0
- examples/7.png +0 -0
- examples/8.png +0 -0
- model_weights/checkpoint +2 -0
- model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt.index +0 -0
- utils/__init__.py +34 -0
- utils/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/__pycache__/architectures.cpython-37.pyc +0 -0
- utils/__pycache__/configuration.cpython-37.pyc +0 -0
- utils/__pycache__/face_detection.cpython-37.pyc +0 -0
- utils/__pycache__/model.cpython-37.pyc +0 -0
- utils/architectures.py +344 -0
- utils/configuration.py +22 -0
- utils/data_generator.py +151 -0
- utils/face_detection.py +111 -0
- utils/model.py +495 -0
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.configuration import Configuration
|
2 |
+
import tensorflow as tf
|
3 |
+
from utils.model import ModelLoss
|
4 |
+
from utils.model import LFUNet
|
5 |
+
from utils.architectures import UNet
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
configuration = Configuration()
|
10 |
+
filters = (64, 128, 128, 256, 256, 512)
|
11 |
+
kernels = (7, 7, 7, 3, 3, 3)
|
12 |
+
input_image_size = (256, 256, 3)
|
13 |
+
architecture = UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV
|
14 |
+
|
15 |
+
trained_model = LFUNet.build_model(architecture=architecture, input_size=input_image_size, filters=filters,
|
16 |
+
kernels=kernels, configuration=configuration)
|
17 |
+
trained_model.compile(
|
18 |
+
loss=ModelLoss.ms_ssim_l1_perceptual_loss,
|
19 |
+
optimizer=tf.keras.optimizers.Adam(1e-4),
|
20 |
+
metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
|
21 |
+
)
|
22 |
+
|
23 |
+
weights_path = "model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
|
24 |
+
trained_model.load_weights(weights_path)
|
25 |
+
|
26 |
+
def main(input_img):
|
27 |
+
try:
|
28 |
+
print(input_img)
|
29 |
+
predicted_image = trained_model.predict(input_img)
|
30 |
+
return predicted_image
|
31 |
+
except Exception as e:
|
32 |
+
raise gr.Error("Sorry, something went wrong. Please try again!")
|
33 |
+
|
34 |
+
demo = gr.Interface(
|
35 |
+
title= "Lightweight network for face unmasking",
|
36 |
+
description= "This is a demo of a <b>Lightweight network for face unmasking</b> \
|
37 |
+
designed to provide a powerful and efficient solution for restoring facial details obscured by masks.<br> \
|
38 |
+
To use it, simply upload your image, or click one of the examples to load them. Inference needs some time since this demo uses CPU.",
|
39 |
+
fn = main,
|
40 |
+
inputs= gr.Image(type="filepath").style(height=256),
|
41 |
+
outputs=gr.Image(type='numpy',shape=(256, 256, 3)).style(height=256),
|
42 |
+
# allow_flagging='never',
|
43 |
+
examples=[
|
44 |
+
["examples/1.png"],
|
45 |
+
["examples/2.png"],
|
46 |
+
["examples/3.png"],
|
47 |
+
["examples/4.png"],
|
48 |
+
["examples/5.png"],
|
49 |
+
["examples/6.png"],
|
50 |
+
["examples/7.png"],
|
51 |
+
["examples/8.png"],
|
52 |
+
],
|
53 |
+
css = """
|
54 |
+
.svelte-mppz8v {
|
55 |
+
text-align: -webkit-center;
|
56 |
+
}
|
57 |
+
|
58 |
+
.gallery {
|
59 |
+
display: flex;
|
60 |
+
flex-wrap: wrap;
|
61 |
+
width: 100%;
|
62 |
+
}
|
63 |
+
|
64 |
+
p {
|
65 |
+
font-size: medium;
|
66 |
+
}
|
67 |
+
|
68 |
+
h1 {
|
69 |
+
font-size: xx-large;
|
70 |
+
}
|
71 |
+
""",
|
72 |
+
theme= 'EveryPizza/Cartoony-Gradio-Theme',
|
73 |
+
# article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.04676' target='_blank'>Simple Baselines for Image Restoration</a> | <a href='https://arxiv.org/abs/2204.08714' target='_blank'>NAFSSR: Stereo Image Super-Resolution Using NAFNet</a> | <a href='https://github.com/megvii-research/NAFNet' target='_blank'> Github Repo</a></p>"
|
74 |
+
)
|
75 |
+
demo.launch(show_error=True, share= True)
|
configuration.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"input_images_path": "data/lfw-deepfunneled",
|
3 |
+
"dataset_archive_download_url": "http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz",
|
4 |
+
"path_to_patterns": "data/mask_patterns",
|
5 |
+
"train_data_path": "data/train",
|
6 |
+
"test_data_path": "data/test",
|
7 |
+
"landmarks_predictor_path": "shape_predictor_68_face_landmarks.dat",
|
8 |
+
"landmarks_predictor_download_url": "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2",
|
9 |
+
"minimal_confidence": 0.8,
|
10 |
+
"hyp_ratio": 0.3333,
|
11 |
+
"coordinates_range": [-10, 10],
|
12 |
+
"test_image_count": 100,
|
13 |
+
"train_image_count": 15000,
|
14 |
+
"image_size": [256, 256],
|
15 |
+
"mask_type" : "random",
|
16 |
+
"mask_color" : null,
|
17 |
+
"mask_patter" : null,
|
18 |
+
"mask_pattern_weight" : 0.9,
|
19 |
+
"mask_color_weight" : 0.8,
|
20 |
+
"mask_filter_output" : false,
|
21 |
+
"mask_filter_radius" : 2,
|
22 |
+
"test_results_dir": "data/results/",
|
23 |
+
"train_data_limit": 10000,
|
24 |
+
"test_data_limit": 1000
|
25 |
+
}
|
environment.yaml
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: unmask3
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- atk-1.0=2.36.0=ha1a6a79_0
|
9 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
10 |
+
- ca-certificates=2022.10.11=h06a4308_0
|
11 |
+
- cairo=1.16.0=h19f5f5c_2
|
12 |
+
- certifi=2022.9.24=py37h06a4308_0
|
13 |
+
- cudatoolkit=11.2.2=hbe64b41_10
|
14 |
+
- cudnn=8.1.0.77=h90431f1_0
|
15 |
+
- debugpy=1.5.1=py37h295c915_0
|
16 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
17 |
+
- entrypoints=0.4=py37h06a4308_0
|
18 |
+
- expat=2.4.9=h6a678d5_0
|
19 |
+
- font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
|
20 |
+
- font-ttf-inconsolata=2.001=hcb22688_0
|
21 |
+
- font-ttf-source-code-pro=2.030=hd3eb1b0_0
|
22 |
+
- font-ttf-ubuntu=0.83=h8b1ccd4_0
|
23 |
+
- fontconfig=2.13.1=h6c09931_0
|
24 |
+
- fonts-anaconda=1=h8fa9717_0
|
25 |
+
- fonts-conda-ecosystem=1=hd3eb1b0_0
|
26 |
+
- freetype=2.11.0=h70c0345_0
|
27 |
+
- fribidi=1.0.10=h7b6447c_0
|
28 |
+
- gdk-pixbuf=2.42.8=h433bba3_0
|
29 |
+
- glib=2.69.1=h4ff587b_1
|
30 |
+
- gobject-introspection=1.72.0=py37hbb6d50b_0
|
31 |
+
- graphite2=1.3.14=h295c915_1
|
32 |
+
- graphviz=2.50.0=h3cd0ef9_0
|
33 |
+
- gtk2=2.24.33=h73c1081_2
|
34 |
+
- gts=0.7.6=hb67d8dd_3
|
35 |
+
- harfbuzz=4.3.0=hd55b92a_0
|
36 |
+
- icu=58.2=he6710b0_3
|
37 |
+
- ipython=7.31.1=py37h06a4308_1
|
38 |
+
- jedi=0.18.1=py37h06a4308_1
|
39 |
+
- jpeg=9e=h7f8727e_0
|
40 |
+
- jupyter_client=7.3.4=py37h06a4308_0
|
41 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
42 |
+
- lerc=3.0=h295c915_0
|
43 |
+
- libdeflate=1.8=h7f8727e_5
|
44 |
+
- libffi=3.3=he6710b0_2
|
45 |
+
- libgcc-ng=11.2.0=h1234567_1
|
46 |
+
- libgd=2.3.3=h695aa2c_1
|
47 |
+
- libgomp=11.2.0=h1234567_1
|
48 |
+
- libpng=1.6.37=hbc83047_0
|
49 |
+
- librsvg=2.54.4=h19fe530_0
|
50 |
+
- libsodium=1.0.18=h7b6447c_0
|
51 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
52 |
+
- libtiff=4.4.0=hecacb30_0
|
53 |
+
- libtool=2.4.6=h295c915_1008
|
54 |
+
- libuuid=1.0.3=h7f8727e_2
|
55 |
+
- libwebp-base=1.2.4=h5eee18b_0
|
56 |
+
- libxcb=1.15=h7f8727e_0
|
57 |
+
- libxml2=2.9.14=h74e7548_0
|
58 |
+
- lz4-c=1.9.3=h295c915_1
|
59 |
+
- matplotlib-inline=0.1.6=py37h06a4308_0
|
60 |
+
- ncurses=6.3=h5eee18b_3
|
61 |
+
- nest-asyncio=1.5.5=py37h06a4308_0
|
62 |
+
- ninja=1.10.2=h06a4308_5
|
63 |
+
- ninja-base=1.10.2=hd09550d_5
|
64 |
+
- openssl=1.1.1q=h7f8727e_0
|
65 |
+
- pango=1.50.7=h05da053_0
|
66 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
67 |
+
- pcre=8.45=h295c915_0
|
68 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
69 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
70 |
+
- pixman=0.40.0=h7f8727e_1
|
71 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
72 |
+
- pygments=2.11.2=pyhd3eb1b0_0
|
73 |
+
- python=3.7.13=h12debd9_0
|
74 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
75 |
+
- pyzmq=23.2.0=py37h6a678d5_0
|
76 |
+
- readline=8.1.2=h7f8727e_1
|
77 |
+
- setuptools=63.4.1=py37h06a4308_0
|
78 |
+
- six=1.16.0=pyhd3eb1b0_1
|
79 |
+
- sqlite=3.39.2=h5082296_0
|
80 |
+
- tk=8.6.12=h1ccaba5_0
|
81 |
+
- tornado=6.2=py37h5eee18b_0
|
82 |
+
- wcwidth=0.2.5=pyhd3eb1b0_0
|
83 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
84 |
+
- xz=5.2.5=h7f8727e_1
|
85 |
+
- zeromq=4.3.4=h2531618_0
|
86 |
+
- zlib=1.2.12=h7f8727e_2
|
87 |
+
- zstd=1.5.2=ha4553b6_0
|
88 |
+
- pip:
|
89 |
+
- absl-py==1.2.0
|
90 |
+
- anyio==3.6.2
|
91 |
+
- argon2-cffi==21.3.0
|
92 |
+
- argon2-cffi-bindings==21.2.0
|
93 |
+
- astunparse==1.6.3
|
94 |
+
- attrs==22.2.0
|
95 |
+
- beautifulsoup4==4.11.1
|
96 |
+
- black==22.10.0
|
97 |
+
- bleach==6.0.0
|
98 |
+
- cachetools==5.2.0
|
99 |
+
- cffi==1.15.1
|
100 |
+
- charset-normalizer==2.1.1
|
101 |
+
- click==8.1.3
|
102 |
+
- cmake==3.24.1
|
103 |
+
- cycler==0.11.0
|
104 |
+
- defusedxml==0.7.1
|
105 |
+
- dlib==19.24.1
|
106 |
+
- dotmap==1.3.30
|
107 |
+
- fastjsonschema==2.16.2
|
108 |
+
- flatbuffers==1.12
|
109 |
+
- fonttools==4.37.1
|
110 |
+
- gast==0.4.0
|
111 |
+
- google-auth==2.11.0
|
112 |
+
- google-auth-oauthlib==0.4.6
|
113 |
+
- google-pasta==0.2.0
|
114 |
+
- grpcio==1.48.1
|
115 |
+
- h5py==3.7.0
|
116 |
+
- idna==3.3
|
117 |
+
- imageio==2.27.0
|
118 |
+
- importlib-metadata==4.12.0
|
119 |
+
- importlib-resources==5.10.2
|
120 |
+
- imutils==0.5.4
|
121 |
+
- ipykernel==6.16.2
|
122 |
+
- ipython-genutils==0.2.0
|
123 |
+
- ipywidgets==8.0.2
|
124 |
+
- jinja2==3.1.2
|
125 |
+
- joblib==1.1.0
|
126 |
+
- jsonschema==4.17.3
|
127 |
+
- jupyter-console==6.6.3
|
128 |
+
- jupyter-core==4.12.0
|
129 |
+
- jupyter-server==1.23.5
|
130 |
+
- jupyterlab-pygments==0.2.2
|
131 |
+
- jupyterlab-widgets==3.0.3
|
132 |
+
- jupyterthemes==0.20.0
|
133 |
+
- keras==2.9.0
|
134 |
+
- keras-applications==1.0.8
|
135 |
+
- keras-preprocessing==1.1.2
|
136 |
+
- keras-vggface==0.6
|
137 |
+
- kiwisolver==1.4.4
|
138 |
+
- lesscpy==0.15.1
|
139 |
+
- libclang==14.0.6
|
140 |
+
- markdown==3.4.1
|
141 |
+
- markupsafe==2.1.1
|
142 |
+
- matplotlib==3.5.3
|
143 |
+
- mistune==2.0.4
|
144 |
+
- mtcnn==0.1.1
|
145 |
+
- mypy-extensions==0.4.3
|
146 |
+
- nbclassic==0.4.8
|
147 |
+
- nbclient==0.7.2
|
148 |
+
- nbconvert==7.2.9
|
149 |
+
- nbformat==5.7.3
|
150 |
+
- networkx==2.6.3
|
151 |
+
- notebook==6.5.2
|
152 |
+
- notebook-shim==0.2.2
|
153 |
+
- numpy==1.21.6
|
154 |
+
- oauthlib==3.2.0
|
155 |
+
- opencv-contrib-python==4.6.0.66
|
156 |
+
- opencv-python==4.6.0.66
|
157 |
+
- opt-einsum==3.3.0
|
158 |
+
- packaging==21.3
|
159 |
+
- pandas==1.3.5
|
160 |
+
- pandocfilters==1.5.0
|
161 |
+
- pathspec==0.10.1
|
162 |
+
- pillow==9.2.0
|
163 |
+
- pip==22.3.1
|
164 |
+
- pkgutil-resolve-name==1.3.10
|
165 |
+
- platformdirs==2.5.2
|
166 |
+
- ply==3.11
|
167 |
+
- prometheus-client==0.16.0
|
168 |
+
- prompt-toolkit==3.0.38
|
169 |
+
- protobuf==3.19.4
|
170 |
+
- psutil==5.9.5
|
171 |
+
- pyasn1==0.4.8
|
172 |
+
- pyasn1-modules==0.2.8
|
173 |
+
- pycparser==2.21
|
174 |
+
- pydot==1.4.2
|
175 |
+
- pyparsing==3.0.9
|
176 |
+
- pyrsistent==0.19.3
|
177 |
+
- pytz==2022.2.1
|
178 |
+
- pywavelets==1.3.0
|
179 |
+
- pyyaml==6.0
|
180 |
+
- requests==2.28.1
|
181 |
+
- requests-oauthlib==1.3.1
|
182 |
+
- rsa==4.9
|
183 |
+
- scikit-image==0.19.3
|
184 |
+
- scikit-learn==1.0.2
|
185 |
+
- scipy==1.7.3
|
186 |
+
- seaborn==0.11.2
|
187 |
+
- send2trash==1.8.0
|
188 |
+
- sniffio==1.3.0
|
189 |
+
- soupsieve==2.3.2.post1
|
190 |
+
- tensorboard==2.9.1
|
191 |
+
- tensorboard-data-server==0.6.1
|
192 |
+
- tensorboard-plugin-wit==1.8.1
|
193 |
+
- tensorflow==2.9.2
|
194 |
+
- tensorflow-addons==0.19.0
|
195 |
+
- tensorflow-estimator==2.9.0
|
196 |
+
- tensorflow-io-gcs-filesystem==0.26.0
|
197 |
+
- termcolor==1.1.0
|
198 |
+
- terminado==0.17.1
|
199 |
+
- threadpoolctl==3.1.0
|
200 |
+
- tifffile==2021.11.2
|
201 |
+
- tinycss2==1.2.1
|
202 |
+
- tomli==2.0.1
|
203 |
+
- tqdm==4.64.1
|
204 |
+
- traitlets==5.8.1
|
205 |
+
- trianglesolver==1.2
|
206 |
+
- typed-ast==1.5.4
|
207 |
+
- typeguard==2.13.3
|
208 |
+
- typing-extensions==4.3.0
|
209 |
+
- urllib3==1.26.12
|
210 |
+
- webencodings==0.5.1
|
211 |
+
- websocket-client==1.4.2
|
212 |
+
- werkzeug==2.2.2
|
213 |
+
- widgetsnbextension==4.0.3
|
214 |
+
- wrapt==1.14.1
|
215 |
+
- zipp==3.8.1
|
216 |
+
prefix: /home/suresh/miniconda3/envs/unmask
|
examples/1.png
ADDED
![]() |
examples/2.png
ADDED
![]() |
examples/3.png
ADDED
![]() |
examples/4.png
ADDED
![]() |
examples/5.png
ADDED
![]() |
examples/6.png
ADDED
![]() |
examples/7.png
ADDED
![]() |
examples/8.png
ADDED
![]() |
model_weights/checkpoint
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
model_checkpoint_path: "model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
|
2 |
+
all_model_checkpoint_paths: "model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
|
model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt.index
ADDED
Binary file (32 kB). View file
|
|
utils/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
import functools
|
5 |
+
from tqdm.notebook import tqdm
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
def image_to_array(image: Image) -> np.ndarray:
|
9 |
+
"""Convert Image to array"""
|
10 |
+
return np.asarray(image).astype('uint8')
|
11 |
+
|
12 |
+
|
13 |
+
def load_image(img_path: str) -> Image:
|
14 |
+
"""Load image to array"""
|
15 |
+
return Image.open(img_path)
|
16 |
+
|
17 |
+
|
18 |
+
def download_data(url, save_path, file_size=None):
|
19 |
+
"""Downloads data from `url` to `save_path`"""
|
20 |
+
r = requests.get(url, stream=True, allow_redirects=True)
|
21 |
+
if r.status_code != 200:
|
22 |
+
r.raise_for_status()
|
23 |
+
raise RuntimeError(f'Request to {url} returned status code {r.status_code}')
|
24 |
+
|
25 |
+
if file_size is None:
|
26 |
+
file_size = int(r.headers.get('content-length', 0))
|
27 |
+
|
28 |
+
r.raw.read = functools.partial(r.raw.read, decode_content=True) # Decompress if needed
|
29 |
+
with tqdm.wrapattr(r.raw, 'read', total=file_size, desc='') as r_raw:
|
30 |
+
with open(save_path, 'wb') as f:
|
31 |
+
shutil.copyfileobj(r_raw, f)
|
32 |
+
|
33 |
+
def plot_image_triple():
|
34 |
+
pass
|
utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.41 kB). View file
|
|
utils/__pycache__/architectures.cpython-37.pyc
ADDED
Binary file (9.89 kB). View file
|
|
utils/__pycache__/configuration.cpython-37.pyc
ADDED
Binary file (996 Bytes). View file
|
|
utils/__pycache__/face_detection.cpython-37.pyc
ADDED
Binary file (3.38 kB). View file
|
|
utils/__pycache__/model.cpython-37.pyc
ADDED
Binary file (15.7 kB). View file
|
|
utils/architectures.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Tuple, Optional
|
4 |
+
|
5 |
+
import tensorflow as tf
|
6 |
+
from tensorflow.keras.layers import *
|
7 |
+
from tensorflow.keras.models import *
|
8 |
+
|
9 |
+
class BaseUNet(ABC):
|
10 |
+
"""
|
11 |
+
Base Interface for UNet
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, model: Model):
|
15 |
+
self.model: Model = model
|
16 |
+
|
17 |
+
def get_model(self):
|
18 |
+
return self.model
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
class UNet(Enum):
|
27 |
+
"""
|
28 |
+
Enum class defining different architecture types available
|
29 |
+
"""
|
30 |
+
DEFAULT = 0
|
31 |
+
DEFAULT_IMAGENET_EMBEDDING = 1
|
32 |
+
RESNET = 3
|
33 |
+
RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV = 4
|
34 |
+
|
35 |
+
def build_model(self, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
|
36 |
+
kernels: Optional[Tuple] = None) -> BaseUNet:
|
37 |
+
|
38 |
+
# set default filters
|
39 |
+
if filters is None:
|
40 |
+
filters = (16, 32, 64, 128, 256)
|
41 |
+
|
42 |
+
# set default kernels
|
43 |
+
if kernels is None:
|
44 |
+
kernels = list(3 for _ in range(len(filters)))
|
45 |
+
|
46 |
+
# check kernels and filters
|
47 |
+
if len(filters) != len(kernels):
|
48 |
+
raise Exception('Kernels and filter count has to match.')
|
49 |
+
|
50 |
+
if self == UNet.DEFAULT_IMAGENET_EMBEDDING:
|
51 |
+
print('Using default UNet model with imagenet embedding')
|
52 |
+
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=True)
|
53 |
+
elif self == UNet.RESNET:
|
54 |
+
print('Using UNet Resnet model')
|
55 |
+
return UNet_resnet.build_model(input_size, filters, kernels)
|
56 |
+
elif self == UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV:
|
57 |
+
print('Using UNet Resnet model with attention mechanism and separable convolutions')
|
58 |
+
return UNet_ResNet_Attention_SeparableConv.build_model(input_size, filters, kernels)
|
59 |
+
|
60 |
+
print('Using default UNet model')
|
61 |
+
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=False)
|
62 |
+
|
63 |
+
|
64 |
+
class Attention(Layer):
|
65 |
+
def __init__(self, **kwargs):
|
66 |
+
super(Attention, self).__init__(**kwargs)
|
67 |
+
|
68 |
+
def build(self, input_shape):
|
69 |
+
# Create a trainable weight variable for this layer.
|
70 |
+
self.kernel = self.add_weight(name='kernel',
|
71 |
+
shape=(input_shape[-1], 1),
|
72 |
+
initializer='glorot_normal',
|
73 |
+
trainable=True)
|
74 |
+
self.bias = self.add_weight(name='bias',
|
75 |
+
shape=(1,),
|
76 |
+
initializer='zeros',
|
77 |
+
trainable=True)
|
78 |
+
super(Attention, self).build(input_shape) # Be sure to call this at the end
|
79 |
+
|
80 |
+
def call(self, x):
|
81 |
+
attention = tf.nn.softmax(tf.matmul(x, self.kernel) + self.bias, axis=-1)
|
82 |
+
return tf.multiply(x, attention)
|
83 |
+
|
84 |
+
def compute_output_shape(self, input_shape):
|
85 |
+
return input_shape
|
86 |
+
|
87 |
+
|
88 |
+
class UNet_ResNet_Attention_SeparableConv(BaseUNet):
|
89 |
+
"""
|
90 |
+
UNet architecture with resnet blocks, attention mechanism and separable convolutions
|
91 |
+
"""
|
92 |
+
@staticmethod
|
93 |
+
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
|
94 |
+
|
95 |
+
p0 = Input(shape=input_size)
|
96 |
+
conv_outputs = []
|
97 |
+
first_layer = SeparableConv2D(filters[0], kernels[0], padding='same')(p0)
|
98 |
+
int_layer = first_layer
|
99 |
+
for i, f in enumerate(filters):
|
100 |
+
int_layer, skip = UNet_ResNet_Attention_SeparableConv.down_block(int_layer, f, kernels[i])
|
101 |
+
conv_outputs.append(skip)
|
102 |
+
|
103 |
+
int_layer = UNet_ResNet_Attention_SeparableConv.bottleneck(int_layer, filters[-1], kernels[-1])
|
104 |
+
|
105 |
+
conv_outputs = list(reversed(conv_outputs))
|
106 |
+
reversed_filter = list(reversed(filters))
|
107 |
+
reversed_kernels = list(reversed(kernels))
|
108 |
+
for i, f in enumerate(reversed_filter):
|
109 |
+
if i + 1 < len(reversed_filter):
|
110 |
+
num_filters_next = reversed_filter[i + 1]
|
111 |
+
num_kernels_next = reversed_kernels[i + 1]
|
112 |
+
else:
|
113 |
+
num_filters_next = f
|
114 |
+
num_kernels_next = reversed_kernels[i]
|
115 |
+
int_layer = UNet_ResNet_Attention_SeparableConv.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
|
116 |
+
int_layer = Attention()(int_layer)
|
117 |
+
|
118 |
+
# concat. with the first layer
|
119 |
+
int_layer = Concatenate()([first_layer, int_layer])
|
120 |
+
int_layer = SeparableConv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
|
121 |
+
outputs = SeparableConv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
|
122 |
+
model = Model(p0, outputs)
|
123 |
+
return UNet_ResNet_Attention_SeparableConv(model)
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def down_block(x, num_filters: int = 64, kernel: int = 3):
|
127 |
+
# down-sample inputs
|
128 |
+
x = SeparableConv2D(num_filters, kernel, padding='same', strides=2, dilation_rate = 2)(x)
|
129 |
+
|
130 |
+
# inner block
|
131 |
+
out = SeparableConv2D(num_filters, kernel, padding='same')(x)
|
132 |
+
# out = BatchNormalization()(out)
|
133 |
+
out = Activation('relu')(out)
|
134 |
+
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
|
135 |
+
|
136 |
+
# merge with the skip connection
|
137 |
+
out = Add()([out, x])
|
138 |
+
# out = BatchNormalization()(out)
|
139 |
+
return Activation('relu')(out), x
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
|
143 |
+
# add U-Net skip connection - before up-sampling
|
144 |
+
concat = Concatenate()([x, skip])
|
145 |
+
|
146 |
+
# inner block
|
147 |
+
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(concat)
|
148 |
+
# out = BatchNormalization()(out)
|
149 |
+
out = Activation('relu')(out)
|
150 |
+
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
|
151 |
+
|
152 |
+
# merge with the skip connection
|
153 |
+
out = Add()([out, x])
|
154 |
+
# out = BatchNormalization()(out)
|
155 |
+
out = Activation('relu')(out)
|
156 |
+
|
157 |
+
# up-sample
|
158 |
+
out = UpSampling2D((2, 2))(out)
|
159 |
+
out = SeparableConv2D(num_filters_next, kernel, padding='same')(out)
|
160 |
+
# out = BatchNormalization()(out)
|
161 |
+
return Activation('relu')(out)
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def bottleneck(x, num_filters: int = 64, kernel: int = 3):
|
165 |
+
# inner block
|
166 |
+
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(x)
|
167 |
+
# out = BatchNormalization()(out)
|
168 |
+
out = Activation('relu')(out)
|
169 |
+
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
|
170 |
+
out = Add()([out, x])
|
171 |
+
# out = BatchNormalization()(out)
|
172 |
+
return Activation('relu')(out)
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
# Class for UNet with Resnet blocks
|
182 |
+
|
183 |
+
class UNet_resnet(BaseUNet):
|
184 |
+
"""
|
185 |
+
UNet architecture with resnet blocks
|
186 |
+
"""
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
|
190 |
+
|
191 |
+
p0 = Input(shape=input_size)
|
192 |
+
conv_outputs = []
|
193 |
+
first_layer = Conv2D(filters[0], kernels[0], padding='same')(p0)
|
194 |
+
int_layer = first_layer
|
195 |
+
for i, f in enumerate(filters):
|
196 |
+
int_layer, skip = UNet_resnet.down_block(int_layer, f, kernels[i])
|
197 |
+
conv_outputs.append(skip)
|
198 |
+
|
199 |
+
int_layer = UNet_resnet.bottleneck(int_layer, filters[-1], kernels[-1])
|
200 |
+
|
201 |
+
conv_outputs = list(reversed(conv_outputs))
|
202 |
+
reversed_filter = list(reversed(filters))
|
203 |
+
reversed_kernels = list(reversed(kernels))
|
204 |
+
for i, f in enumerate(reversed_filter):
|
205 |
+
if i + 1 < len(reversed_filter):
|
206 |
+
num_filters_next = reversed_filter[i + 1]
|
207 |
+
num_kernels_next = reversed_kernels[i + 1]
|
208 |
+
else:
|
209 |
+
num_filters_next = f
|
210 |
+
num_kernels_next = reversed_kernels[i]
|
211 |
+
int_layer = UNet_resnet.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
|
212 |
+
|
213 |
+
# concat. with the first layer
|
214 |
+
int_layer = Concatenate()([first_layer, int_layer])
|
215 |
+
int_layer = Conv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
|
216 |
+
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
|
217 |
+
model = Model(p0, outputs)
|
218 |
+
return UNet_resnet(model)
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def down_block(x, num_filters: int = 64, kernel: int = 3):
|
222 |
+
# down-sample inputs
|
223 |
+
x = Conv2D(num_filters, kernel, padding='same', strides=2)(x)
|
224 |
+
|
225 |
+
# inner block
|
226 |
+
out = Conv2D(num_filters, kernel, padding='same')(x)
|
227 |
+
# out = BatchNormalization()(out)
|
228 |
+
out = Activation('relu')(out)
|
229 |
+
out = Conv2D(num_filters, kernel, padding='same')(out)
|
230 |
+
|
231 |
+
# merge with the skip connection
|
232 |
+
out = Add()([out, x])
|
233 |
+
# out = BatchNormalization()(out)
|
234 |
+
return Activation('relu')(out), x
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
|
238 |
+
|
239 |
+
# add U-Net skip connection - before up-sampling
|
240 |
+
concat = Concatenate()([x, skip])
|
241 |
+
|
242 |
+
# inner block
|
243 |
+
out = Conv2D(num_filters, kernel, padding='same')(concat)
|
244 |
+
# out = BatchNormalization()(out)
|
245 |
+
out = Activation('relu')(out)
|
246 |
+
out = Conv2D(num_filters, kernel, padding='same')(out)
|
247 |
+
|
248 |
+
# merge with the skip connection
|
249 |
+
out = Add()([out, x])
|
250 |
+
# out = BatchNormalization()(out)
|
251 |
+
out = Activation('relu')(out)
|
252 |
+
|
253 |
+
# add U-Net skip connection - before up-sampling
|
254 |
+
concat = Concatenate()([out, skip])
|
255 |
+
|
256 |
+
# up-sample
|
257 |
+
# out = UpSampling2D((2, 2))(concat)
|
258 |
+
out = Conv2DTranspose(num_filters_next, kernel, padding='same', strides=2)(concat)
|
259 |
+
out = Conv2D(num_filters_next, kernel, padding='same')(out)
|
260 |
+
# out = BatchNormalization()(out)
|
261 |
+
return Activation('relu')(out)
|
262 |
+
|
263 |
+
@staticmethod
|
264 |
+
def bottleneck(x, filters, kernel: int = 3):
|
265 |
+
x = Conv2D(filters, kernel, padding='same', name='bottleneck')(x)
|
266 |
+
# x = BatchNormalization()(x)
|
267 |
+
return Activation('relu')(x)
|
268 |
+
|
269 |
+
|
270 |
+
class UNetDefault(BaseUNet):
|
271 |
+
"""
|
272 |
+
UNet architecture from following github notebook for image segmentation:
|
273 |
+
https://github.com/nikhilroxtomar/UNet-Segmentation-in-Keras-TensorFlow/blob/master/unet-segmentation.ipynb
|
274 |
+
https://github.com/nikhilroxtomar/Polyp-Segmentation-using-UNET-in-TensorFlow-2.0
|
275 |
+
"""
|
276 |
+
|
277 |
+
@staticmethod
|
278 |
+
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple, use_embedding: bool = True):
|
279 |
+
|
280 |
+
p0 = Input(input_size)
|
281 |
+
|
282 |
+
if use_embedding:
|
283 |
+
mobilenet_model = tf.keras.applications.MobileNetV2(
|
284 |
+
input_shape=input_size, include_top=False, weights='imagenet'
|
285 |
+
)
|
286 |
+
mobilenet_model.trainable = False
|
287 |
+
mn1 = mobilenet_model(p0)
|
288 |
+
mn1 = Reshape((16, 16, 320))(mn1)
|
289 |
+
|
290 |
+
conv_outputs = []
|
291 |
+
int_layer = p0
|
292 |
+
|
293 |
+
for f in filters:
|
294 |
+
conv_output, int_layer = UNetDefault.down_block(int_layer, f)
|
295 |
+
conv_outputs.append(conv_output)
|
296 |
+
|
297 |
+
int_layer = UNetDefault.bottleneck(int_layer, filters[-1])
|
298 |
+
|
299 |
+
if use_embedding:
|
300 |
+
int_layer = Concatenate()([int_layer, mn1])
|
301 |
+
|
302 |
+
conv_outputs = list(reversed(conv_outputs))
|
303 |
+
for i, f in enumerate(reversed(filters)):
|
304 |
+
int_layer = UNetDefault.up_block(int_layer, conv_outputs[i], f)
|
305 |
+
|
306 |
+
int_layer = Conv2D(filters[0] // 2, 3, padding="same", activation="relu")(int_layer)
|
307 |
+
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
|
308 |
+
model = Model(p0, outputs)
|
309 |
+
return UNetDefault(model)
|
310 |
+
|
311 |
+
@staticmethod
|
312 |
+
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
|
313 |
+
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
|
314 |
+
# c = BatchNormalization()(c)
|
315 |
+
p = MaxPool2D((2, 2), (2, 2))(c)
|
316 |
+
return c, p
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
|
320 |
+
us = UpSampling2D((2, 2))(x)
|
321 |
+
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(us)
|
322 |
+
# c = BatchNormalization()(c)
|
323 |
+
concat = Concatenate()([c, skip])
|
324 |
+
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
|
325 |
+
# c = BatchNormalization()(c)
|
326 |
+
return c
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
|
330 |
+
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
|
331 |
+
# c = BatchNormalization()(c)
|
332 |
+
return c
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
filters = (64, 128, 128, 256, 256, 512)
|
337 |
+
kernels = (7, 7, 7, 3, 3, 3)
|
338 |
+
input_image_size = (256, 256, 3)
|
339 |
+
# model = UNet_resnet()
|
340 |
+
# model = model.build_model(input_size=input_image_size,filters=filters,kernels=kernels)
|
341 |
+
# print(model.summary())
|
342 |
+
# __init__() missing 1 required positional argument: 'model'
|
343 |
+
model = UNetDefault.build_model(input_size=input_image_size, filters=filters, kernels=kernels)
|
344 |
+
print(model.summary())
|
utils/configuration.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Configuration:
|
8 |
+
def __init__(self, config_file_path: str = "configuration.json"):
|
9 |
+
self.config_file_path = config_file_path
|
10 |
+
self.config_json = None
|
11 |
+
if os.path.exists(config_file_path):
|
12 |
+
with open(self.config_file_path, 'r') as json_file:
|
13 |
+
self.config_json = json.load(json_file)
|
14 |
+
else:
|
15 |
+
print(f'ERROR: Configuration JSON {config_file_path} does not exist.')
|
16 |
+
|
17 |
+
def get(self, key: str):
|
18 |
+
if key in self.config_json:
|
19 |
+
return self.config_json[key]
|
20 |
+
else:
|
21 |
+
print(f'ERROR: Key \'{key}\' is not in configuration JSON.')
|
22 |
+
return None
|
utils/data_generator.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import dlib
|
3 |
+
import os
|
4 |
+
import bz2
|
5 |
+
import random
|
6 |
+
from tqdm.notebook import tqdm
|
7 |
+
import shutil
|
8 |
+
from utils import image_to_array, load_image, download_data
|
9 |
+
from utils.face_detection import crop_face, get_face_keypoints_detecting_function
|
10 |
+
from mask_utils.mask_utils import mask_image
|
11 |
+
|
12 |
+
|
13 |
+
class DataGenerator:
|
14 |
+
def __init__(self, configuration):
|
15 |
+
self.configuration = configuration
|
16 |
+
self.path_to_data = configuration.get('input_images_path')
|
17 |
+
self.path_to_patterns = configuration.get('path_to_patterns')
|
18 |
+
self.minimal_confidence = configuration.get('minimal_confidence')
|
19 |
+
self.hyp_ratio = configuration.get('hyp_ratio')
|
20 |
+
self.coordinates_range = configuration.get('coordinates_range')
|
21 |
+
self.test_image_count = configuration.get('test_image_count')
|
22 |
+
self.train_image_count = configuration.get('train_image_count')
|
23 |
+
self.train_data_path = configuration.get('train_data_path')
|
24 |
+
self.test_data_path = configuration.get('test_data_path')
|
25 |
+
self.predictor_path = configuration.get('landmarks_predictor_path')
|
26 |
+
self.check_predictor()
|
27 |
+
|
28 |
+
self.valid_image_extensions = ('png', 'jpg', 'jpeg')
|
29 |
+
self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(self.minimal_confidence)
|
30 |
+
|
31 |
+
def check_predictor(self):
|
32 |
+
""" Check if predictor exists. If not downloads it. """
|
33 |
+
if not os.path.exists(self.predictor_path):
|
34 |
+
print('Downloading missing predictor.')
|
35 |
+
url = self.configuration.get('landmarks_predictor_download_url')
|
36 |
+
download_data(url, self.predictor_path + '.bz2', 64040097)
|
37 |
+
print(f'Decompressing downloaded file into {self.predictor_path}')
|
38 |
+
with bz2.BZ2File(self.predictor_path + '.bz2') as fr, open(self.predictor_path, 'wb') as fw:
|
39 |
+
shutil.copyfileobj(fr, fw)
|
40 |
+
|
41 |
+
def get_face_landmarks(self, image):
|
42 |
+
"""Compute 68 facial landmarks"""
|
43 |
+
landmarks = []
|
44 |
+
image_array = image_to_array(image)
|
45 |
+
detector = dlib.get_frontal_face_detector()
|
46 |
+
predictor = dlib.shape_predictor(self.predictor_path)
|
47 |
+
face_rectangles = detector(image_array)
|
48 |
+
if len(face_rectangles) < 1:
|
49 |
+
return None
|
50 |
+
dlib_shape = predictor(image_array, face_rectangles[0])
|
51 |
+
for i in range(0, dlib_shape.num_parts):
|
52 |
+
landmarks.append([dlib_shape.part(i).x, dlib_shape.part(i).y])
|
53 |
+
return landmarks
|
54 |
+
|
55 |
+
def get_files_faces(self):
|
56 |
+
"""Get path of all images in dataset"""
|
57 |
+
image_files = []
|
58 |
+
for dirpath, dirs, files in os.walk(self.path_to_data):
|
59 |
+
for filename in files:
|
60 |
+
fname = os.path.join(dirpath, filename)
|
61 |
+
if fname.endswith(self.valid_image_extensions):
|
62 |
+
image_files.append(fname)
|
63 |
+
|
64 |
+
return image_files
|
65 |
+
|
66 |
+
def generate_images(self, image_size=None, test_image_count=None, train_image_count=None):
|
67 |
+
"""Generate test and train data (images with and without the mask)"""
|
68 |
+
if image_size is None:
|
69 |
+
image_size = self.configuration.get('image_size')
|
70 |
+
if test_image_count is None:
|
71 |
+
test_image_count = self.test_image_count
|
72 |
+
if train_image_count is None:
|
73 |
+
train_image_count = self.train_image_count
|
74 |
+
|
75 |
+
if not os.path.exists(self.train_data_path):
|
76 |
+
os.mkdir(self.train_data_path)
|
77 |
+
os.mkdir(os.path.join(self.train_data_path, 'inputs'))
|
78 |
+
os.mkdir(os.path.join(self.train_data_path, 'outputs'))
|
79 |
+
|
80 |
+
if not os.path.exists(self.test_data_path):
|
81 |
+
os.mkdir(self.test_data_path)
|
82 |
+
os.mkdir(os.path.join(self.test_data_path, 'inputs'))
|
83 |
+
os.mkdir(os.path.join(self.test_data_path, 'outputs'))
|
84 |
+
|
85 |
+
print('Generating testing data')
|
86 |
+
self.generate_data(test_image_count,
|
87 |
+
image_size=image_size,
|
88 |
+
save_to=self.test_data_path)
|
89 |
+
print('Generating training data')
|
90 |
+
self.generate_data(train_image_count,
|
91 |
+
image_size=image_size,
|
92 |
+
save_to=self.train_data_path)
|
93 |
+
|
94 |
+
def generate_data(self, number_of_images, image_size=None, save_to=None):
|
95 |
+
""" Add masks on `number_of_images` images
|
96 |
+
if save_to is valid path to folder images are saved there otherwise generated data are just returned in list
|
97 |
+
"""
|
98 |
+
inputs = []
|
99 |
+
outputs = []
|
100 |
+
|
101 |
+
if image_size is None:
|
102 |
+
image_size = self.configuration.get('image_size')
|
103 |
+
|
104 |
+
for i, file in tqdm(enumerate(random.sample(self.get_files_faces(), number_of_images)), total=number_of_images):
|
105 |
+
# Load images
|
106 |
+
image = load_image(file)
|
107 |
+
|
108 |
+
# Detect keypoints and landmarks on face
|
109 |
+
face_landmarks = self.get_face_landmarks(image)
|
110 |
+
if face_landmarks is None:
|
111 |
+
continue
|
112 |
+
keypoints = self.face_keypoints_detecting_fun(image)
|
113 |
+
|
114 |
+
# Generate mask
|
115 |
+
image_with_mask = mask_image(copy.deepcopy(image), face_landmarks, self.configuration)
|
116 |
+
|
117 |
+
# Crop images
|
118 |
+
cropped_image = crop_face(image_with_mask, keypoints)
|
119 |
+
cropped_original = crop_face(image, keypoints)
|
120 |
+
|
121 |
+
# Resize all images to NN input size
|
122 |
+
res_image = cropped_image.resize(image_size)
|
123 |
+
res_original = cropped_original.resize(image_size)
|
124 |
+
|
125 |
+
# Save generated data to lists or to folder
|
126 |
+
if save_to is None:
|
127 |
+
inputs.append(res_image)
|
128 |
+
outputs.append(res_original)
|
129 |
+
else:
|
130 |
+
res_image.save(os.path.join(save_to, 'inputs', f"{i:06d}.png"))
|
131 |
+
res_original.save(os.path.join(save_to, 'outputs', f"{i:06d}.png"))
|
132 |
+
|
133 |
+
if save_to is None:
|
134 |
+
return inputs, outputs
|
135 |
+
|
136 |
+
def get_dataset_examples(self, n=10, test_dataset=False):
|
137 |
+
"""
|
138 |
+
Returns `n` random images form dataset. If `test_dataset` parameter
|
139 |
+
is not provided or False it will return images from training part of dataset.
|
140 |
+
If `test_dataset` parameter is True it will return images from testing part of dataset.
|
141 |
+
"""
|
142 |
+
if test_dataset:
|
143 |
+
data_path = self.test_data_path
|
144 |
+
else:
|
145 |
+
data_path = self.train_data_path
|
146 |
+
|
147 |
+
images = os.listdir(os.path.join(data_path, 'inputs'))
|
148 |
+
images = random.sample(images, n)
|
149 |
+
inputs = [os.path.join(data_path, 'inputs', img) for img in images]
|
150 |
+
outputs = [os.path.join(data_path, 'outputs', img) for img in images]
|
151 |
+
return inputs, outputs
|
utils/face_detection.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Functions for face detection"""
|
2 |
+
from math import pi
|
3 |
+
from typing import Tuple, Optional, Dict
|
4 |
+
|
5 |
+
import tensorflow as tf
|
6 |
+
import matplotlib.patches as patches
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
from mtcnn import MTCNN
|
10 |
+
from trianglesolver import solve
|
11 |
+
|
12 |
+
from utils import image_to_array
|
13 |
+
|
14 |
+
|
15 |
+
def compute_slacks(height, width, hyp_ratio) -> Tuple[float, float]:
|
16 |
+
"""Compute slacks to add to bounding box on each site"""
|
17 |
+
|
18 |
+
# compute angle and side for hypotenuse
|
19 |
+
_, b, _, A, _, _ = solve(c=width, a=height, B=pi / 2)
|
20 |
+
|
21 |
+
# compute new height and width
|
22 |
+
a, _, c, _, _, _ = solve(b=b * (1.0 + hyp_ratio), B=pi / 2, A=A)
|
23 |
+
|
24 |
+
# compute slacks
|
25 |
+
return c - width, a - height
|
26 |
+
|
27 |
+
|
28 |
+
def get_face_keypoints_detecting_function(minimal_confidence: float = 0.8):
|
29 |
+
"""Create function for face keypoints detection"""
|
30 |
+
|
31 |
+
# face detector
|
32 |
+
detector = MTCNN()
|
33 |
+
|
34 |
+
# detect faces and their keypoints
|
35 |
+
def get_keypoints(image: Image) -> Optional[Dict]:
|
36 |
+
|
37 |
+
# run inference to detect faces (on CPU only)
|
38 |
+
with tf.device("/cpu:0"):
|
39 |
+
detection = detector.detect_faces(image_to_array(image))
|
40 |
+
|
41 |
+
# run detection and keep results with certain confidence only
|
42 |
+
results = [item for item in detection if item['confidence'] > minimal_confidence]
|
43 |
+
|
44 |
+
# nothing found
|
45 |
+
if len(results) == 0:
|
46 |
+
return None
|
47 |
+
|
48 |
+
# return result with highest confidence and size
|
49 |
+
return max(results, key=lambda item: item['confidence'] * item['box'][2] * item['box'][3])
|
50 |
+
|
51 |
+
# return function
|
52 |
+
return get_keypoints
|
53 |
+
|
54 |
+
|
55 |
+
def plot_face_detection(image: Image, ax, face_keypoints: Optional, hyp_ratio: float = 1 / 3):
|
56 |
+
"""Plot faces with keypoints and bounding boxes"""
|
57 |
+
|
58 |
+
# make annotations
|
59 |
+
if face_keypoints is not None:
|
60 |
+
|
61 |
+
# get bounding box
|
62 |
+
x, y, width, height = face_keypoints['box']
|
63 |
+
|
64 |
+
# add rectangle patch for detected face
|
65 |
+
rectangle = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
|
66 |
+
ax.add_patch(rectangle)
|
67 |
+
|
68 |
+
# add rectangle patch with slacks
|
69 |
+
w_s, h_s = compute_slacks(height, width, hyp_ratio)
|
70 |
+
rectangle = patches.Rectangle((x - w_s, y - h_s), width + 2 * w_s, height + 2 * h_s, linewidth=1, edgecolor='r',
|
71 |
+
facecolor='none')
|
72 |
+
ax.add_patch(rectangle)
|
73 |
+
|
74 |
+
# add keypoints
|
75 |
+
for coordinates in face_keypoints['keypoints'].values():
|
76 |
+
circle = plt.Circle(coordinates, 3, color='r')
|
77 |
+
ax.add_artist(circle)
|
78 |
+
|
79 |
+
# add image
|
80 |
+
ax.imshow(image)
|
81 |
+
|
82 |
+
|
83 |
+
def get_crop_points(image: Image, face_keypoints: Optional, hyp_ratio: float = 1 / 3) -> Image:
|
84 |
+
"""Find position where to crop face from image"""
|
85 |
+
if face_keypoints is None:
|
86 |
+
return 0, 0, image.width, image.height
|
87 |
+
|
88 |
+
# get bounding box
|
89 |
+
x, y, width, height = face_keypoints['box']
|
90 |
+
|
91 |
+
# compute slacks
|
92 |
+
w_s, h_s = compute_slacks(height, width, hyp_ratio)
|
93 |
+
|
94 |
+
# compute coordinates
|
95 |
+
left = min(max(0, x - w_s), image.width)
|
96 |
+
upper = min(max(0, y - h_s), image.height)
|
97 |
+
right = min(x + width + w_s, image.width)
|
98 |
+
lower = min(y + height + h_s, image.height)
|
99 |
+
|
100 |
+
return left, upper, right, lower
|
101 |
+
|
102 |
+
|
103 |
+
def crop_face(image: Image, face_keypoints: Optional, hyp_ratio: float = 1 / 3) -> Image:
|
104 |
+
"""Crop input image to just the face"""
|
105 |
+
if face_keypoints is None:
|
106 |
+
print("No keypoints detected on image")
|
107 |
+
return image
|
108 |
+
|
109 |
+
left, upper, right, lower = get_crop_points(image, face_keypoints, hyp_ratio)
|
110 |
+
|
111 |
+
return image.crop((left, upper, right, lower))
|
utils/model.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
from glob import glob
|
4 |
+
from typing import Tuple, Optional
|
5 |
+
from utils import load_image
|
6 |
+
import random
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from PIL import Image
|
11 |
+
from sklearn.model_selection import train_test_split
|
12 |
+
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
13 |
+
from tensorflow.keras.utils import CustomObjectScope
|
14 |
+
from utils.face_detection import get_face_keypoints_detecting_function, crop_face, get_crop_points
|
15 |
+
from utils.architectures import UNet
|
16 |
+
from tensorflow.keras.losses import MeanSquaredError, mean_squared_error
|
17 |
+
from keras_vggface.vggface import VGGFace
|
18 |
+
import tensorflow.keras.backend as K
|
19 |
+
from tensorflow.keras.applications import VGG19
|
20 |
+
|
21 |
+
# # VGG19 model for perceptual loss
|
22 |
+
# vgg = VGG19(include_top=False, weights='imagenet')
|
23 |
+
|
24 |
+
# def preprocess_image(image):
|
25 |
+
# image = tf.image.resize(image, (224, 224))
|
26 |
+
# image = tf.keras.applications.vgg19.preprocess_input(image)
|
27 |
+
# return image
|
28 |
+
|
29 |
+
# def perceptual_loss(y_true, y_pred):
|
30 |
+
# y_true = preprocess_image(y_true)
|
31 |
+
# y_pred = preprocess_image(y_pred)
|
32 |
+
# y_true_c = vgg(y_true)
|
33 |
+
# y_pred_c = vgg(y_pred)
|
34 |
+
# loss = K.mean(K.square(y_pred_c - y_true_c))
|
35 |
+
# return loss
|
36 |
+
|
37 |
+
vgg_face_model = VGGFace(model='resnet50', include_top=False, input_shape=(256, 256, 3), pooling='avg')
|
38 |
+
|
39 |
+
|
40 |
+
class ModelLoss:
|
41 |
+
@staticmethod
|
42 |
+
@tf.function
|
43 |
+
def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
|
44 |
+
"""
|
45 |
+
Computes MS-SSIM and perceptual loss
|
46 |
+
@param gt: Ground truth image
|
47 |
+
@param y_pred: Predicted image
|
48 |
+
@param max_val: Maximal MS-SSIM value
|
49 |
+
@param l1_weight: Weight of L1 normalization
|
50 |
+
@return: MS-SSIM and perceptual loss
|
51 |
+
"""
|
52 |
+
|
53 |
+
# Compute SSIM loss
|
54 |
+
ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
|
55 |
+
|
56 |
+
# Compute perceptual loss
|
57 |
+
vgg_face_outputs = vgg_face_model(y_pred)
|
58 |
+
vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt)))
|
59 |
+
|
60 |
+
|
61 |
+
# Combine both losses with l1 normalization
|
62 |
+
l1 = mean_squared_error(gt, y_pred)
|
63 |
+
l1_casted = tf.cast(l1 * l1_weight, tf.float32)
|
64 |
+
return ssim_loss + l1_casted + vgg_face_loss
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
class LFUNet(tf.keras.models.Model):
|
69 |
+
"""
|
70 |
+
Model for Mask2Face - removes mask from people faces using U-net neural network
|
71 |
+
"""
|
72 |
+
def __init__(self, model: tf.keras.models.Model, configuration=None, *args, **kwargs):
|
73 |
+
super().__init__(*args, **kwargs)
|
74 |
+
self.model: tf.keras.models.Model = model
|
75 |
+
self.configuration = configuration
|
76 |
+
self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(0.8)
|
77 |
+
self.mse = MeanSquaredError()
|
78 |
+
|
79 |
+
def call(self, x, **kwargs):
|
80 |
+
return self.model(x)
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
@tf.function
|
84 |
+
def ssim_loss(gt, y_pred, max_val=1.0):
|
85 |
+
"""
|
86 |
+
Computes standard SSIM loss
|
87 |
+
@param gt: Ground truth image
|
88 |
+
@param y_pred: Predicted image
|
89 |
+
@param max_val: Maximal SSIM value
|
90 |
+
@return: SSIM loss
|
91 |
+
"""
|
92 |
+
return 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
@tf.function
|
96 |
+
def ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
|
97 |
+
"""
|
98 |
+
Computes SSIM loss with L1 normalization
|
99 |
+
@param gt: Ground truth image
|
100 |
+
@param y_pred: Predicted image
|
101 |
+
@param max_val: Maximal SSIM value
|
102 |
+
@param l1_weight: Weight of L1 normalization
|
103 |
+
@return: SSIM L1 loss
|
104 |
+
"""
|
105 |
+
ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
|
106 |
+
l1 = mean_squared_error(gt, y_pred)
|
107 |
+
return ssim_loss + tf.cast(l1 * l1_weight, tf.float32)
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
# @staticmethod
|
112 |
+
# @tf.function
|
113 |
+
# def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0, perceptual_weight=1.0):
|
114 |
+
# """
|
115 |
+
# Computes MS-SSIM loss, L1 loss, and perceptual loss
|
116 |
+
# @param gt: Ground truth image
|
117 |
+
# @param y_pred: Predicted image
|
118 |
+
# @param max_val: Maximal SSIM value
|
119 |
+
# @param l1_weight: Weight of L1 normalization
|
120 |
+
# @param perceptual_weight: Weight of perceptual loss
|
121 |
+
# @return: MS-SSIM L1 perceptual loss
|
122 |
+
# """
|
123 |
+
# y_pred = tf.clip_by_value(y_pred, 0, float("inf"))
|
124 |
+
# y_pred = tf.debugging.check_numerics(y_pred, message='y_pred has NaN values')
|
125 |
+
# ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val))
|
126 |
+
# l1_loss = tf.losses.mean_absolute_error(gt, y_pred)
|
127 |
+
# vgg_face_outputs = vgg_face_model(y_pred)
|
128 |
+
# vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt)))
|
129 |
+
# return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32) + perceptual_weight*vgg_face_loss
|
130 |
+
|
131 |
+
# Function for ms-ssim loss + l1 loss
|
132 |
+
@staticmethod
|
133 |
+
@tf.function
|
134 |
+
def ms_ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
|
135 |
+
"""
|
136 |
+
Computes MS-SSIM loss and L1 loss
|
137 |
+
@param gt: Ground truth image
|
138 |
+
@param y_pred: Predicted image
|
139 |
+
@param max_val: Maximal SSIM value
|
140 |
+
@param l1_weight: Weight of L1 normalization
|
141 |
+
@return: MS-SSIM L1 loss
|
142 |
+
"""
|
143 |
+
# Replace NaN values with 0
|
144 |
+
y_pred = tf.clip_by_value(y_pred, 0, float("inf"))
|
145 |
+
|
146 |
+
ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val))
|
147 |
+
l1_loss = tf.losses.mean_absolute_error(gt, y_pred)
|
148 |
+
return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32)
|
149 |
+
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def load_model(model_path, configuration=None):
|
153 |
+
"""
|
154 |
+
Loads saved h5 file with trained model.
|
155 |
+
@param configuration: Optional instance of Configuration with config JSON
|
156 |
+
@param model_path: Path to h5 file
|
157 |
+
@return: LFUNet
|
158 |
+
"""
|
159 |
+
with CustomObjectScope({'ssim_loss': LFUNet.ssim_loss, 'ssim_l1_loss': LFUNet.ssim_l1_loss, 'ms_ssim_l1_perceptual_loss': ModelLoss.ms_ssim_l1_perceptual_loss, 'ms_ssim_l1_loss': LFUNet.ms_ssim_l1_loss}):
|
160 |
+
model = tf.keras.models.load_model(model_path)
|
161 |
+
return LFUNet(model, configuration)
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def build_model(architecture: UNet, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
|
165 |
+
kernels: Optional[Tuple] = None, configuration=None):
|
166 |
+
"""
|
167 |
+
Builds model based on input arguments
|
168 |
+
@param architecture: utils.architectures.UNet architecture
|
169 |
+
@param input_size: Size of input images
|
170 |
+
@param filters: Tuple with sizes of filters in U-net
|
171 |
+
@param kernels: Tuple with sizes of kernels in U-net. Must be the same size as filters.
|
172 |
+
@param configuration: Optional instance of Configuration with config JSON
|
173 |
+
@return: LFUNet
|
174 |
+
"""
|
175 |
+
return LFUNet(architecture.build_model(input_size, filters, kernels).get_model(), configuration)
|
176 |
+
|
177 |
+
def train(self, epochs=20, batch_size=20, loss_function='mse', learning_rate=1e-4,
|
178 |
+
predict_difference: bool = False):
|
179 |
+
"""
|
180 |
+
Train the model.
|
181 |
+
@param epochs: Number of epochs during training
|
182 |
+
@param batch_size: Batch size
|
183 |
+
@param loss_function: Loss function. Either standard tensorflow loss function or `ssim_loss` or `ssim_l1_loss`
|
184 |
+
@param learning_rate: Learning rate
|
185 |
+
@param predict_difference: Compute prediction on difference between input and output image
|
186 |
+
@return: History of training
|
187 |
+
"""
|
188 |
+
# get data
|
189 |
+
(train_x, train_y), (valid_x, valid_y) = self.load_train_data()
|
190 |
+
(test_x, test_y) = self.load_test_data()
|
191 |
+
|
192 |
+
train_dataset = LFUNet.tf_dataset(train_x, train_y, batch_size, predict_difference)
|
193 |
+
valid_dataset = LFUNet.tf_dataset(valid_x, valid_y, batch_size, predict_difference, train=False)
|
194 |
+
test_dataset = LFUNet.tf_dataset(test_x, test_y, batch_size, predict_difference, train=False)
|
195 |
+
|
196 |
+
# select loss
|
197 |
+
if loss_function == 'ssim_loss':
|
198 |
+
loss = LFUNet.ssim_loss
|
199 |
+
elif loss_function == 'ssim_l1_loss':
|
200 |
+
loss = LFUNet.ssim_l1_loss
|
201 |
+
elif loss_function == 'ms_ssim_l1_perceptual_loss':
|
202 |
+
loss = ModelLoss.ms_ssim_l1_perceptual_loss
|
203 |
+
elif loss_function == 'ms_ssim_l1_loss':
|
204 |
+
loss = LFUNet.ms_ssim_l1_loss
|
205 |
+
else:
|
206 |
+
loss = loss_function
|
207 |
+
|
208 |
+
# compile loss with selected loss function
|
209 |
+
self.model.compile(
|
210 |
+
loss=loss,
|
211 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate),
|
212 |
+
metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
|
213 |
+
)
|
214 |
+
|
215 |
+
# define callbacks
|
216 |
+
callbacks = [
|
217 |
+
ModelCheckpoint(
|
218 |
+
f'models/model_epochs-{epochs}_batch-{batch_size}_loss-{loss_function}_{LFUNet.get_datetime_string()}.h5'),
|
219 |
+
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
|
220 |
+
]
|
221 |
+
|
222 |
+
# evaluation before training
|
223 |
+
results = self.model.evaluate(test_dataset)
|
224 |
+
print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results))
|
225 |
+
|
226 |
+
# fit the model
|
227 |
+
history = self.model.fit(train_dataset, validation_data=valid_dataset, epochs=epochs, callbacks=callbacks)
|
228 |
+
|
229 |
+
# evaluation after training
|
230 |
+
results = self.model.evaluate(test_dataset)
|
231 |
+
print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results))
|
232 |
+
|
233 |
+
# use the model for inference on several test images
|
234 |
+
self._test_results(test_x, test_y, predict_difference)
|
235 |
+
|
236 |
+
# return training history
|
237 |
+
return history
|
238 |
+
|
239 |
+
def _test_results(self, test_x, test_y, predict_difference: bool):
|
240 |
+
"""
|
241 |
+
Test trained model on testing dataset. All images in testing dataset are processed and result image triples
|
242 |
+
(input with mask, ground truth, model output) are stored to `data/results` into folder with time stamp
|
243 |
+
when this method was executed.
|
244 |
+
@param test_x: List of input images
|
245 |
+
@param test_y: List of ground truth output images
|
246 |
+
@param predict_difference: Compute prediction on difference between input and output image
|
247 |
+
@return: None
|
248 |
+
"""
|
249 |
+
if self.configuration is None:
|
250 |
+
result_dir = f'data/results/{LFUNet.get_datetime_string()}/'
|
251 |
+
else:
|
252 |
+
result_dir = os.path.join(self.configuration.get('test_results_dir'), LFUNet.get_datetime_string())
|
253 |
+
os.makedirs(result_dir, exist_ok=True)
|
254 |
+
|
255 |
+
for i, (x, y) in enumerate(zip(test_x, test_y)):
|
256 |
+
x = LFUNet.read_image(x)
|
257 |
+
y = LFUNet.read_image(y)
|
258 |
+
|
259 |
+
y_pred = self.model.predict(np.expand_dims(x, axis=0))
|
260 |
+
if predict_difference:
|
261 |
+
y_pred = (y_pred * 2) - 1
|
262 |
+
y_pred = np.clip(x - y_pred.squeeze(axis=0), 0.0, 1.0)
|
263 |
+
else:
|
264 |
+
y_pred = y_pred.squeeze(axis=0)
|
265 |
+
h, w, _ = x.shape
|
266 |
+
white_line = np.ones((h, 10, 3)) * 255.0
|
267 |
+
|
268 |
+
all_images = [
|
269 |
+
x * 255.0, white_line,
|
270 |
+
y * 255.0, white_line,
|
271 |
+
y_pred * 255.0
|
272 |
+
]
|
273 |
+
image = np.concatenate(all_images, axis=1)
|
274 |
+
cv2.imwrite(os.path.join(result_dir, f"{i}.png"), image)
|
275 |
+
|
276 |
+
def summary(self):
|
277 |
+
"""
|
278 |
+
Prints model summary
|
279 |
+
"""
|
280 |
+
self.model.summary()
|
281 |
+
|
282 |
+
def predict(self, img_path, predict_difference: bool = False):
|
283 |
+
"""
|
284 |
+
Use trained model to take down the mask from image with person wearing the mask.
|
285 |
+
@param img_path: Path to image to processed
|
286 |
+
@param predict_difference: Compute prediction on difference between input and output image
|
287 |
+
@return: Image without the mask on the face
|
288 |
+
"""
|
289 |
+
# Load image into RGB format
|
290 |
+
image = load_image(img_path)
|
291 |
+
image = image.convert('RGB')
|
292 |
+
|
293 |
+
# Find facial keypoints and crop the image to just the face
|
294 |
+
keypoints = self.face_keypoints_detecting_fun(image)
|
295 |
+
cropped_image = crop_face(image, keypoints)
|
296 |
+
print(cropped_image.size)
|
297 |
+
|
298 |
+
# Resize image to input recognized by neural net
|
299 |
+
resized_image = cropped_image.resize((256, 256))
|
300 |
+
image_array = np.array(resized_image)
|
301 |
+
|
302 |
+
# Convert from RGB to BGR (open cv format)
|
303 |
+
image_array = image_array[:, :, ::-1].copy()
|
304 |
+
image_array = image_array / 255.0
|
305 |
+
|
306 |
+
# Remove mask from input image
|
307 |
+
y_pred = self.model.predict(np.expand_dims(image_array, axis=0))
|
308 |
+
h, w, _ = image_array.shape
|
309 |
+
|
310 |
+
if predict_difference:
|
311 |
+
y_pred = (y_pred * 2) - 1
|
312 |
+
y_pred = np.clip(image_array - y_pred.squeeze(axis=0), 0.0, 1.0)
|
313 |
+
else:
|
314 |
+
y_pred = y_pred.squeeze(axis=0)
|
315 |
+
|
316 |
+
# Convert output from model to image and scale it back to original size
|
317 |
+
y_pred = y_pred * 255.0
|
318 |
+
im = Image.fromarray(y_pred.astype(np.uint8)[:, :, ::-1])
|
319 |
+
im = im.resize(cropped_image.size)
|
320 |
+
left, upper, _, _ = get_crop_points(image, keypoints)
|
321 |
+
|
322 |
+
# Combine original image with output from model
|
323 |
+
image.paste(im, (int(left), int(upper)))
|
324 |
+
return image
|
325 |
+
|
326 |
+
@staticmethod
|
327 |
+
def get_datetime_string():
|
328 |
+
"""
|
329 |
+
Creates date-time string
|
330 |
+
@return: String with current date and time
|
331 |
+
"""
|
332 |
+
now = datetime.now()
|
333 |
+
return now.strftime("%Y%m%d_%H_%M_%S")
|
334 |
+
|
335 |
+
def load_train_data(self, split=0.2):
|
336 |
+
"""
|
337 |
+
Loads training data (paths to training images)
|
338 |
+
@param split: Percentage of training data used for validation as float from 0.0 to 1.0. Default 0.2.
|
339 |
+
@return: Two tuples - first with training data (tuple with (input images, output images)) and second
|
340 |
+
with validation data (tuple with (input images, output images))
|
341 |
+
"""
|
342 |
+
if self.configuration is None:
|
343 |
+
train_dir = 'data/train/'
|
344 |
+
limit = None
|
345 |
+
else:
|
346 |
+
train_dir = self.configuration.get('train_data_path')
|
347 |
+
limit = self.configuration.get('train_data_limit')
|
348 |
+
print(f'Loading training data from {train_dir} with limit of {limit} images')
|
349 |
+
return LFUNet.load_data(os.path.join(train_dir, 'inputs'), os.path.join(train_dir, 'outputs'), split, limit)
|
350 |
+
|
351 |
+
def load_test_data(self):
|
352 |
+
"""
|
353 |
+
Loads testing data (paths to testing images)
|
354 |
+
@return: Tuple with testing data - (input images, output images)
|
355 |
+
"""
|
356 |
+
if self.configuration is None:
|
357 |
+
test_dir = 'data/test/'
|
358 |
+
limit = None
|
359 |
+
else:
|
360 |
+
test_dir = self.configuration.get('test_data_path')
|
361 |
+
limit = self.configuration.get('test_data_limit')
|
362 |
+
print(f'Loading testing data from {test_dir} with limit of {limit} images')
|
363 |
+
return LFUNet.load_data(os.path.join(test_dir, 'inputs'), os.path.join(test_dir, 'outputs'), None, limit)
|
364 |
+
|
365 |
+
@staticmethod
|
366 |
+
def load_data(input_path, output_path, split=0.2, limit=None):
|
367 |
+
"""
|
368 |
+
Loads data (paths to images)
|
369 |
+
@param input_path: Path to folder with input images
|
370 |
+
@param output_path: Path to folder with output images
|
371 |
+
@param split: Percentage of data used for validation as float from 0.0 to 1.0. Default 0.2.
|
372 |
+
If split is None it expects you are loading testing data, otherwise expects training data.
|
373 |
+
@param limit: Maximal number of images loaded from data folder. Default None (no limit).
|
374 |
+
@return: If split is not None: Two tuples - first with training data (tuple with (input images, output images))
|
375 |
+
and second with validation data (tuple with (input images, output images))
|
376 |
+
Else: Tuple with testing data - (input images, output images)
|
377 |
+
"""
|
378 |
+
images = sorted(glob(os.path.join(input_path, "*.png")))
|
379 |
+
masks = sorted(glob(os.path.join(output_path, "*.png")))
|
380 |
+
if len(images) == 0:
|
381 |
+
raise TypeError(f'No images found in {input_path}')
|
382 |
+
if len(masks) == 0:
|
383 |
+
raise TypeError(f'No images found in {output_path}')
|
384 |
+
|
385 |
+
if limit is not None:
|
386 |
+
images = images[:limit]
|
387 |
+
masks = masks[:limit]
|
388 |
+
|
389 |
+
if split is not None:
|
390 |
+
total_size = len(images)
|
391 |
+
valid_size = int(split * total_size)
|
392 |
+
train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42)
|
393 |
+
train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42)
|
394 |
+
return (train_x, train_y), (valid_x, valid_y)
|
395 |
+
|
396 |
+
else:
|
397 |
+
return images, masks
|
398 |
+
|
399 |
+
@staticmethod
|
400 |
+
def read_image(path):
|
401 |
+
"""
|
402 |
+
Loads image, resize it to size 256x256 and normalize to float values from 0.0 to 1.0.
|
403 |
+
@param path: Path to image to be loaded.
|
404 |
+
@return: Loaded image in open CV format.
|
405 |
+
"""
|
406 |
+
x = cv2.imread(path, cv2.IMREAD_COLOR)
|
407 |
+
x = cv2.resize(x, (256, 256))
|
408 |
+
x = x / 255.0
|
409 |
+
return x
|
410 |
+
|
411 |
+
@staticmethod
|
412 |
+
def tf_parse(x, y):
|
413 |
+
"""
|
414 |
+
Mapping function for dataset creation. Load and resize images.
|
415 |
+
@param x: Path to input image
|
416 |
+
@param y: Path to output image
|
417 |
+
@return: Tuple with input and output image with shape (256, 256, 3)
|
418 |
+
"""
|
419 |
+
def _parse(x, y):
|
420 |
+
x = LFUNet.read_image(x.decode())
|
421 |
+
y = LFUNet.read_image(y.decode())
|
422 |
+
return x, y
|
423 |
+
|
424 |
+
x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
|
425 |
+
x.set_shape([256, 256, 3])
|
426 |
+
y.set_shape([256, 256, 3])
|
427 |
+
return x, y
|
428 |
+
|
429 |
+
@staticmethod
|
430 |
+
def tf_dataset(x, y, batch=8, predict_difference: bool = False, train: bool = True):
|
431 |
+
"""
|
432 |
+
Creates standard tensorflow dataset.
|
433 |
+
@param x: List of paths to input images
|
434 |
+
@param y: List of paths to output images
|
435 |
+
@param batch: Batch size
|
436 |
+
@param predict_difference: Compute prediction on difference between input and output image
|
437 |
+
@param train: Flag if training dataset should be generated
|
438 |
+
@return: Dataset with loaded images
|
439 |
+
"""
|
440 |
+
dataset = tf.data.Dataset.from_tensor_slices((x, y))
|
441 |
+
dataset = dataset.map(LFUNet.tf_parse)
|
442 |
+
random_seed = random.randint(0, 999999999)
|
443 |
+
|
444 |
+
if predict_difference:
|
445 |
+
def map_output(img_in, img_target):
|
446 |
+
return img_in, (img_in - img_target + 1.0) / 2.0
|
447 |
+
|
448 |
+
dataset = dataset.map(map_output)
|
449 |
+
|
450 |
+
if train:
|
451 |
+
# for the train set, we want to apply data augmentations and shuffle data to different batches
|
452 |
+
|
453 |
+
# random flip
|
454 |
+
def flip(img_in, img_out):
|
455 |
+
return tf.image.random_flip_left_right(img_in, random_seed), \
|
456 |
+
tf.image.random_flip_left_right(img_out, random_seed)
|
457 |
+
|
458 |
+
# augmenting quality - parameters
|
459 |
+
hue_delta = 0.05
|
460 |
+
saturation_low = 0.2
|
461 |
+
saturation_up = 1.3
|
462 |
+
brightness_delta = 0.1
|
463 |
+
contrast_low = 0.2
|
464 |
+
contrast_up = 1.5
|
465 |
+
|
466 |
+
# augmenting quality
|
467 |
+
def color(img_in, img_out):
|
468 |
+
# Augmentations applied are:
|
469 |
+
# - random hue
|
470 |
+
# - random saturation
|
471 |
+
# - random brightness
|
472 |
+
# - random contrast
|
473 |
+
# - random flip left right
|
474 |
+
# - random flip up down
|
475 |
+
img_in = tf.image.random_hue(img_in, hue_delta, random_seed)
|
476 |
+
img_in = tf.image.random_saturation(img_in, saturation_low, saturation_up, random_seed)
|
477 |
+
img_in = tf.image.random_brightness(img_in, brightness_delta, random_seed)
|
478 |
+
img_in = tf.image.random_contrast(img_in, contrast_low, contrast_up, random_seed)
|
479 |
+
img_out = tf.image.random_hue(img_out, hue_delta, random_seed)
|
480 |
+
img_out = tf.image.random_saturation(img_out, saturation_low, saturation_up, random_seed)
|
481 |
+
img_out = tf.image.random_brightness(img_out, brightness_delta, random_seed)
|
482 |
+
img_out = tf.image.random_contrast(img_out, contrast_low, contrast_up, random_seed)
|
483 |
+
return img_in, img_out
|
484 |
+
|
485 |
+
# shuffle data and create batches
|
486 |
+
dataset = dataset.shuffle(5000)
|
487 |
+
dataset = dataset.batch(batch)
|
488 |
+
|
489 |
+
# apply augmentations
|
490 |
+
dataset = dataset.map(flip)
|
491 |
+
dataset = dataset.map(color)
|
492 |
+
else:
|
493 |
+
dataset = dataset.batch(batch)
|
494 |
+
|
495 |
+
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
|