File size: 4,750 Bytes
a134324 9dfef1f 30b140a a06d491 f9fed5c 000a77e f9fed5c 000a77e f9fed5c 9dfef1f c6a31fa 9dfef1f c6a31fa 9dfef1f c6a31fa 9dfef1f f9fed5c 000a77e f9fed5c 9dfef1f a0eb692 f9fed5c 000a77e f9fed5c a63fd9b f9fed5c 30b140a f9fed5c 30b140a c6a31fa 30b140a f9fed5c 30b140a c6a31fa 000a77e c6a31fa 000a77e 2480ca8 c6a31fa 30b140a c6a31fa f7a8dfa c6a31fa 84297cc c6a31fa 000a77e c6a31fa 2480ca8 c6a31fa 000a77e c6a31fa a134324 f9fed5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from prediction import predict_mask
# Placeholder for your 3D model
def process_3d_image(image, resx, resy, resz):
# Dummy model implementation: Replace with your actual model logic
binary_mask = predict_mask(image, resx, resy, resz)
return binary_mask
def auximread(filepath):
image = tifffile.imread(filepath)
# The output image should be (X,Y,Z)
original_0 = np.shape(image)[0]
original_1 = np.shape(image)[1]
original_2 = np.shape(image)[2]
index_min = np.argmin([original_0, original_1, original_2])
if index_min == 0:
image = image.transpose(1, 2, 0)
elif index_min == 1:
image = image.transpose(0, 2, 1)
return image
# Function to handle file input and processing
def process_file(file, resx, resy, resz):
"""
Process the uploaded file and return the binary mask.
"""
if file.name.endswith(".tif"):
# Load .tif file as a 3D numpy array
image = auximread(file.name)
else:
raise ValueError("Unsupported file format. Please upload a .tif or .czi file.")
# Ensure image is 3D
if len(image.shape) != 3:
raise ValueError("Input image is not 3D.")
# Process image through the model
binary_mask = process_3d_image(image, resx, resy, resz)
# Save binary mask to a .tif file to return
output_path = "output_mask.tif"
tifffile.imwrite(output_path, binary_mask)
return image, binary_mask, output_path
# Function to generate the slice visualization
def visualize_slice(image, mask, slice_index):
"""
Visualizes a 2D slice of the image and the corresponding mask at the given index.
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Extract the 2D slices
image_slice = image[:, :, slice_index]
mask_slice = mask[:, :, slice_index]
# Plot image slice
axes[0].imshow(image_slice, cmap="gray")
axes[0].set_title("Image Slice")
axes[0].axis("off")
# Plot mask slice
axes[1].imshow(mask_slice, cmap="gray")
axes[1].set_title("Mask Slice")
axes[1].axis("off")
# Return the plot as a Gradio-compatible output
plt.tight_layout()
plt.close(fig)
return fig
# Variables to store the processed image and mask
processed_image = None
processed_mask = None
def segment_button_click(file, resx, resy, resz):
global processed_image, processed_mask
processed_image, processed_mask, output_path = process_file(file, resx, resy, resz)
num_slices = processed_image.shape[2]
return "Segmentation completed! Use the slider to explore slices.", output_path, gr.update(visible=True, maximum=num_slices - 1)
def update_visualization(slice_index):
if processed_image is None or processed_mask is None:
raise ValueError("Please process an image first by clicking the Segment button.")
return visualize_slice(processed_image, processed_mask, slice_index)
# Gradio Interface
with gr.Blocks() as iface:
gr.Markdown("""# 3DVascNet: Retinal Blood Vessel Segmentation
Upload a 3D image in .tif format. Click the **Segment** button to process the image and generate a 3D binary mask.
Use the slider to navigate through the 2D slices. This is the official implementation of 3DVascNet, described in this paper: https://www.ahajournals.org/doi/10.1161/ATVBAHA.124.320672.
The raw code is available at https://github.com/HemaxiN/3DVascNet.
""")
# Input fields for resolution in micrometers
with gr.Row():
resx_input = gr.Number(value=0.333, label="Resolution in X (µm)", precision=3)
resy_input = gr.Number(value=0.333, label="Resolution in Y (µm)", precision=3)
resz_input = gr.Number(value=0.5, label="Resolution in Z (µm)", precision=3)
with gr.Row():
file_input = gr.File(label="Upload 3D Image (.tif)")
segment_button = gr.Button("Segment")
status_output = gr.Textbox(label="Status", interactive=False)
download_output = gr.File(label="Download Binary Mask (.tif)")
with gr.Row():
slice_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index", interactive=True, visible=False)
visualization_output = gr.Plot(label="2D Slice Visualization")
# Button click triggers segmentation
segment_button.click(segment_button_click,
inputs=[file_input, resx_input, resy_input, resz_input],
outputs=[status_output, download_output, slice_slider])
# Slider changes trigger visualization updates
slice_slider.change(update_visualization, inputs=slice_slider, outputs=visualization_output)
if __name__ == "__main__":
iface.launch()
|