seg_hair / app.py
KhanhVan's picture
Update app.py
db1903a
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, MaxPooling2D, Conv2D, UpSampling2D, concatenate, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint
from tensorflow.keras import backend as K
import os
import cv2
import gradio as gr
import numpy as np
import time
SIZE_IMG = 128
DATA_DIR = '/content/drive/MyDrive/Dataset/'
def dice(y_true, y_pred, smooth=1.):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
model = tf.keras.models.load_model('mymodel-pretrain.h5', custom_objects = {'dice':dice})
def process(img, r,g,b):
img_raw = img
img_org = img_raw.copy()
img = cv2.cvtColor(img_raw, cv2.COLOR_BGR2GRAY)
h,w = img.shape
img = cv2.resize(img,(128,128))
arr_img = np.array(img)/255.0
arr_img = arr_img.reshape(-1, 128,128,1)
res = model.predict(arr_img)
mask = np.array(cv2.resize(res[0]*255,(w,h)), np.uint8)
def changeValue(img, mask, max_range, min_range, axis = 2):
value = img[:,:,axis]
x = value[mask>50]
if (len(x)==0):
return img
if (x.max() == x.min()):
index = np.where(mask>50)
for i,j in zip(index[0],index[1]):
img[i,j,axis] = (max_range + min_range)/2
return np.array(img ,np.uint8)
else:
index = np.where(mask>50)
ratio = (max_range-min_range)/(x.max()-x.min())
x = x*ratio + min_range - x.min()*ratio
n = 0
for i,j in zip(index[0],index[1]):
img[i,j,axis] = x[n]
n += 1
return np.array(img ,np.uint8)
def changeColor(img , mask, img_org, r = 0, g = 0, b = 0):
img = changeValue(img, mask, r, 0,2)
img = changeValue(img, mask, g, 0,1)
img = changeValue(img, mask, b, 0,0)
return img
return changeColor(img_raw, mask, img_org, r,g, b)
def image_classifier(i,color):
(b,g,r) = tuple(int(color[i:i+2], 16) for i in (1, 3, 5))
return process(i, r, g, b)
inputs = ["image", gr.ColorPicker(label="color")]
demo = gr.Interface(fn=image_classifier, inputs=inputs, outputs="image")
demo.launch()