import os |
from datetime import datetime |
from glob import glob |
from typing import Tuple, Optional |
from utils import load_image |
import random |
import cv2 |
import numpy as np |
import tensorflow as tf |
from PIL import Image |
from sklearn.model_selection import train_test_split |
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint |
from tensorflow.keras.utils import CustomObjectScope |
from utils.face_detection import get_face_keypoints_detecting_function, crop_face, get_crop_points |
from utils.architectures import UNet |
from tensorflow.keras.losses import MeanSquaredError, mean_squared_error |
from keras_vggface.vggface import VGGFace |
import tensorflow.keras.backend as K |
from tensorflow.keras.applications import VGG19 |
vgg_face_model = VGGFace(model='resnet50', include_top=False, input_shape=(256, 256, 3), pooling='avg') |
class ModelLoss: |
@staticmethod |
@tf.function |
def ms_ssim_l1_perceptual_loss(gt, y_pred, max_val=1.0, l1_weight=1.0): |
""" |
Computes MS-SSIM and perceptual loss |
@param gt: Ground truth image |
@param y_pred: Predicted image |
@param max_val: Maximal MS-SSIM value |
@param l1_weight: Weight of L1 normalization |
@return: MS-SSIM and perceptual loss |
""" |
ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val)) |
vgg_face_outputs = vgg_face_model(y_pred) |
vgg_face_loss = tf.reduce_mean(tf.losses.mean_squared_error(vgg_face_outputs,vgg_face_model(gt))) |
l1 = mean_squared_error(gt, y_pred) |
l1_casted = tf.cast(l1 * l1_weight, tf.float32) |
return ssim_loss + l1_casted + vgg_face_loss |
class LFUNet(tf.keras.models.Model): |
""" |
Model for Mask2Face - removes mask from people faces using U-net neural network |
""" |
def __init__(self, model: tf.keras.models.Model, configuration=None, *args, **kwargs): |
super().__init__(*args, **kwargs) |
self.model: tf.keras.models.Model = model |
self.configuration = configuration |
self.face_keypoints_detecting_fun = get_face_keypoints_detecting_function(0.8) |
self.mse = MeanSquaredError() |
def call(self, x, **kwargs): |
return self.model(x) |
@staticmethod |
@tf.function |
def ssim_loss(gt, y_pred, max_val=1.0): |
""" |
Computes standard SSIM loss |
@param gt: Ground truth image |
@param y_pred: Predicted image |
@param max_val: Maximal SSIM value |
@return: SSIM loss |
""" |
return 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val)) |
@staticmethod |
@tf.function |
def ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0): |
""" |
Computes SSIM loss with L1 normalization |
@param gt: Ground truth image |
@param y_pred: Predicted image |
@param max_val: Maximal SSIM value |
@param l1_weight: Weight of L1 normalization |
@return: SSIM L1 loss |
""" |
ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val)) |
l1 = mean_squared_error(gt, y_pred) |
return ssim_loss + tf.cast(l1 * l1_weight, tf.float32) |
@staticmethod |
@tf.function |
def ms_ssim_l1_loss(gt, y_pred, max_val=1.0, l1_weight=1.0): |
""" |
Computes MS-SSIM loss and L1 loss |
@param gt: Ground truth image |
@param y_pred: Predicted image |
@param max_val: Maximal SSIM value |
@param l1_weight: Weight of L1 normalization |
@return: MS-SSIM L1 loss |
""" |
y_pred = tf.clip_by_value(y_pred, 0, float("inf")) |
ms_ssim_loss = 1 - tf.reduce_mean(tf.image.ssim_multiscale(gt, y_pred, max_val=max_val)) |
l1_loss = tf.losses.mean_absolute_error(gt, y_pred) |
return ms_ssim_loss + tf.cast(l1_loss * l1_weight, tf.float32) |
@staticmethod |
def load_model(model_path, configuration=None): |
""" |
Loads saved h5 file with trained model. |
@param configuration: Optional instance of Configuration with config JSON |
@param model_path: Path to h5 file |
@return: LFUNet |
""" |
with CustomObjectScope({'ssim_loss': LFUNet.ssim_loss, 'ssim_l1_loss': LFUNet.ssim_l1_loss, 'ms_ssim_l1_perceptual_loss': ModelLoss.ms_ssim_l1_perceptual_loss, 'ms_ssim_l1_loss': LFUNet.ms_ssim_l1_loss}): |
model = tf.keras.models.load_model(model_path) |
return LFUNet(model, configuration) |
@staticmethod |
def build_model(architecture: UNet, input_size: Tuple[int, int, int], filters: Optional[Tuple] = None, |
kernels: Optional[Tuple] = None, configuration=None): |
""" |
Builds model based on input arguments |
@param architecture: utils.architectures.UNet architecture |
@param input_size: Size of input images |
@param filters: Tuple with sizes of filters in U-net |
@param kernels: Tuple with sizes of kernels in U-net. Must be the same size as filters. |
@param configuration: Optional instance of Configuration with config JSON |
@return: LFUNet |
""" |
return LFUNet(architecture.build_model(input_size, filters, kernels).get_model(), configuration) |
def train(self, epochs=20, batch_size=20, loss_function='mse', learning_rate=1e-4, |
predict_difference: bool = False): |
""" |
Train the model. |
@param epochs: Number of epochs during training |
@param batch_size: Batch size |
@param loss_function: Loss function. Either standard tensorflow loss function or `ssim_loss` or `ssim_l1_loss` |
@param learning_rate: Learning rate |
@param predict_difference: Compute prediction on difference between input and output image |
@return: History of training |
""" |
(train_x, train_y), (valid_x, valid_y) = self.load_train_data() |
(test_x, test_y) = self.load_test_data() |
train_dataset = LFUNet.tf_dataset(train_x, train_y, batch_size, predict_difference) |
valid_dataset = LFUNet.tf_dataset(valid_x, valid_y, batch_size, predict_difference, train=False) |
test_dataset = LFUNet.tf_dataset(test_x, test_y, batch_size, predict_difference, train=False) |
if loss_function == 'ssim_loss': |
loss = LFUNet.ssim_loss |
elif loss_function == 'ssim_l1_loss': |
loss = LFUNet.ssim_l1_loss |
elif loss_function == 'ms_ssim_l1_perceptual_loss': |
loss = ModelLoss.ms_ssim_l1_perceptual_loss |
elif loss_function == 'ms_ssim_l1_loss': |
loss = LFUNet.ms_ssim_l1_loss |
else: |
loss = loss_function |
self.model.compile( |
loss=loss, |
optimizer=tf.keras.optimizers.Adam(learning_rate), |
metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()] |
) |
callbacks = [ |
ModelCheckpoint( |
f'models/model_epochs-{epochs}_batch-{batch_size}_loss-{loss_function}_{LFUNet.get_datetime_string()}.h5'), |
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) |
] |
results = self.model.evaluate(test_dataset) |
print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results)) |
history = self.model.fit(train_dataset, validation_data=valid_dataset, epochs=epochs, callbacks=callbacks) |
results = self.model.evaluate(test_dataset) |
print("- TEST -> LOSS: {:10.4f}, ACC: {:10.4f}, RECALL: {:10.4f}, PRECISION: {:10.4f}".format(*results)) |
self._test_results(test_x, test_y, predict_difference) |
return history |
def _test_results(self, test_x, test_y, predict_difference: bool): |
""" |
Test trained model on testing dataset. All images in testing dataset are processed and result image triples |
(input with mask, ground truth, model output) are stored to `data/results` into folder with time stamp |
when this method was executed. |
@param test_x: List of input images |
@param test_y: List of ground truth output images |
@param predict_difference: Compute prediction on difference between input and output image |
@return: None |
""" |
if self.configuration is None: |
result_dir = f'data/results/{LFUNet.get_datetime_string()}/' |
else: |
result_dir = os.path.join(self.configuration.get('test_results_dir'), LFUNet.get_datetime_string()) |
os.makedirs(result_dir, exist_ok=True) |
for i, (x, y) in enumerate(zip(test_x, test_y)): |
x = LFUNet.read_image(x) |
y = LFUNet.read_image(y) |
y_pred = self.model.predict(np.expand_dims(x, axis=0)) |
if predict_difference: |
y_pred = (y_pred * 2) - 1 |
y_pred = np.clip(x - y_pred.squeeze(axis=0), 0.0, 1.0) |
else: |
y_pred = y_pred.squeeze(axis=0) |
h, w, _ = x.shape |
white_line = np.ones((h, 10, 3)) * 255.0 |
all_images = [ |
x * 255.0, white_line, |
y * 255.0, white_line, |
y_pred * 255.0 |
] |
image = np.concatenate(all_images, axis=1) |
cv2.imwrite(os.path.join(result_dir, f"{i}.png"), image) |
def summary(self): |
""" |
Prints model summary |
""" |
self.model.summary() |
def predict(self, img_path, predict_difference: bool = False): |
""" |
Use trained model to take down the mask from image with person wearing the mask. |
@param img_path: Path to image to processed |
@param predict_difference: Compute prediction on difference between input and output image |
@return: Image without the mask on the face |
""" |
image = load_image(img_path) |
image = image.convert('RGB') |
keypoints = self.face_keypoints_detecting_fun(image) |
cropped_image = crop_face(image, keypoints) |
print(cropped_image.size) |
resized_image = cropped_image.resize((256, 256)) |
image_array = np.array(resized_image) |
image_array = image_array[:, :, ::-1].copy() |
image_array = image_array / 255.0 |
y_pred = self.model.predict(np.expand_dims(image_array, axis=0)) |
h, w, _ = image_array.shape |
if predict_difference: |
y_pred = (y_pred * 2) - 1 |
y_pred = np.clip(image_array - y_pred.squeeze(axis=0), 0.0, 1.0) |
else: |
y_pred = y_pred.squeeze(axis=0) |
y_pred = y_pred * 255.0 |
im = Image.fromarray(y_pred.astype(np.uint8)[:, :, ::-1]) |
im = im.resize(cropped_image.size) |
left, upper, _, _ = get_crop_points(image, keypoints) |
image.paste(im, (int(left), int(upper))) |
return image |
@staticmethod |
def get_datetime_string(): |
""" |
Creates date-time string |
@return: String with current date and time |
""" |
now = datetime.now() |
return now.strftime("%Y%m%d_%H_%M_%S") |
def load_train_data(self, split=0.2): |
""" |
Loads training data (paths to training images) |
@param split: Percentage of training data used for validation as float from 0.0 to 1.0. Default 0.2. |
@return: Two tuples - first with training data (tuple with (input images, output images)) and second |
with validation data (tuple with (input images, output images)) |
""" |
if self.configuration is None: |
train_dir = 'data/train/' |
limit = None |
else: |
train_dir = self.configuration.get('train_data_path') |
limit = self.configuration.get('train_data_limit') |
print(f'Loading training data from {train_dir} with limit of {limit} images') |
return LFUNet.load_data(os.path.join(train_dir, 'inputs'), os.path.join(train_dir, 'outputs'), split, limit) |
def load_test_data(self): |
""" |
Loads testing data (paths to testing images) |
@return: Tuple with testing data - (input images, output images) |
""" |
if self.configuration is None: |
test_dir = 'data/test/' |
limit = None |
else: |
test_dir = self.configuration.get('test_data_path') |
limit = self.configuration.get('test_data_limit') |
print(f'Loading testing data from {test_dir} with limit of {limit} images') |
return LFUNet.load_data(os.path.join(test_dir, 'inputs'), os.path.join(test_dir, 'outputs'), None, limit) |
@staticmethod |
def load_data(input_path, output_path, split=0.2, limit=None): |
""" |
Loads data (paths to images) |
@param input_path: Path to folder with input images |
@param output_path: Path to folder with output images |
@param split: Percentage of data used for validation as float from 0.0 to 1.0. Default 0.2. |
If split is None it expects you are loading testing data, otherwise expects training data. |
@param limit: Maximal number of images loaded from data folder. Default None (no limit). |
@return: If split is not None: Two tuples - first with training data (tuple with (input images, output images)) |
and second with validation data (tuple with (input images, output images)) |
Else: Tuple with testing data - (input images, output images) |
""" |
images = sorted(glob(os.path.join(input_path, "*.png"))) |
masks = sorted(glob(os.path.join(output_path, "*.png"))) |
if len(images) == 0: |
raise TypeError(f'No images found in {input_path}') |
if len(masks) == 0: |
raise TypeError(f'No images found in {output_path}') |
if limit is not None: |
images = images[:limit] |
masks = masks[:limit] |
if split is not None: |
total_size = len(images) |
valid_size = int(split * total_size) |
train_x, valid_x = train_test_split(images, test_size=valid_size, random_state=42) |
train_y, valid_y = train_test_split(masks, test_size=valid_size, random_state=42) |
return (train_x, train_y), (valid_x, valid_y) |
else: |
return images, masks |
@staticmethod |
def read_image(path): |
""" |
Loads image, resize it to size 256x256 and normalize to float values from 0.0 to 1.0. |
@param path: Path to image to be loaded. |
@return: Loaded image in open CV format. |
""" |
x = cv2.imread(path, cv2.IMREAD_COLOR) |
x = cv2.resize(x, (256, 256)) |
x = x / 255.0 |
return x |
@staticmethod |
def tf_parse(x, y): |
""" |
Mapping function for dataset creation. Load and resize images. |
@param x: Path to input image |
@param y: Path to output image |
@return: Tuple with input and output image with shape (256, 256, 3) |
""" |
def _parse(x, y): |
x = LFUNet.read_image(x.decode()) |
y = LFUNet.read_image(y.decode()) |
return x, y |
x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64]) |
x.set_shape([256, 256, 3]) |
y.set_shape([256, 256, 3]) |
return x, y |
@staticmethod |
def tf_dataset(x, y, batch=8, predict_difference: bool = False, train: bool = True): |
""" |
Creates standard tensorflow dataset. |
@param x: List of paths to input images |
@param y: List of paths to output images |
@param batch: Batch size |
@param predict_difference: Compute prediction on difference between input and output image |
@param train: Flag if training dataset should be generated |
@return: Dataset with loaded images |
""" |
dataset = tf.data.Dataset.from_tensor_slices((x, y)) |
dataset = dataset.map(LFUNet.tf_parse) |
random_seed = random.randint(0, 999999999) |
if predict_difference: |
def map_output(img_in, img_target): |
return img_in, (img_in - img_target + 1.0) / 2.0 |
dataset = dataset.map(map_output) |
if train: |
def flip(img_in, img_out): |
return tf.image.random_flip_left_right(img_in, random_seed), \ |
tf.image.random_flip_left_right(img_out, random_seed) |
hue_delta = 0.05 |
saturation_low = 0.2 |
saturation_up = 1.3 |
brightness_delta = 0.1 |
contrast_low = 0.2 |
contrast_up = 1.5 |
def color(img_in, img_out): |
img_in = tf.image.random_hue(img_in, hue_delta, random_seed) |
img_in = tf.image.random_saturation(img_in, saturation_low, saturation_up, random_seed) |
img_in = tf.image.random_brightness(img_in, brightness_delta, random_seed) |
img_in = tf.image.random_contrast(img_in, contrast_low, contrast_up, random_seed) |
img_out = tf.image.random_hue(img_out, hue_delta, random_seed) |
img_out = tf.image.random_saturation(img_out, saturation_low, saturation_up, random_seed) |
img_out = tf.image.random_brightness(img_out, brightness_delta, random_seed) |
img_out = tf.image.random_contrast(img_out, contrast_low, contrast_up, random_seed) |
return img_in, img_out |
dataset = dataset.shuffle(5000) |
dataset = dataset.batch(batch) |
dataset = dataset.map(flip) |
dataset = dataset.map(color) |
else: |
dataset = dataset.batch(batch) |
return dataset.prefetch(tf.data.experimental.AUTOTUNE) |