|
from utils.configuration import Configuration |
|
import tensorflow as tf |
|
from utils.model import ModelLoss |
|
from utils.model import LFUNet |
|
from utils.architectures import UNet |
|
|
|
import gradio as gr |
|
|
|
configuration = Configuration() |
|
filters = (64, 128, 128, 256, 256, 512) |
|
kernels = (7, 7, 7, 3, 3, 3) |
|
input_image_size = (256, 256, 3) |
|
architecture = UNet.RESIDUAL_ATTENTION_UNET_SEPARABLE_CONV |
|
|
|
trained_model = LFUNet.build_model(architecture=architecture, input_size=input_image_size, filters=filters, |
|
kernels=kernels, configuration=configuration) |
|
trained_model.compile( |
|
loss=ModelLoss.ms_ssim_l1_perceptual_loss, |
|
optimizer=tf.keras.optimizers.Adam(1e-4), |
|
metrics=["acc", tf.keras.metrics.Recall(), tf.keras.metrics.Precision()] |
|
) |
|
|
|
weights_path = "model_weights/model_epochs-40_batch-20_loss-ms_ssim_l1_perceptual_loss_20230210_15_45_38.ckpt" |
|
trained_model.load_weights(weights_path) |
|
|
|
def main(input_img): |
|
try: |
|
print(input_img) |
|
predicted_image = trained_model.predict(input_img) |
|
return predicted_image |
|
except Exception as e: |
|
raise gr.Error("Sorry, something went wrong. Please try again!") |
|
|
|
demo = gr.Interface( |
|
title= "Lightweight network for face unmasking", |
|
description= "This is a demo of a <b>Lightweight network for face unmasking</b> \ |
|
designed to provide a powerful and efficient solution for restoring facial details obscured by masks.<br> \ |
|
To use it, simply upload your image, or click one of the examples to load them. Inference in demo may take some time because of connectivity reasons.", |
|
fn = main, |
|
inputs= gr.Image(type="filepath").style(height=256), |
|
outputs=gr.Image(type='numpy',shape=(256, 256, 3)).style(height=256), |
|
|
|
examples=[ |
|
["examples/1.png"], |
|
["examples/2.png"], |
|
["examples/3.png"], |
|
["examples/4.png"], |
|
["examples/5.png"], |
|
["examples/6.png"], |
|
["examples/7.png"], |
|
["examples/8.png"], |
|
["examples/9.png"], |
|
["examples/10.png"], |
|
["examples/11.png"], |
|
["examples/12.png"], |
|
], |
|
css = """ |
|
.svelte-mppz8v { |
|
text-align: -webkit-center; |
|
} |
|
|
|
.gallery { |
|
display: flex; |
|
flex-wrap: wrap; |
|
width: 100%; |
|
} |
|
|
|
p { |
|
font-size: medium; |
|
} |
|
|
|
h1 { |
|
font-size: xx-large; |
|
} |
|
""", |
|
|
|
theme = 'xiaobaiyuan/theme_brief', |
|
cache_examples=False |
|
|
|
) |
|
demo.launch(show_error=True) |
|
|