import tensorflow as tf from tensorflow.keras.layers import Dense, Dropout, MaxPooling2D, Conv2D, UpSampling2D, concatenate, Input from tensorflow.keras.optimizers import Adam from tensorflow.keras.models import Model from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint from tensorflow.keras import backend as K import os import cv2 import gradio as gr import numpy as np import time SIZE_IMG = 128 DATA_DIR = '/content/drive/MyDrive/Dataset/' def dice(y_true, y_pred, smooth=1.): y_true_f = K.flatten(y_true) y_pred_f = K.flatten(y_pred) intersection = K.sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) model = tf.keras.models.load_model('mymodel-pretrain.h5', custom_objects = {'dice':dice}) def process(img, r,g,b): img_raw = img img_org = img_raw.copy() img = cv2.cvtColor(img_raw, cv2.COLOR_BGR2GRAY) h,w = img.shape img = cv2.resize(img,(128,128)) arr_img = np.array(img)/255.0 arr_img = arr_img.reshape(-1, 128,128,1) res = model.predict(arr_img) mask = np.array(cv2.resize(res[0]*255,(w,h)), np.uint8) def changeValue(img, mask, max_range, min_range, axis = 2): value = img[:,:,axis] x = value[mask>50] if (len(x)==0): return img if (x.max() == x.min()): index = np.where(mask>50) for i,j in zip(index[0],index[1]): img[i,j,axis] = (max_range + min_range)/2 return np.array(img ,np.uint8) else: index = np.where(mask>50) ratio = (max_range-min_range)/(x.max()-x.min()) x = x*ratio + min_range - x.min()*ratio n = 0 for i,j in zip(index[0],index[1]): img[i,j,axis] = x[n] n += 1 return np.array(img ,np.uint8) def changeColor(img , mask, img_org, r = 0, g = 0, b = 0): img = changeValue(img, mask, r, 0,2) img = changeValue(img, mask, g, 0,1) img = changeValue(img, mask, b, 0,0) return img return changeColor(img_raw, mask, img_org, r,g, b) def image_classifier(i,color): (b,g,r) = tuple(int(color[i:i+2], 16) for i in (1, 3, 5)) return process(i, r, g, b) inputs = ["image", gr.ColorPicker(label="color")] demo = gr.Interface(fn=image_classifier, inputs=inputs, outputs="image") demo.launch()