import gradio as gr import tensorflow as tf import numpy as np import cv2 # Add this import statement from keras.utils import normalize def dice_coef(y_true, y_pred): smooth = 1e-5 intersection = K.sum(y_true * y_pred, axis=[1, 2, 3]) union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) return K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0) def predict_segmentation(image): SIZE_X = 128 SIZE_Y = 128 train_images = [] img = cv2.imread(image, 0) img = cv2.resize(img, (SIZE_Y, SIZE_X)) train_images.append(img) train_images = np.array(train_images) train_images = np.expand_dims(train_images, axis=3) train_images = normalize(train_images, axis=1) X_test = train_images custom_objects = {'dice_coef': dice_coef} with tf.keras.utils.custom_object_scope(custom_objects): model = tf.keras.models.load_model("model100.h5") # Normalize the test image test_img = X_test[0] test_img_norm = test_img[:, :, 0][:, :, None] test_img_input = np.expand_dims(test_img_norm, 0) # Get the prediction prediction = model.predict(test_img_input) predicted_img = np.argmax(prediction, axis=3)[0, :, :] # Create an RGB image with a transparent background rgba_img = np.zeros((predicted_img.shape[0], predicted_img.shape[1], 4)) # Define the color for the segmented area (e.g., red) segmented_color = [1, 0, 0] # Red color in RGB # Set the segmented area to the desired color for i in range(3): rgba_img[:, :, i] = np.where(predicted_img > 0, segmented_color[i], 0) # Create an alpha channel: 1 where there is segmentation, 0 otherwise alpha_channel = np.where(predicted_img > 0, 1, 0) rgba_img[:, :, 3] = alpha_channel # img = cv2.resize(image, (SIZE_Y, SIZE_X)) # img = np.expand_dims(img, axis=2) # img = normalize(img, axis=1) # # Prepare image for prediction # img = np.expand_dims(img, axis=0) # # Predict # prediction = model.predict(img) # predicted_img = np.argmax(prediction, axis=3)[0, :, :] return rgba_img # Gradio Interface iface = gr.Interface( fn=predict_segmentation, inputs="image", outputs="image", live=False ) iface.launch()