itzjunayed's picture
Update app.py
edf8500 verified
raw
history blame
2.35 kB
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
from keras.utils import normalize
from PIL import Image
def dice_coef(y_true, y_pred):
smooth = 1e-5
intersection = K.sum(y_true * y_pred, axis=[1, 2, 3])
union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3])
return K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
def predict_segmentation(image):
original_size = (image.shape[1], image.shape[0]) # (width, height)
# Resize to the model's input size
SIZE_X = 128
SIZE_Y = 128
img = cv2.resize(image, (SIZE_Y, SIZE_X))
if len(img.shape) == 3 and img.shape[2] == 3: # If the image is RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Convert to grayscale
img = np.expand_dims(img, axis=2) # Add the channel dimension
img = normalize(img, axis=1)
X_test = np.expand_dims(img, axis=0) # Add the batch dimension
custom_objects = {'dice_coef': dice_coef}
with tf.keras.utils.custom_object_scope(custom_objects):
model = tf.keras.models.load_model("model100.h5")
# Get the prediction
prediction = model.predict(X_test)
predicted_img = np.argmax(prediction, axis=3)[0, :, :]
# Resize prediction back to original image size
predicted_img_resized = cv2.resize(predicted_img, original_size, interpolation=cv2.INTER_NEAREST)
# Create an RGBA image with a transparent background
rgba_img = np.zeros((predicted_img_resized.shape[0], predicted_img_resized.shape[1], 4), dtype=np.uint8)
# Define the color for the segmented area (e.g., red)
segmented_color = [255, 0, 0] # Red color in RGB
# Set the segmented area to the desired color
for i in range(3):
rgba_img[:, :, i] = np.where(predicted_img_resized > 0, segmented_color[i], 0)
# Create an alpha channel: 255 where there is segmentation, 0 otherwise
rgba_img[:, :, 3] = np.where(predicted_img_resized > 0, 255, 0)
# Convert the numpy array to an image and save it as a PNG file
output_image = Image.fromarray(rgba_img)
output_image_path = "segmented_output.png"
output_image.save(output_image_path)
return output_image_path
# Gradio Interface
iface = gr.Interface(
fn=predict_segmentation,
inputs="image",
outputs="file",
live=False
)
iface.launch(share=True)