import gradio as gr import tensorflow as tf import numpy as np import cv2 from keras.utils import normalize from PIL import Image def dice_coef(y_true, y_pred): smooth = 1e-5 intersection = K.sum(y_true * y_pred, axis=[1, 2, 3]) union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) return K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0) def predict_segmentation(image): original_size = (image.shape[1], image.shape[0]) # (width, height) # Resize to the model's input size SIZE_X = 128 SIZE_Y = 128 img = cv2.resize(image, (SIZE_Y, SIZE_X)) if len(img.shape) == 3 and img.shape[2] == 3: # If the image is RGB img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Convert to grayscale img = np.expand_dims(img, axis=2) # Add the channel dimension img = normalize(img, axis=1) X_test = np.expand_dims(img, axis=0) # Add the batch dimension custom_objects = {'dice_coef': dice_coef} with tf.keras.utils.custom_object_scope(custom_objects): model = tf.keras.models.load_model("model100.h5") # Get the prediction prediction = model.predict(X_test) predicted_img = np.argmax(prediction, axis=3)[0, :, :] # Resize prediction back to original image size predicted_img_resized = cv2.resize(predicted_img, original_size, interpolation=cv2.INTER_NEAREST) # Create an RGBA image with a transparent background rgba_img = np.zeros((predicted_img_resized.shape[0], predicted_img_resized.shape[1], 4), dtype=np.uint8) # Define the color for the segmented area (e.g., red) segmented_color = [255, 0, 0] # Red color in RGB # Set the segmented area to the desired color for i in range(3): rgba_img[:, :, i] = np.where(predicted_img_resized > 0, segmented_color[i], 0) # Create an alpha channel: 255 where there is segmentation, 0 otherwise rgba_img[:, :, 3] = np.where(predicted_img_resized > 0, 255, 0) # Convert the numpy array to an image output_image = Image.fromarray(rgba_img) # Save the image as PNG to return it output_image_path = "/tmp/segmented_output.png" output_image.save(output_image_path) return output_image_path # Gradio Interface iface = gr.Interface( fn=predict_segmentation, inputs="image", outputs="file", # Return the file path to download the PNG live=False ) iface.launch(share=True)