3DVascNet / app.py
Hemaxi's picture
Update app.py
9dfef1f verified
raw
history blame
3.24 kB
import gradio as gr
import numpy as np
import tifffile
from aicsimageio import AICSImage # To handle .czi files
import matplotlib.pyplot as plt
import predict_mask from prediction
# Placeholder for your 3D model
def process_3d_image(image):
# Dummy model implementation: Replace with your actual model logic
model_dir = 'https://huggingface.co/Hemaxi/3DCycleGAN/tree/main/CycleGANVesselSegmentation.h5'
binary_mask = predict_mask(model_dir,image)
return binary_mask
def auximread(filepath):
image = tifffile.imread(filepath)
#the output image should be (X,Y,Z)
original_0 = np.shape(image)[0]
original_1 = np.shape(image)[1]
original_2 = np.shape(image)[2]
index_min = np.argmin([original_0, original_1, original_2])
if index_min == 0:
image = image.transpose(1,2,0)
elif index_min == 1:
image = image.transpose(0,2,1)
return image
# Function to handle file input and processing
def process_file(file):
"""
Process the uploaded file and return the binary mask.
"""
if file.name.endswith(".tif"):
# Load .tif file as a 3D numpy array
image = auximread(file.name)
else:
raise ValueError("Unsupported file format. Please upload a .tif or .czi file.")
# Ensure image is 3D
if len(image.shape) != 3:
raise ValueError("Input image is not 3D.")
# Process image through the model
binary_mask = process_3d_image(image)
# Save binary mask to a .tif file to return
output_path = "output_mask.tif"
tiff.imwrite(output_path, binary_mask)
return image, binary_mask, output_path
# Function to generate the slice visualization
def visualize_slice(image, mask, slice_index):
"""
Visualizes a 2D slice of the image and the corresponding mask at the given index.
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Extract the 2D slices
image_slice = image[slice_index, :, :]
mask_slice = mask[slice_index, :, :]
# Plot image slice
axes[0].imshow(image_slice, cmap="gray")
axes[0].set_title("Image Slice")
axes[0].axis("off")
# Plot mask slice
axes[1].imshow(mask_slice, cmap="gray")
axes[1].set_title("Mask Slice")
axes[1].axis("off")
# Return the plot as a Gradio-compatible output
plt.tight_layout()
plt.close(fig)
return fig
# Gradio Interface function
def interface(file, slice_index):
image, mask, output_path = process_file(file)
fig = visualize_slice(image, mask, slice_index)
return fig, output_path
# Gradio Interface
iface = gr.Interface(
fn=interface,
inputs=[
gr.File(label="Upload 3D Image (.tif or .czi)"),
gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index") # No live=True
],
outputs=[
gr.Plot(label="2D Slice Visualization"),
gr.File(label="Download Binary Mask (.tif)")
],
title="3D Image Processing with Binary Mask Output",
description="Upload a 3D image in .tif or .czi format. The model will process the image and output a 3D binary mask. Use the slider to navigate through the 2D slices.",
live=True # Enable real-time updates in Interface
)
if __name__ == "__main__":
iface.launch()