Initial commit
Browse files- +75 -0
- configuration.json +25 -0
- environment.yaml +216 -0
- examples/1.png +0 -0
- examples/2.png +0 -0
- examples/3.png +0 -0
- examples/4.png +0 -0
- examples/5.png +0 -0
- examples/6.png +0 -0
- examples/7.png +0 -0
- examples/8.png +0 -0
- model_weights/checkpoint +2 -0
- model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt.index +0 -0
- utils/ +34 -0
- utils/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/__pycache__/architectures.cpython-37.pyc +0 -0
- utils/__pycache__/configuration.cpython-37.pyc +0 -0
- utils/__pycache__/face_detection.cpython-37.pyc +0 -0
- utils/__pycache__/model.cpython-37.pyc +0 -0
- utils/ +344 -0
- utils/ +22 -0
- utils/ +151 -0
- utils/ +111 -0
- utils/ +495 -0
@@ -0,0 +1,75 @@
1 |
from utils.configuration import Configuration
2 |
import tensorflow as tf
3 |
from utils.model import ModelLoss
4 |
from utils.model import LFUNet
5 |
from utils.architectures import UNet
6 |
7 |
import gradio as gr
8 |
9 |
configuration = Configuration()
10 |
filters = (64, 128, 128, 256, 256, 512)
11 |
kernels = (7, 7, 7, 3, 3, 3)
12 |
input_image_size = (256, 256, 3)
13 |
14 |
15 |
trained_model = LFUNet.build_model(architecture=architecture, input_size=input_image_size, filters=filters,
16 |
kernels=kernels, configuration=configuration)
17 |
18 |
19 |
20 |
metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
21 |
22 |
23 |
weights_path = "model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
24 |
25 |
26 |
def main(input_img):
27 |
28 |
29 |
predicted_image = trained_model.predict(input_img)
30 |
return predicted_image
31 |
except Exception as e:
32 |
raise gr.Error("Sorry, something went wrong. Please try again!")
33 |
34 |
demo = gr.Interface(
35 |
title= "Lightweight network for face unmasking",
36 |
description= "This is a demo of a <b>Lightweight network for face unmasking</b> \
37 |
designed to provide a powerful and efficient solution for restoring facial details obscured by masks.<br> \
38 |
To use it, simply upload your image, or click one of the examples to load them. Inference needs some time since this demo uses CPU.",
39 |
fn = main,
40 |
inputs= gr.Image(type="filepath").style(height=256),
41 |
outputs=gr.Image(type='numpy',shape=(256, 256, 3)).style(height=256),
42 |
# allow_flagging='never',
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
css = """
54 |
.svelte-mppz8v {
55 |
text-align: -webkit-center;
56 |
57 |
58 |
.gallery {
59 |
display: flex;
60 |
flex-wrap: wrap;
61 |
width: 100%;
62 |
63 |
64 |
p {
65 |
font-size: medium;
66 |
67 |
68 |
h1 {
69 |
font-size: xx-large;
70 |
71 |
72 |
theme= 'EveryPizza/Cartoony-Gradio-Theme',
73 |
# article = "<p style='text-align: center'><a href='' target='_blank'>Simple Baselines for Image Restoration</a> | <a href='' target='_blank'>NAFSSR: Stereo Image Super-Resolution Using NAFNet</a> | <a href='' target='_blank'> Github Repo</a></p>"
74 |
75 |
demo.launch(show_error=True, share= True)
@@ -0,0 +1,25 @@
1 |
2 |
"input_images_path": "data/lfw-deepfunneled",
3 |
"dataset_archive_download_url": "",
4 |
"path_to_patterns": "data/mask_patterns",
5 |
"train_data_path": "data/train",
6 |
"test_data_path": "data/test",
7 |
"landmarks_predictor_path": "shape_predictor_68_face_landmarks.dat",
8 |
"landmarks_predictor_download_url": "",
9 |
"minimal_confidence": 0.8,
10 |
"hyp_ratio": 0.3333,
11 |
"coordinates_range": [-10, 10],
12 |
"test_image_count": 100,
13 |
"train_image_count": 15000,
14 |
"image_size": [256, 256],
15 |
"mask_type" : "random",
16 |
"mask_color" : null,
17 |
"mask_patter" : null,
18 |
"mask_pattern_weight" : 0.9,
19 |
"mask_color_weight" : 0.8,
20 |
"mask_filter_output" : false,
21 |
"mask_filter_radius" : 2,
22 |
"test_results_dir": "data/results/",
23 |
"train_data_limit": 10000,
24 |
"test_data_limit": 1000
25 |
@@ -0,0 +1,216 @@
1 |
name: unmask3
2 |
3 |
- conda-forge
4 |
- defaults
5 |
6 |
- _libgcc_mutex=0.1=main
7 |
- _openmp_mutex=5.1=1_gnu
8 |
- atk-1.0=2.36.0=ha1a6a79_0
9 |
- backcall=0.2.0=pyhd3eb1b0_0
10 |
- ca-certificates=2022.10.11=h06a4308_0
11 |
- cairo=1.16.0=h19f5f5c_2
12 |
- certifi=2022.9.24=py37h06a4308_0
13 |
- cudatoolkit=11.2.2=hbe64b41_10
14 |
- cudnn=
15 |
- debugpy=1.5.1=py37h295c915_0
16 |
- decorator=5.1.1=pyhd3eb1b0_0
17 |
- entrypoints=0.4=py37h06a4308_0
18 |
- expat=2.4.9=h6a678d5_0
19 |
- font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
20 |
- font-ttf-inconsolata=2.001=hcb22688_0
21 |
- font-ttf-source-code-pro=2.030=hd3eb1b0_0
22 |
- font-ttf-ubuntu=0.83=h8b1ccd4_0
23 |
- fontconfig=2.13.1=h6c09931_0
24 |
- fonts-anaconda=1=h8fa9717_0
25 |
- fonts-conda-ecosystem=1=hd3eb1b0_0
26 |
- freetype=2.11.0=h70c0345_0
27 |
- fribidi=1.0.10=h7b6447c_0
28 |
- gdk-pixbuf=2.42.8=h433bba3_0
29 |
- glib=2.69.1=h4ff587b_1
30 |
- gobject-introspection=1.72.0=py37hbb6d50b_0
31 |
- graphite2=1.3.14=h295c915_1
32 |
- graphviz=2.50.0=h3cd0ef9_0
33 |
- gtk2=2.24.33=h73c1081_2
34 |
- gts=0.7.6=hb67d8dd_3
35 |
- harfbuzz=4.3.0=hd55b92a_0
36 |
- icu=58.2=he6710b0_3
37 |
- ipython=7.31.1=py37h06a4308_1
38 |
- jedi=0.18.1=py37h06a4308_1
39 |
- jpeg=9e=h7f8727e_0
40 |
- jupyter_client=7.3.4=py37h06a4308_0
41 |
- ld_impl_linux-64=2.38=h1181459_1
42 |
- lerc=3.0=h295c915_0
43 |
- libdeflate=1.8=h7f8727e_5
44 |
- libffi=3.3=he6710b0_2
45 |
- libgcc-ng=11.2.0=h1234567_1
46 |
- libgd=2.3.3=h695aa2c_1
47 |
- libgomp=11.2.0=h1234567_1
48 |
- libpng=1.6.37=hbc83047_0
49 |
- librsvg=2.54.4=h19fe530_0
50 |
- libsodium=1.0.18=h7b6447c_0
51 |
- libstdcxx-ng=11.2.0=h1234567_1
52 |
- libtiff=4.4.0=hecacb30_0
53 |
- libtool=2.4.6=h295c915_1008
54 |
- libuuid=1.0.3=h7f8727e_2
55 |
- libwebp-base=1.2.4=h5eee18b_0
56 |
- libxcb=1.15=h7f8727e_0
57 |
- libxml2=2.9.14=h74e7548_0
58 |
- lz4-c=1.9.3=h295c915_1
59 |
- matplotlib-inline=0.1.6=py37h06a4308_0
60 |
- ncurses=6.3=h5eee18b_3
61 |
- nest-asyncio=1.5.5=py37h06a4308_0
62 |
- ninja=1.10.2=h06a4308_5
63 |
- ninja-base=1.10.2=hd09550d_5
64 |
- openssl=1.1.1q=h7f8727e_0
65 |
- pango=1.50.7=h05da053_0
66 |
- parso=0.8.3=pyhd3eb1b0_0
67 |
- pcre=8.45=h295c915_0
68 |
- pexpect=4.8.0=pyhd3eb1b0_3
69 |
- pickleshare=0.7.5=pyhd3eb1b0_1003
70 |
- pixman=0.40.0=h7f8727e_1
71 |
- ptyprocess=0.7.0=pyhd3eb1b0_2
72 |
- pygments=2.11.2=pyhd3eb1b0_0
73 |
- python=3.7.13=h12debd9_0
74 |
- python-dateutil=2.8.2=pyhd3eb1b0_0
75 |
- pyzmq=23.2.0=py37h6a678d5_0
76 |
- readline=8.1.2=h7f8727e_1
77 |
- setuptools=63.4.1=py37h06a4308_0
78 |
- six=1.16.0=pyhd3eb1b0_1
79 |
- sqlite=3.39.2=h5082296_0
80 |
- tk=8.6.12=h1ccaba5_0
81 |
- tornado=6.2=py37h5eee18b_0
82 |
- wcwidth=0.2.5=pyhd3eb1b0_0
83 |
- wheel=0.37.1=pyhd3eb1b0_0
84 |
- xz=5.2.5=h7f8727e_1
85 |
- zeromq=4.3.4=h2531618_0
86 |
- zlib=1.2.12=h7f8727e_2
87 |
- zstd=1.5.2=ha4553b6_0
88 |
- pip:
89 |
- absl-py==1.2.0
90 |
- anyio==3.6.2
91 |
- argon2-cffi==21.3.0
92 |
- argon2-cffi-bindings==21.2.0
93 |
- astunparse==1.6.3
94 |
- attrs==22.2.0
95 |
- beautifulsoup4==4.11.1
96 |
- black==22.10.0
97 |
- bleach==6.0.0
98 |
- cachetools==5.2.0
99 |
- cffi==1.15.1
100 |
- charset-normalizer==2.1.1
101 |
- click==8.1.3
102 |
- cmake==3.24.1
103 |
- cycler==0.11.0
104 |
- defusedxml==0.7.1
105 |
- dlib==19.24.1
106 |
- dotmap==1.3.30
107 |
- fastjsonschema==2.16.2
108 |
- flatbuffers==1.12
109 |
- fonttools==4.37.1
110 |
- gast==0.4.0
111 |
- google-auth==2.11.0
112 |
- google-auth-oauthlib==0.4.6
113 |
- google-pasta==0.2.0
114 |
- grpcio==1.48.1
115 |
- h5py==3.7.0
116 |
- idna==3.3
117 |
- imageio==2.27.0
118 |
- importlib-metadata==4.12.0
119 |
- importlib-resources==5.10.2
120 |
- imutils==0.5.4
121 |
- ipykernel==6.16.2
122 |
- ipython-genutils==0.2.0
123 |
- ipywidgets==8.0.2
124 |
- jinja2==3.1.2
125 |
- joblib==1.1.0
126 |
- jsonschema==4.17.3
127 |
- jupyter-console==6.6.3
128 |
- jupyter-core==4.12.0
129 |
- jupyter-server==1.23.5
130 |
- jupyterlab-pygments==0.2.2
131 |
- jupyterlab-widgets==3.0.3
132 |
- jupyterthemes==0.20.0
133 |
- keras==2.9.0
134 |
- keras-applications==1.0.8
135 |
- keras-preprocessing==1.1.2
136 |
- keras-vggface==0.6
137 |
- kiwisolver==1.4.4
138 |
- lesscpy==0.15.1
139 |
- libclang==14.0.6
140 |
- markdown==3.4.1
141 |
- markupsafe==2.1.1
142 |
- matplotlib==3.5.3
143 |
- mistune==2.0.4
144 |
- mtcnn==0.1.1
145 |
- mypy-extensions==0.4.3
146 |
- nbclassic==0.4.8
147 |
- nbclient==0.7.2
148 |
- nbconvert==7.2.9
149 |
- nbformat==5.7.3
150 |
- networkx==2.6.3
151 |
- notebook==6.5.2
152 |
- notebook-shim==0.2.2
153 |
- numpy==1.21.6
154 |
- oauthlib==3.2.0
155 |
- opencv-contrib-python==
156 |
- opencv-python==
157 |
- opt-einsum==3.3.0
158 |
- packaging==21.3
159 |
- pandas==1.3.5
160 |
- pandocfilters==1.5.0
161 |
- pathspec==0.10.1
162 |
- pillow==9.2.0
163 |
- pip==22.3.1
164 |
- pkgutil-resolve-name==1.3.10
165 |
- platformdirs==2.5.2
166 |
- ply==3.11
167 |
- prometheus-client==0.16.0
168 |
- prompt-toolkit==3.0.38
169 |
- protobuf==3.19.4
170 |
- psutil==5.9.5
171 |
- pyasn1==0.4.8
172 |
- pyasn1-modules==0.2.8
173 |
- pycparser==2.21
174 |
- pydot==1.4.2
175 |
- pyparsing==3.0.9
176 |
- pyrsistent==0.19.3
177 |
- pytz==2022.2.1
178 |
- pywavelets==1.3.0
179 |
- pyyaml==6.0
180 |
- requests==2.28.1
181 |
- requests-oauthlib==1.3.1
182 |
- rsa==4.9
183 |
- scikit-image==0.19.3
184 |
- scikit-learn==1.0.2
185 |
- scipy==1.7.3
186 |
- seaborn==0.11.2
187 |
- send2trash==1.8.0
188 |
- sniffio==1.3.0
189 |
- soupsieve==2.3.2.post1
190 |
- tensorboard==2.9.1
191 |
- tensorboard-data-server==0.6.1
192 |
- tensorboard-plugin-wit==1.8.1
193 |
- tensorflow==2.9.2
194 |
- tensorflow-addons==0.19.0
195 |
- tensorflow-estimator==2.9.0
196 |
- tensorflow-io-gcs-filesystem==0.26.0
197 |
- termcolor==1.1.0
198 |
- terminado==0.17.1
199 |
- threadpoolctl==3.1.0
200 |
- tifffile==2021.11.2
201 |
- tinycss2==1.2.1
202 |
- tomli==2.0.1
203 |
- tqdm==4.64.1
204 |
- traitlets==5.8.1
205 |
- trianglesolver==1.2
206 |
- typed-ast==1.5.4
207 |
- typeguard==2.13.3
208 |
- typing-extensions==4.3.0
209 |
- urllib3==1.26.12
210 |
- webencodings==0.5.1
211 |
- websocket-client==1.4.2
212 |
- werkzeug==2.2.2
213 |
- widgetsnbextension==4.0.3
214 |
- wrapt==1.14.1
215 |
- zipp==3.8.1
216 |
prefix: /home/suresh/miniconda3/envs/unmask
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
@@ -0,0 +1,2 @@
1 |
model_checkpoint_path: "model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
2 |
all_model_checkpoint_paths: "model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt"
Binary file (32 kB). View file
@@ -0,0 +1,34 @@
1 |
import numpy as np
2 |
from PIL import Image
3 |
import requests
4 |
import functools
5 |
from tqdm.notebook import tqdm
6 |
import shutil
7 |
8 |
def image_to_array(image: Image) -> np.ndarray:
9 |
"""Convert Image to array"""
10 |
return np.asarray(image).astype('uint8')
11 |
12 |
13 |
def load_image(img_path: str) -> Image:
14 |
"""Load image to array"""
15 |
16 |
17 |
18 |
def download_data(url, save_path, file_size=None):
19 |
"""Downloads data from `url` to `save_path`"""
20 |
r = requests.get(url, stream=True, allow_redirects=True)
21 |
if r.status_code != 200:
22 |
23 |
raise RuntimeError(f'Request to {url} returned status code {r.status_code}')
24 |
25 |
if file_size is None:
26 |
file_size = int(r.headers.get('content-length', 0))
27 |
28 |
+ = functools.partial(, decode_content=True) # Decompress if needed
29 |
with tqdm.wrapattr(r.raw, 'read', total=file_size, desc='') as r_raw:
30 |
with open(save_path, 'wb') as f:
31 |
shutil.copyfileobj(r_raw, f)
32 |
33 |
def plot_image_triple():
34 |
Binary file (1.41 kB). View file
Binary file (9.89 kB). View file
Binary file (996 Bytes). View file
Binary file (3.38 kB). View file
Binary file (15.7 kB). View file
@@ -0,0 +1,344 @@
1 |
from abc import ABC, abstractmethod
2 |
from enum import Enum
3 |
from typing import Tuple, Optional
4 |
5 |
import tensorflow as tf
6 |
from tensorflow.keras.layers import *
7 |
from tensorflow.keras.models import *
8 |
9 |
class BaseUNet(ABC):
10 |
11 |
Base Interface for UNet
12 |
13 |
14 |
def __init__(self, model: Model):
15 |
self.model: Model = model
16 |
17 |
def get_model(self):
18 |
return self.model
19 |
20 |
21 |
22 |
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
23 |
24 |
25 |
26 |
class UNet(Enum):
27 |
28 |
Enum class defining different architecture types available
29 |
30 |
31 |
32 |
33 |
34 |
35 |
def build_model(self, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
36 |
kernels: Optional[Tuple] = None) -> BaseUNet:
37 |
38 |
# set default filters
39 |
if filters is None:
40 |
filters = (16, 32, 64, 128, 256)
41 |
42 |
# set default kernels
43 |
if kernels is None:
44 |
kernels = list(3 for _ in range(len(filters)))
45 |
46 |
# check kernels and filters
47 |
if len(filters) != len(kernels):
48 |
raise Exception('Kernels and filter count has to match.')
49 |
50 |
51 |
print('Using default UNet model with imagenet embedding')
52 |
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=True)
53 |
elif self == UNet.RESNET:
54 |
print('Using UNet Resnet model')
55 |
return UNet_resnet.build_model(input_size, filters, kernels)
56 |
57 |
print('Using UNet Resnet model with attention mechanism and separable convolutions')
58 |
return UNet_ResNet_Attention_SeparableConv.build_model(input_size, filters, kernels)
59 |
60 |
print('Using default UNet model')
61 |
return UNetDefault.build_model(input_size, filters, kernels, use_embedding=False)
62 |
63 |
64 |
class Attention(Layer):
65 |
def __init__(self, **kwargs):
66 |
super(Attention, self).__init__(**kwargs)
67 |
68 |
def build(self, input_shape):
69 |
# Create a trainable weight variable for this layer.
70 |
self.kernel = self.add_weight(name='kernel',
71 |
shape=(input_shape[-1], 1),
72 |
73 |
74 |
self.bias = self.add_weight(name='bias',
75 |
76 |
77 |
78 |
super(Attention, self).build(input_shape) # Be sure to call this at the end
79 |
80 |
def call(self, x):
81 |
attention = tf.nn.softmax(tf.matmul(x, self.kernel) + self.bias, axis=-1)
82 |
return tf.multiply(x, attention)
83 |
84 |
def compute_output_shape(self, input_shape):
85 |
return input_shape
86 |
87 |
88 |
class UNet_ResNet_Attention_SeparableConv(BaseUNet):
89 |
90 |
UNet architecture with resnet blocks, attention mechanism and separable convolutions
91 |
92 |
93 |
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
94 |
95 |
p0 = Input(shape=input_size)
96 |
conv_outputs = []
97 |
first_layer = SeparableConv2D(filters[0], kernels[0], padding='same')(p0)
98 |
int_layer = first_layer
99 |
for i, f in enumerate(filters):
100 |
int_layer, skip = UNet_ResNet_Attention_SeparableConv.down_block(int_layer, f, kernels[i])
101 |
102 |
103 |
int_layer = UNet_ResNet_Attention_SeparableConv.bottleneck(int_layer, filters[-1], kernels[-1])
104 |
105 |
conv_outputs = list(reversed(conv_outputs))
106 |
reversed_filter = list(reversed(filters))
107 |
reversed_kernels = list(reversed(kernels))
108 |
for i, f in enumerate(reversed_filter):
109 |
if i + 1 < len(reversed_filter):
110 |
num_filters_next = reversed_filter[i + 1]
111 |
num_kernels_next = reversed_kernels[i + 1]
112 |
113 |
num_filters_next = f
114 |
num_kernels_next = reversed_kernels[i]
115 |
int_layer = UNet_ResNet_Attention_SeparableConv.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
116 |
int_layer = Attention()(int_layer)
117 |
118 |
# concat. with the first layer
119 |
int_layer = Concatenate()([first_layer, int_layer])
120 |
int_layer = SeparableConv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
121 |
outputs = SeparableConv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
122 |
model = Model(p0, outputs)
123 |
return UNet_ResNet_Attention_SeparableConv(model)
124 |
125 |
126 |
def down_block(x, num_filters: int = 64, kernel: int = 3):
127 |
# down-sample inputs
128 |
x = SeparableConv2D(num_filters, kernel, padding='same', strides=2, dilation_rate = 2)(x)
129 |
130 |
# inner block
131 |
out = SeparableConv2D(num_filters, kernel, padding='same')(x)
132 |
# out = BatchNormalization()(out)
133 |
out = Activation('relu')(out)
134 |
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
135 |
136 |
# merge with the skip connection
137 |
out = Add()([out, x])
138 |
# out = BatchNormalization()(out)
139 |
return Activation('relu')(out), x
140 |
141 |
142 |
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
143 |
# add U-Net skip connection - before up-sampling
144 |
concat = Concatenate()([x, skip])
145 |
146 |
# inner block
147 |
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(concat)
148 |
# out = BatchNormalization()(out)
149 |
out = Activation('relu')(out)
150 |
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
151 |
152 |
# merge with the skip connection
153 |
out = Add()([out, x])
154 |
# out = BatchNormalization()(out)
155 |
out = Activation('relu')(out)
156 |
157 |
# up-sample
158 |
out = UpSampling2D((2, 2))(out)
159 |
out = SeparableConv2D(num_filters_next, kernel, padding='same')(out)
160 |
# out = BatchNormalization()(out)
161 |
return Activation('relu')(out)
162 |
163 |
164 |
def bottleneck(x, num_filters: int = 64, kernel: int = 3):
165 |
# inner block
166 |
out = SeparableConv2D(num_filters, kernel, padding='same', dilation_rate = 2)(x)
167 |
# out = BatchNormalization()(out)
168 |
out = Activation('relu')(out)
169 |
out = SeparableConv2D(num_filters, kernel, padding='same')(out)
170 |
out = Add()([out, x])
171 |
# out = BatchNormalization()(out)
172 |
return Activation('relu')(out)
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
# Class for UNet with Resnet blocks
182 |
183 |
class UNet_resnet(BaseUNet):
184 |
185 |
UNet architecture with resnet blocks
186 |
187 |
188 |
189 |
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple):
190 |
191 |
p0 = Input(shape=input_size)
192 |
conv_outputs = []
193 |
first_layer = Conv2D(filters[0], kernels[0], padding='same')(p0)
194 |
int_layer = first_layer
195 |
for i, f in enumerate(filters):
196 |
int_layer, skip = UNet_resnet.down_block(int_layer, f, kernels[i])
197 |
198 |
199 |
int_layer = UNet_resnet.bottleneck(int_layer, filters[-1], kernels[-1])
200 |
201 |
conv_outputs = list(reversed(conv_outputs))
202 |
reversed_filter = list(reversed(filters))
203 |
reversed_kernels = list(reversed(kernels))
204 |
for i, f in enumerate(reversed_filter):
205 |
if i + 1 < len(reversed_filter):
206 |
num_filters_next = reversed_filter[i + 1]
207 |
num_kernels_next = reversed_kernels[i + 1]
208 |
209 |
num_filters_next = f
210 |
num_kernels_next = reversed_kernels[i]
211 |
int_layer = UNet_resnet.up_block(int_layer, conv_outputs[i], f, num_filters_next, num_kernels_next)
212 |
213 |
# concat. with the first layer
214 |
int_layer = Concatenate()([first_layer, int_layer])
215 |
int_layer = Conv2D(filters[0], kernels[0], padding="same", activation="relu")(int_layer)
216 |
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
217 |
model = Model(p0, outputs)
218 |
return UNet_resnet(model)
219 |
220 |
221 |
def down_block(x, num_filters: int = 64, kernel: int = 3):
222 |
# down-sample inputs
223 |
x = Conv2D(num_filters, kernel, padding='same', strides=2)(x)
224 |
225 |
# inner block
226 |
out = Conv2D(num_filters, kernel, padding='same')(x)
227 |
# out = BatchNormalization()(out)
228 |
out = Activation('relu')(out)
229 |
out = Conv2D(num_filters, kernel, padding='same')(out)
230 |
231 |
# merge with the skip connection
232 |
out = Add()([out, x])
233 |
# out = BatchNormalization()(out)
234 |
return Activation('relu')(out), x
235 |
236 |
237 |
def up_block(x, skip, num_filters: int = 64, num_filters_next: int = 64, kernel: int = 3):
238 |
239 |
# add U-Net skip connection - before up-sampling
240 |
concat = Concatenate()([x, skip])
241 |
242 |
# inner block
243 |
out = Conv2D(num_filters, kernel, padding='same')(concat)
244 |
# out = BatchNormalization()(out)
245 |
out = Activation('relu')(out)
246 |
out = Conv2D(num_filters, kernel, padding='same')(out)
247 |
248 |
# merge with the skip connection
249 |
out = Add()([out, x])
250 |
# out = BatchNormalization()(out)
251 |
out = Activation('relu')(out)
252 |
253 |
# add U-Net skip connection - before up-sampling
254 |
concat = Concatenate()([out, skip])
255 |
256 |
# up-sample
257 |
# out = UpSampling2D((2, 2))(concat)
258 |
out = Conv2DTranspose(num_filters_next, kernel, padding='same', strides=2)(concat)
259 |
out = Conv2D(num_filters_next, kernel, padding='same')(out)
260 |
# out = BatchNormalization()(out)
261 |
return Activation('relu')(out)
262 |
263 |
264 |
def bottleneck(x, filters, kernel: int = 3):
265 |
x = Conv2D(filters, kernel, padding='same', name='bottleneck')(x)
266 |
# x = BatchNormalization()(x)
267 |
return Activation('relu')(x)
268 |
269 |
270 |
class UNetDefault(BaseUNet):
271 |
272 |
UNet architecture from following github notebook for image segmentation:
273 |
274 |
275 |
276 |
277 |
278 |
def build_model(input_size: Tuple[int, int, int], filters: Tuple, kernels: Tuple, use_embedding: bool = True):
279 |
280 |
p0 = Input(input_size)
281 |
282 |
if use_embedding:
283 |
mobilenet_model = tf.keras.applications.MobileNetV2(
284 |
input_shape=input_size, include_top=False, weights='imagenet'
285 |
286 |
mobilenet_model.trainable = False
287 |
mn1 = mobilenet_model(p0)
288 |
mn1 = Reshape((16, 16, 320))(mn1)
289 |
290 |
conv_outputs = []
291 |
int_layer = p0
292 |
293 |
for f in filters:
294 |
conv_output, int_layer = UNetDefault.down_block(int_layer, f)
295 |
296 |
297 |
int_layer = UNetDefault.bottleneck(int_layer, filters[-1])
298 |
299 |
if use_embedding:
300 |
int_layer = Concatenate()([int_layer, mn1])
301 |
302 |
conv_outputs = list(reversed(conv_outputs))
303 |
for i, f in enumerate(reversed(filters)):
304 |
int_layer = UNetDefault.up_block(int_layer, conv_outputs[i], f)
305 |
306 |
int_layer = Conv2D(filters[0] // 2, 3, padding="same", activation="relu")(int_layer)
307 |
outputs = Conv2D(3, (1, 1), padding="same", activation="sigmoid")(int_layer)
308 |
model = Model(p0, outputs)
309 |
return UNetDefault(model)
310 |
311 |
312 |
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
313 |
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
314 |
# c = BatchNormalization()(c)
315 |
p = MaxPool2D((2, 2), (2, 2))(c)
316 |
return c, p
317 |
318 |
319 |
def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1):
320 |
us = UpSampling2D((2, 2))(x)
321 |
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(us)
322 |
# c = BatchNormalization()(c)
323 |
concat = Concatenate()([c, skip])
324 |
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
325 |
# c = BatchNormalization()(c)
326 |
return c
327 |
328 |
329 |
def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1):
330 |
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
331 |
# c = BatchNormalization()(c)
332 |
return c
333 |
334 |
335 |
if __name__ == "__main__":
336 |
filters = (64, 128, 128, 256, 256, 512)
337 |
kernels = (7, 7, 7, 3, 3, 3)
338 |
input_image_size = (256, 256, 3)
339 |
# model = UNet_resnet()
340 |
# model = model.build_model(input_size=input_image_size,filters=filters,kernels=kernels)
341 |
# print(model.summary())
342 |
# __init__() missing 1 required positional argument: 'model'
343 |
model = UNetDefault.build_model(input_size=input_image_size, filters=filters, kernels=kernels)
344 |
@@ -0,0 +1,22 @@
1 |
import os
2 |
import json
3 |
from dataclasses import dataclass
4 |
5 |
6 |
7 |
class Configuration:
8 |
def __init__(self, config_file_path: str = "configuration.json"):
9 |
self.config_file_path = config_file_path
10 |
self.config_json = None
11 |
if os.path.exists(config_file_path):
12 |
with open(self.config_file_path, 'r') as json_file:
13 |
self.config_json = json.load(json_file)
14 |
15 |
print(f'ERROR: Configuration JSON {config_file_path} does not exist.')
16 |
17 |
def get(self, key: str):
18 |
if key in self.config_json:
19 |
return self.config_json[key]
20 |
21 |
print(f'ERROR: Key \'{key}\' is not in configuration JSON.')
22 |
return None
@@ -0,0 +1,151 @@
1 |
import copy
2 |
import dlib
3 |
import os
4 |
import bz2
5 |
import random
6 |
from tqdm.notebook import tqdm
7 |
import shutil
8 |
from utils import image_to_array, load_image, download_data
9 |
from utils.face_detection import crop_face, get_face_keypoints_detecting_function
10 |
from mask_utils.mask_utils import mask_image
11 |
12 |
13 |
class DataGenerator:
14 |
def __init__(self, configuration):
15 |
self.configuration = configuration
16 |
self.path_to_data = configuration.get('input_images_path')
17 |
self.path_to_patterns = configuration.get('path_to_patterns')
18 |
self.minimal_confidence = configuration.get('minimal_confidence')
19 |
self.hyp_ratio = configuration.get('hyp_ratio')
20 |
self.coordinates_range = configuration.get('coordinates_range')
21 |
self.test_image_count = configuration.get('test_image_count')
22 |
self.train_image_count = configuration.get('train_image_count')
23 |
self.train_data_path = configuration.get('train_data_path')
24 |
self.test_data_path = configuration.get('test_data_path')
25 |
self.predictor_path = configuration.get('landmarks_predictor_path')
26 |
27 |
28 |
self.valid_image_extensions = ('png', 'jpg', 'jpeg')
29 |
self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(self.minimal_confidence)
30 |
31 |
def check_predictor(self):
32 |
""" Check if predictor exists. If not downloads it. """
33 |
if not os.path.exists(self.predictor_path):
34 |
print('Downloading missing predictor.')
35 |
url = self.configuration.get('landmarks_predictor_download_url')
36 |
download_data(url, self.predictor_path + '.bz2', 64040097)
37 |
print(f'Decompressing downloaded file into {self.predictor_path}')
38 |
with bz2.BZ2File(self.predictor_path + '.bz2') as fr, open(self.predictor_path, 'wb') as fw:
39 |
shutil.copyfileobj(fr, fw)
40 |
41 |
def get_face_landmarks(self, image):
42 |
"""Compute 68 facial landmarks"""
43 |
landmarks = []
44 |
image_array = image_to_array(image)
45 |
detector = dlib.get_frontal_face_detector()
46 |
predictor = dlib.shape_predictor(self.predictor_path)
47 |
face_rectangles = detector(image_array)
48 |
if len(face_rectangles) < 1:
49 |
return None
50 |
dlib_shape = predictor(image_array, face_rectangles[0])
51 |
for i in range(0, dlib_shape.num_parts):
52 |
landmarks.append([dlib_shape.part(i).x, dlib_shape.part(i).y])
53 |
return landmarks
54 |
55 |
def get_files_faces(self):
56 |
"""Get path of all images in dataset"""
57 |
image_files = []
58 |
for dirpath, dirs, files in os.walk(self.path_to_data):
59 |
for filename in files:
60 |
fname = os.path.join(dirpath, filename)
61 |
if fname.endswith(self.valid_image_extensions):
62 |
63 |
64 |
return image_files
65 |
66 |
def generate_images(self, image_size=None, test_image_count=None, train_image_count=None):
67 |
"""Generate test and train data (images with and without the mask)"""
68 |
if image_size is None:
69 |
image_size = self.configuration.get('image_size')
70 |
if test_image_count is None:
71 |
test_image_count = self.test_image_count
72 |
if train_image_count is None:
73 |
train_image_count = self.train_image_count
74 |
75 |
if not os.path.exists(self.train_data_path):
76 |
77 |
os.mkdir(os.path.join(self.train_data_path, 'inputs'))
78 |
os.mkdir(os.path.join(self.train_data_path, 'outputs'))
79 |
80 |
if not os.path.exists(self.test_data_path):
81 |
82 |
os.mkdir(os.path.join(self.test_data_path, 'inputs'))
83 |
os.mkdir(os.path.join(self.test_data_path, 'outputs'))
84 |
85 |
print('Generating testing data')
86 |
87 |
88 |
89 |
print('Generating training data')
90 |
91 |
92 |
93 |
94 |
def generate_data(self, number_of_images, image_size=None, save_to=None):
95 |
""" Add masks on `number_of_images` images
96 |
if save_to is valid path to folder images are saved there otherwise generated data are just returned in list
97 |
98 |
inputs = []
99 |
outputs = []
100 |
101 |
if image_size is None:
102 |
image_size = self.configuration.get('image_size')
103 |
104 |
for i, file in tqdm(enumerate(random.sample(self.get_files_faces(), number_of_images)), total=number_of_images):
105 |
# Load images
106 |
image = load_image(file)
107 |
108 |
# Detect keypoints and landmarks on face
109 |
face_landmarks = self.get_face_landmarks(image)
110 |
if face_landmarks is None:
111 |
112 |
keypoints = self.face_keypoints_detecting_fun(image)
113 |
114 |
# Generate mask
115 |
image_with_mask = mask_image(copy.deepcopy(image), face_landmarks, self.configuration)
116 |
117 |
# Crop images
118 |
cropped_image = crop_face(image_with_mask, keypoints)
119 |
cropped_original = crop_face(image, keypoints)
120 |
121 |
# Resize all images to NN input size
122 |
res_image = cropped_image.resize(image_size)
123 |
res_original = cropped_original.resize(image_size)
124 |
125 |
# Save generated data to lists or to folder
126 |
if save_to is None:
127 |
128 |
129 |
130 |
+, 'inputs', f"{i:06d}.png"))
131 |
+, 'outputs', f"{i:06d}.png"))
132 |
133 |
if save_to is None:
134 |
return inputs, outputs
135 |
136 |
def get_dataset_examples(self, n=10, test_dataset=False):
137 |
138 |
Returns `n` random images form dataset. If `test_dataset` parameter
139 |
is not provided or False it will return images from training part of dataset.
140 |
If `test_dataset` parameter is True it will return images from testing part of dataset.
141 |
142 |
if test_dataset:
143 |
data_path = self.test_data_path
144 |
145 |
data_path = self.train_data_path
146 |
147 |
images = os.listdir(os.path.join(data_path, 'inputs'))
148 |
images = random.sample(images, n)
149 |
inputs = [os.path.join(data_path, 'inputs', img) for img in images]
150 |
outputs = [os.path.join(data_path, 'outputs', img) for img in images]
151 |
return inputs, outputs
@@ -0,0 +1,111 @@
1 |
"""Functions for face detection"""
2 |
from math import pi
3 |
from typing import Tuple, Optional, Dict
4 |
5 |
import tensorflow as tf
6 |
import matplotlib.patches as patches
7 |
import matplotlib.pyplot as plt
8 |
from PIL import Image
9 |
from mtcnn import MTCNN
10 |
from trianglesolver import solve
11 |
12 |
from utils import image_to_array
13 |
14 |
15 |
def compute_slacks(height, width, hyp_ratio) -> Tuple[float, float]:
16 |
"""Compute slacks to add to bounding box on each site"""
17 |
18 |
# compute angle and side for hypotenuse
19 |
_, b, _, A, _, _ = solve(c=width, a=height, B=pi / 2)
20 |
21 |
# compute new height and width
22 |
a, _, c, _, _, _ = solve(b=b * (1.0 + hyp_ratio), B=pi / 2, A=A)
23 |
24 |
# compute slacks
25 |
return c - width, a - height
26 |
27 |
28 |
def get_face_keypoints_detecting_function(minimal_confidence: float = 0.8):
29 |
"""Create function for face keypoints detection"""
30 |
31 |
# face detector
32 |
detector = MTCNN()
33 |
34 |
# detect faces and their keypoints
35 |
def get_keypoints(image: Image) -> Optional[Dict]:
36 |
37 |
# run inference to detect faces (on CPU only)
38 |
with tf.device("/cpu:0"):
39 |
detection = detector.detect_faces(image_to_array(image))
40 |
41 |
# run detection and keep results with certain confidence only
42 |
results = [item for item in detection if item['confidence'] > minimal_confidence]
43 |
44 |
# nothing found
45 |
if len(results) == 0:
46 |
return None
47 |
48 |
# return result with highest confidence and size
49 |
return max(results, key=lambda item: item['confidence'] * item['box'][2] * item['box'][3])
50 |
51 |
# return function
52 |
return get_keypoints
53 |
54 |
55 |
def plot_face_detection(image: Image, ax, face_keypoints: Optional, hyp_ratio: float = 1 / 3):
56 |
"""Plot faces with keypoints and bounding boxes"""
57 |
58 |
# make annotations
59 |
if face_keypoints is not None:
60 |
61 |
# get bounding box
62 |
x, y, width, height = face_keypoints['box']
63 |
64 |
# add rectangle patch for detected face
65 |
rectangle = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
66 |
67 |
68 |
# add rectangle patch with slacks
69 |
w_s, h_s = compute_slacks(height, width, hyp_ratio)
70 |
rectangle = patches.Rectangle((x - w_s, y - h_s), width + 2 * w_s, height + 2 * h_s, linewidth=1, edgecolor='r',
71 |
72 |
73 |
74 |
# add keypoints
75 |
for coordinates in face_keypoints['keypoints'].values():
76 |
circle = plt.Circle(coordinates, 3, color='r')
77 |
78 |
79 |
# add image
80 |
81 |
82 |
83 |
def get_crop_points(image: Image, face_keypoints: Optional, hyp_ratio: float = 1 / 3) -> Image:
84 |
"""Find position where to crop face from image"""
85 |
if face_keypoints is None:
86 |
return 0, 0, image.width, image.height
87 |
88 |
# get bounding box
89 |
x, y, width, height = face_keypoints['box']
90 |
91 |
# compute slacks
92 |
w_s, h_s = compute_slacks(height, width, hyp_ratio)
93 |
94 |
# compute coordinates
95 |
left = min(max(0, x - w_s), image.width)
96 |
upper = min(max(0, y - h_s), image.height)
97 |
right = min(x + width + w_s, image.width)
98 |
lower = min(y + height + h_s, image.height)
99 |
100 |
return left, upper, right, lower
101 |
102 |
103 |
def crop_face(image: Image, face_keypoints: Optional, hyp_ratio: float = 1 / 3) -> Image:
104 |
"""Crop input image to just the face"""
105 |
if face_keypoints is None:
106 |
print("No keypoints detected on image")
107 |
return image
108 |
109 |
left, upper, right, lower = get_crop_points(image, face_keypoints, hyp_ratio)
110 |
111 |
return image.crop((left, upper, right, lower))
@@ -0,0 +1,495 @@
1 |
import os
2 |
from datetime import datetime
3 |
from glob import glob
4 |
from typing import Tuple, Optional
5 |
from utils import load_image
6 |
import random
7 |
import cv2
8 |
import numpy as np
9 |
import tensorflow as tf
10 |
from PIL import Image
11 |
from sklearn.model_selection import train_test_split
12 |
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
13 |
from tensorflow.keras.utils import CustomObjectScope
14 |
from utils.face_detection import get_face_keypoints_detecting_function, crop_face, get_crop_points
15 |
from utils.architectures import UNet
16 |
from tensorflow.keras.losses import MeanSquaredError, mean_squared_error
17 |
from keras_vggface.vggface import VGGFace
18 |
import tensorflow.keras.backend as K
19 |
from tensorflow.keras.applications import VGG19
20 |
21 |
# # VGG19 model for perceptual loss
22 |
# vgg = VGG19(include_top=False, weights='imagenet')
23 |
24 |
# def preprocess_image(image):
25 |
# image = tf.image.resize(image, (224, 224))
26 |
# image = tf.keras.applications.vgg19.preprocess_input(image)
27 |
# return image
28 |
29 |
# def perceptual_loss(y_true, y_pred):
30 |
# y_true = preprocess_image(y_true)
31 |
# y_pred = preprocess_image(y_pred)
32 |
# y_true_c = vgg(y_true)
33 |
# y_pred_c = vgg(y_pred)
34 |
# loss = K.mean(K.square(y_pred_c - y_true_c))
35 |
# return loss
36 |
37 |
vgg_face_model = VGGFace(model='resnet50', include_top=False, input_shape=(256, 256, 3), pooling='avg')
38 |
39 |
40 |
class ModelLoss:
41 |
42 |
43 |
def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
44 |
45 |
Computes MS-SSIM and perceptual loss
46 |
@param gt: Ground truth image
47 |
@param y_pred: Predicted image
48 |
@param max_val: Maximal MS-SSIM value
49 |
@param l1_weight: Weight of L1 normalization
50 |
@return: MS-SSIM and perceptual loss
51 |
52 |
53 |
# Compute SSIM loss
54 |
ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
55 |
56 |
# Compute perceptual loss
57 |
vgg_face_outputs = vgg_face_model(y_pred)
58 |
vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt)))
59 |
60 |
61 |
# Combine both losses with l1 normalization
62 |
l1 = mean_squared_error(gt, y_pred)
63 |
l1_casted = tf.cast(l1 * l1_weight, tf.float32)
64 |
return ssim_loss + l1_casted + vgg_face_loss
65 |
66 |
67 |
68 |
class LFUNet(tf.keras.models.Model):
69 |
70 |
Model for Mask2Face - removes mask from people faces using U-net neural network
71 |
72 |
def __init__(self, model: tf.keras.models.Model, configuration=None, *args, **kwargs):
73 |
super().__init__(*args, **kwargs)
74 |
self.model: tf.keras.models.Model = model
75 |
self.configuration = configuration
76 |
self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(0.8)
77 |
self.mse = MeanSquaredError()
78 |
79 |
def call(self, x, **kwargs):
80 |
return self.model(x)
81 |
82 |
83 |
84 |
def ssim_loss(gt, y_pred, max_val=1.0):
85 |
86 |
Computes standard SSIM loss
87 |
@param gt: Ground truth image
88 |
@param y_pred: Predicted image
89 |
@param max_val: Maximal SSIM value
90 |
@return: SSIM loss
91 |
92 |
return 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
93 |
94 |
95 |
96 |
def ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
97 |
98 |
Computes SSIM loss with L1 normalization
99 |
@param gt: Ground truth image
100 |
@param y_pred: Predicted image
101 |
@param max_val: Maximal SSIM value
102 |
@param l1_weight: Weight of L1 normalization
103 |
@return: SSIM L1 loss
104 |
105 |
ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
106 |
l1 = mean_squared_error(gt, y_pred)
107 |
return ssim_loss + tf.cast(l1 * l1_weight, tf.float32)
108 |
109 |
110 |
111 |
# @staticmethod
112 |
# @tf.function
113 |
# def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0, perceptual_weight=1.0):
114 |
# """
115 |
# Computes MS-SSIM loss, L1 loss, and perceptual loss
116 |
# @param gt: Ground truth image
117 |
# @param y_pred: Predicted image
118 |
# @param max_val: Maximal SSIM value
119 |
# @param l1_weight: Weight of L1 normalization
120 |
# @param perceptual_weight: Weight of perceptual loss
121 |
# @return: MS-SSIM L1 perceptual loss
122 |
# """
123 |
# y_pred = tf.clip_by_value(y_pred, 0, float("inf"))
124 |
# y_pred = tf.debugging.check_numerics(y_pred, message='y_pred has NaN values')
125 |
# ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val))
126 |
# l1_loss = tf.losses.mean_absolute_error(gt, y_pred)
127 |
# vgg_face_outputs = vgg_face_model(y_pred)
128 |
# vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt)))
129 |
# return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32) + perceptual_weight*vgg_face_loss
130 |
131 |
# Function for ms-ssim loss + l1 loss
132 |
133 |
134 |
def ms_ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0):
135 |
136 |
Computes MS-SSIM loss and L1 loss
137 |
@param gt: Ground truth image
138 |
@param y_pred: Predicted image
139 |
@param max_val: Maximal SSIM value
140 |
@param l1_weight: Weight of L1 normalization
141 |
@return: MS-SSIM L1 loss
142 |
143 |
# Replace NaN values with 0
144 |
y_pred = tf.clip_by_value(y_pred, 0, float("inf"))
145 |
146 |
ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val))
147 |
l1_loss = tf.losses.mean_absolute_error(gt, y_pred)
148 |
return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32)
149 |
150 |
151 |
152 |
def load_model(model_path, configuration=None):
153 |
154 |
Loads saved h5 file with trained model.
155 |
@param configuration: Optional instance of Configuration with config JSON
156 |
@param model_path: Path to h5 file
157 |
@return: LFUNet
158 |
159 |
with CustomObjectScope({'ssim_loss': LFUNet.ssim_loss, 'ssim_l1_loss': LFUNet.ssim_l1_loss, 'ms_ssim_l1_perceptual_loss': ModelLoss.ms_ssim_l1_perceptual_loss, 'ms_ssim_l1_loss': LFUNet.ms_ssim_l1_loss}):
160 |
model = tf.keras.models.load_model(model_path)
161 |
return LFUNet(model, configuration)
162 |
163 |
164 |
def build_model(architecture: UNet, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None,
165 |
kernels: Optional[Tuple] = None, configuration=None):
166 |
167 |
Builds model based on input arguments
168 |
@param architecture: utils.architectures.UNet architecture
169 |
@param input_size: Size of input images
170 |
@param filters: Tuple with sizes of filters in U-net
171 |
@param kernels: Tuple with sizes of kernels in U-net. Must be the same size as filters.
172 |
@param configuration: Optional instance of Configuration with config JSON
173 |
@return: LFUNet
174 |
175 |
return LFUNet(architecture.build_model(input_size, filters, kernels).get_model(), configuration)
176 |
177 |
def train(self, epochs=20, batch_size=20, loss_function='mse', learning_rate=1e-4,
178 |
predict_difference: bool = False):
179 |
180 |
Train the model.
181 |
@param epochs: Number of epochs during training
182 |
@param batch_size: Batch size
183 |
@param loss_function: Loss function. Either standard tensorflow loss function or `ssim_loss` or `ssim_l1_loss`
184 |
@param learning_rate: Learning rate
185 |
@param predict_difference: Compute prediction on difference between input and output image
186 |
@return: History of training
187 |
188 |
# get data
189 |
(train_x, train_y), (valid_x, valid_y) = self.load_train_data()
190 |
(test_x, test_y) = self.load_test_data()
191 |
192 |
train_dataset = LFUNet.tf_dataset(train_x, train_y, batch_size, predict_difference)
193 |
valid_dataset = LFUNet.tf_dataset(valid_x, valid_y, batch_size, predict_difference, train=False)
194 |
test_dataset = LFUNet.tf_dataset(test_x, test_y, batch_size, predict_difference, train=False)
195 |
196 |
# select loss
197 |
if loss_function == 'ssim_loss':
198 |
loss = LFUNet.ssim_loss
199 |
elif loss_function == 'ssim_l1_loss':
200 |
loss = LFUNet.ssim_l1_loss
201 |
elif loss_function == 'ms_ssim_l1_perceptual_loss':
202 |
loss = ModelLoss.ms_ssim_l1_perceptual_loss
203 |
elif loss_function == 'ms_ssim_l1_loss':
204 |
loss = LFUNet.ms_ssim_l1_loss
205 |
206 |
loss = loss_function
207 |
208 |
# compile loss with selected loss function
209 |
210 |
211 |
212 |
metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
213 |
214 |
215 |
# define callbacks
216 |
callbacks = [
217 |
218 |
219 |
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
220 |
221 |
222 |
# evaluation before training
223 |
results = self.model.evaluate(test_dataset)
224 |
print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results))
225 |
226 |
# fit the model
227 |
history =, validation_data=valid_dataset, epochs=epochs, callbacks=callbacks)
228 |
229 |
# evaluation after training
230 |
results = self.model.evaluate(test_dataset)
231 |
print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results))
232 |
233 |
# use the model for inference on several test images
234 |
self._test_results(test_x, test_y, predict_difference)
235 |
236 |
# return training history
237 |
return history
238 |
239 |
def _test_results(self, test_x, test_y, predict_difference: bool):
240 |
241 |
Test trained model on testing dataset. All images in testing dataset are processed and result image triples
242 |
(input with mask, ground truth, model output) are stored to `data/results` into folder with time stamp
243 |
when this method was executed.
244 |
@param test_x: List of input images
245 |
@param test_y: List of ground truth output images
246 |
@param predict_difference: Compute prediction on difference between input and output image
247 |
@return: None
248 |
249 |
if self.configuration is None:
250 |
result_dir = f'data/results/{LFUNet.get_datetime_string()}/'
251 |
252 |
result_dir = os.path.join(self.configuration.get('test_results_dir'), LFUNet.get_datetime_string())
253 |
os.makedirs(result_dir, exist_ok=True)
254 |
255 |
for i, (x, y) in enumerate(zip(test_x, test_y)):
256 |
x = LFUNet.read_image(x)
257 |
y = LFUNet.read_image(y)
258 |
259 |
y_pred = self.model.predict(np.expand_dims(x, axis=0))
260 |
if predict_difference:
261 |
y_pred = (y_pred * 2) - 1
262 |
y_pred = np.clip(x - y_pred.squeeze(axis=0), 0.0, 1.0)
263 |
264 |
y_pred = y_pred.squeeze(axis=0)
265 |
h, w, _ = x.shape
266 |
white_line = np.ones((h, 10, 3)) * 255.0
267 |
268 |
all_images = [
269 |
x * 255.0, white_line,
270 |
y * 255.0, white_line,
271 |
y_pred * 255.0
272 |
273 |
image = np.concatenate(all_images, axis=1)
274 |
cv2.imwrite(os.path.join(result_dir, f"{i}.png"), image)
275 |
276 |
def summary(self):
277 |
278 |
Prints model summary
279 |
280 |
281 |
282 |
def predict(self, img_path, predict_difference: bool = False):
283 |
284 |
Use trained model to take down the mask from image with person wearing the mask.
285 |
@param img_path: Path to image to processed
286 |
@param predict_difference: Compute prediction on difference between input and output image
287 |
@return: Image without the mask on the face
288 |
289 |
# Load image into RGB format
290 |
image = load_image(img_path)
291 |
image = image.convert('RGB')
292 |
293 |
# Find facial keypoints and crop the image to just the face
294 |
keypoints = self.face_keypoints_detecting_fun(image)
295 |
cropped_image = crop_face(image, keypoints)
296 |
297 |
298 |
# Resize image to input recognized by neural net
299 |
resized_image = cropped_image.resize((256, 256))
300 |
image_array = np.array(resized_image)
301 |
302 |
# Convert from RGB to BGR (open cv format)
303 |
image_array = image_array[:, :, ::-1].copy()
304 |
image_array = image_array / 255.0
305 |
306 |
# Remove mask from input image
307 |
y_pred = self.model.predict(np.expand_dims(image_array, axis=0))
308 |
h, w, _ = image_array.shape
309 |
310 |
if predict_difference:
311 |
y_pred = (y_pred * 2) - 1
312 |
y_pred = np.clip(image_array - y_pred.squeeze(axis=0), 0.0, 1.0)
313 |
314 |
y_pred = y_pred.squeeze(axis=0)
315 |
316 |
# Convert output from model to image and scale it back to original size
317 |
y_pred = y_pred * 255.0
318 |
im = Image.fromarray(y_pred.astype(np.uint8)[:, :, ::-1])
319 |
im = im.resize(cropped_image.size)
320 |
left, upper, _, _ = get_crop_points(image, keypoints)
321 |
322 |
# Combine original image with output from model
323 |
image.paste(im, (int(left), int(upper)))
324 |
return image
325 |
326 |
327 |
def get_datetime_string():
328 |
329 |
Creates date-time string
330 |
@return: String with current date and time
331 |
332 |
now =
333 |
return now.strftime("%Y%m%d_%H_%M_%S")
334 |
335 |
def load_train_data(self, split=0.2):
336 |
337 |
Loads training data (paths to training images)
338 |
@param split: Percentage of training data used for validation as float from 0.0 to 1.0. Default 0.2.
339 |
@return: Two tuples - first with training data (tuple with (input images, output images)) and second
340 |
with validation data (tuple with (input images, output images))
341 |
342 |
if self.configuration is None:
343 |
train_dir = 'data/train/'
344 |
limit = None
345 |
346 |
train_dir = self.configuration.get('train_data_path')
347 |
limit = self.configuration.get('train_data_limit')
348 |
print(f'Loading training data from {train_dir} with limit of {limit} images')
349 |
return LFUNet.load_data(os.path.join(train_dir, 'inputs'), os.path.join(train_dir, 'outputs'), split, limit)
350 |
351 |
def load_test_data(self):
352 |
353 |
Loads testing data (paths to testing images)
354 |
@return: Tuple with testing data - (input images, output images)
355 |
356 |
if self.configuration is None:
357 |
test_dir = 'data/test/'
358 |
limit = None
359 |
360 |
test_dir = self.configuration.get('test_data_path')
361 |
limit = self.configuration.get('test_data_limit')
362 |
print(f'Loading testing data from {test_dir} with limit of {limit} images')
363 |
return LFUNet.load_data(os.path.join(test_dir, 'inputs'), os.path.join(test_dir, 'outputs'), None, limit)
364 |
365 |
366 |
def load_data(input_path, output_path, split=0.2, limit=None):
367 |
368 |
Loads data (paths to images)
369 |
@param input_path: Path to folder with input images
370 |
@param output_path: Path to folder with output images
371 |
@param split: Percentage of data used for validation as float from 0.0 to 1.0. Default 0.2.
372 |
If split is None it expects you are loading testing data, otherwise expects training data.
373 |
@param limit: Maximal number of images loaded from data folder. Default None (no limit).
374 |
@return: If split is not None: Two tuples - first with training data (tuple with (input images, output images))
375 |
and second with validation data (tuple with (input images, output images))
376 |
Else: Tuple with testing data - (input images, output images)
377 |
378 |
images = sorted(glob(os.path.join(input_path, "*.png")))
379 |
masks = sorted(glob(os.path.join(output_path, "*.png")))
380 |
if len(images) == 0:
381 |
raise TypeError(f'No images found in {input_path}')
382 |
if len(masks) == 0:
383 |
raise TypeError(f'No images found in {output_path}')
384 |
385 |
if limit is not None:
386 |
images = images[:limit]
387 |
masks = masks[:limit]
388 |
389 |
if split is not None:
390 |
total_size = len(images)
391 |
valid_size = int(split * total_size)
392 |
train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42)
393 |
train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42)
394 |
return (train_x, train_y), (valid_x, valid_y)
395 |
396 |
397 |
return images, masks
398 |
399 |
400 |
def read_image(path):
401 |
402 |
Loads image, resize it to size 256x256 and normalize to float values from 0.0 to 1.0.
403 |
@param path: Path to image to be loaded.
404 |
@return: Loaded image in open CV format.
405 |
406 |
x = cv2.imread(path, cv2.IMREAD_COLOR)
407 |
x = cv2.resize(x, (256, 256))
408 |
x = x / 255.0
409 |
return x
410 |
411 |
412 |
def tf_parse(x, y):
413 |
414 |
Mapping function for dataset creation. Load and resize images.
415 |
@param x: Path to input image
416 |
@param y: Path to output image
417 |
@return: Tuple with input and output image with shape (256, 256, 3)
418 |
419 |
def _parse(x, y):
420 |
x = LFUNet.read_image(x.decode())
421 |
y = LFUNet.read_image(y.decode())
422 |
return x, y
423 |
424 |
x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
425 |
x.set_shape([256, 256, 3])
426 |
y.set_shape([256, 256, 3])
427 |
return x, y
428 |
429 |
430 |
def tf_dataset(x, y, batch=8, predict_difference: bool = False, train: bool = True):
431 |
432 |
Creates standard tensorflow dataset.
433 |
@param x: List of paths to input images
434 |
@param y: List of paths to output images
435 |
@param batch: Batch size
436 |
@param predict_difference: Compute prediction on difference between input and output image
437 |
@param train: Flag if training dataset should be generated
438 |
@return: Dataset with loaded images
439 |
440 |
dataset =, y))
441 |
dataset =
442 |
random_seed = random.randint(0, 999999999)
443 |
444 |
if predict_difference:
445 |
def map_output(img_in, img_target):
446 |
return img_in, (img_in - img_target + 1.0) / 2.0
447 |
448 |
dataset =
449 |
450 |
if train:
451 |
# for the train set, we want to apply data augmentations and shuffle data to different batches
452 |
453 |
# random flip
454 |
def flip(img_in, img_out):
455 |
return tf.image.random_flip_left_right(img_in, random_seed), \
456 |
tf.image.random_flip_left_right(img_out, random_seed)
457 |
458 |
# augmenting quality - parameters
459 |
hue_delta = 0.05
460 |
saturation_low = 0.2
461 |
saturation_up = 1.3
462 |
brightness_delta = 0.1
463 |
contrast_low = 0.2
464 |
contrast_up = 1.5
465 |
466 |
# augmenting quality
467 |
def color(img_in, img_out):
468 |
# Augmentations applied are:
469 |
# - random hue
470 |
# - random saturation
471 |
# - random brightness
472 |
# - random contrast
473 |
# - random flip left right
474 |
# - random flip up down
475 |
img_in = tf.image.random_hue(img_in, hue_delta, random_seed)
476 |
img_in = tf.image.random_saturation(img_in, saturation_low, saturation_up, random_seed)
477 |
img_in = tf.image.random_brightness(img_in, brightness_delta, random_seed)
478 |
img_in = tf.image.random_contrast(img_in, contrast_low, contrast_up, random_seed)
479 |
img_out = tf.image.random_hue(img_out, hue_delta, random_seed)
480 |
img_out = tf.image.random_saturation(img_out, saturation_low, saturation_up, random_seed)
481 |
img_out = tf.image.random_brightness(img_out, brightness_delta, random_seed)
482 |
img_out = tf.image.random_contrast(img_out, contrast_low, contrast_up, random_seed)
483 |
return img_in, img_out
484 |
485 |
# shuffle data and create batches
486 |
dataset = dataset.shuffle(5000)
487 |
dataset = dataset.batch(batch)
488 |
489 |
# apply augmentations
490 |
dataset =
491 |
dataset =
492 |
493 |
dataset = dataset.batch(batch)
494 |
495 |
return dataset.prefetch(