itzjunayed's picture
Update app.py
2fca5ba verified
raw
history blame
2.42 kB
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
from keras.utils import normalize
from PIL import Image
def dice_coef(y_true, y_pred):
smooth = 1e-5
intersection = K.sum(y_true * y_pred, axis=[1, 2, 3])
union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3])
return K.mean((2.0 * intersection + smooth) / (union + smooth), axis=0)
def predict_segmentation(image):
original_size = (image.shape[1], image.shape[0]) # (width, height)
# Resize to the model's input size
SIZE_X = 128
SIZE_Y = 128
img = cv2.resize(image, (SIZE_Y, SIZE_X))
if len(img.shape) == 3 and img.shape[2] == 3: # If the image is RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Convert to grayscale
img = np.expand_dims(img, axis=2) # Add the channel dimension
img = normalize(img, axis=1)
X_test = np.expand_dims(img, axis=0) # Add the batch dimension
custom_objects = {'dice_coef': dice_coef}
with tf.keras.utils.custom_object_scope(custom_objects):
model = tf.keras.models.load_model("model100.h5")
# Get the prediction
prediction = model.predict(X_test)
predicted_img = np.argmax(prediction, axis=3)[0, :, :]
# Resize prediction back to original image size
predicted_img_resized = cv2.resize(predicted_img, original_size, interpolation=cv2.INTER_NEAREST)
# Create an RGBA image with a transparent background
rgba_img = np.zeros((predicted_img_resized.shape[0], predicted_img_resized.shape[1], 4), dtype=np.uint8)
# Define the color for the segmented area (e.g., red)
segmented_color = [255, 0, 0] # Red color in RGB
# Set the segmented area to the desired color
for i in range(3):
rgba_img[:, :, i] = np.where(predicted_img_resized > 0, segmented_color[i], 0)
# Create an alpha channel: 255 where there is segmentation, 0 otherwise
rgba_img[:, :, 3] = np.where(predicted_img_resized > 0, 255, 0)
# Convert the numpy array to an image
output_image = Image.fromarray(rgba_img)
# Save the image as PNG to return it
output_image_path = "/tmp/segmented_output.png"
output_image.save(output_image_path)
return output_image_path
# Gradio Interface
iface = gr.Interface(
fn=predict_segmentation,
inputs="image",
outputs="file", # Return the file path to download the PNG
live=False
)
iface.launch(share=True)