File size: 4,750 Bytes
a134324
 
9dfef1f
30b140a
a06d491
f9fed5c
 
000a77e
f9fed5c
000a77e
f9fed5c
 
9dfef1f
 
 
c6a31fa
9dfef1f
 
 
 
 
 
 
c6a31fa
9dfef1f
c6a31fa
9dfef1f
 
 
f9fed5c
000a77e
f9fed5c
 
 
 
 
9dfef1f
a0eb692
f9fed5c
 
 
 
 
 
 
000a77e
f9fed5c
 
 
a63fd9b
f9fed5c
30b140a
f9fed5c
30b140a
 
 
 
 
 
 
 
c6a31fa
 
30b140a
 
 
 
 
f9fed5c
30b140a
 
 
 
 
 
 
 
 
 
c6a31fa
 
 
 
000a77e
c6a31fa
000a77e
2480ca8
 
c6a31fa
 
 
 
 
30b140a
 
c6a31fa
f7a8dfa
c6a31fa
84297cc
 
c6a31fa
000a77e
 
 
 
 
 
c6a31fa
 
 
 
 
 
 
 
 
2480ca8
c6a31fa
 
 
000a77e
 
 
c6a31fa
 
 
a134324
 
f9fed5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from prediction import predict_mask

# Placeholder for your 3D model
def process_3d_image(image, resx, resy, resz):
    # Dummy model implementation: Replace with your actual model logic
    binary_mask = predict_mask(image, resx, resy, resz)
    return binary_mask

def auximread(filepath):
    image = tifffile.imread(filepath)
        
    # The output image should be (X,Y,Z)
    original_0 = np.shape(image)[0]
    original_1 = np.shape(image)[1]
    original_2 = np.shape(image)[2]

    index_min = np.argmin([original_0, original_1, original_2])

    if index_min == 0:
        image = image.transpose(1, 2, 0)
    elif index_min == 1:
        image = image.transpose(0, 2, 1)

    return image

# Function to handle file input and processing
def process_file(file, resx, resy, resz):
    """
    Process the uploaded file and return the binary mask.
    """
    if file.name.endswith(".tif"):
        # Load .tif file as a 3D numpy array
        image = auximread(file.name)
    else:
        raise ValueError("Unsupported file format. Please upload a .tif or .czi file.")

    # Ensure image is 3D
    if len(image.shape) != 3:
        raise ValueError("Input image is not 3D.")

    # Process image through the model
    binary_mask = process_3d_image(image, resx, resy, resz)

    # Save binary mask to a .tif file to return
    output_path = "output_mask.tif"
    tifffile.imwrite(output_path, binary_mask)

    return image, binary_mask, output_path

# Function to generate the slice visualization
def visualize_slice(image, mask, slice_index):
    """
    Visualizes a 2D slice of the image and the corresponding mask at the given index.
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # Extract the 2D slices
    image_slice = image[:, :, slice_index]
    mask_slice = mask[:, :, slice_index]

    # Plot image slice
    axes[0].imshow(image_slice, cmap="gray")
    axes[0].set_title("Image Slice")
    axes[0].axis("off")

    # Plot mask slice
    axes[1].imshow(mask_slice, cmap="gray")
    axes[1].set_title("Mask Slice")
    axes[1].axis("off")

    # Return the plot as a Gradio-compatible output
    plt.tight_layout()
    plt.close(fig)
    return fig

# Variables to store the processed image and mask
processed_image = None
processed_mask = None

def segment_button_click(file, resx, resy, resz):
    global processed_image, processed_mask
    processed_image, processed_mask, output_path = process_file(file, resx, resy, resz)
    num_slices = processed_image.shape[2]
    return "Segmentation completed! Use the slider to explore slices.", output_path, gr.update(visible=True, maximum=num_slices - 1)

def update_visualization(slice_index):
    if processed_image is None or processed_mask is None:
        raise ValueError("Please process an image first by clicking the Segment button.")
    return visualize_slice(processed_image, processed_mask, slice_index)

# Gradio Interface
with gr.Blocks() as iface:
    gr.Markdown("""# 3DVascNet: Retinal Blood Vessel Segmentation
    Upload a 3D image in .tif format. Click the **Segment** button to process the image and generate a 3D binary mask.
    Use the slider to navigate through the 2D slices. This is the official implementation of 3DVascNet, described in this paper: https://www.ahajournals.org/doi/10.1161/ATVBAHA.124.320672.
    The raw code is available at https://github.com/HemaxiN/3DVascNet.
    """)
    
    # Input fields for resolution in micrometers
    with gr.Row():
        resx_input = gr.Number(value=0.333, label="Resolution in X (µm)", precision=3)
        resy_input = gr.Number(value=0.333, label="Resolution in Y (µm)", precision=3)
        resz_input = gr.Number(value=0.5, label="Resolution in Z (µm)", precision=3)

    with gr.Row():
        file_input = gr.File(label="Upload 3D Image (.tif)")
        segment_button = gr.Button("Segment")

    status_output = gr.Textbox(label="Status", interactive=False)
    download_output = gr.File(label="Download Binary Mask (.tif)")

    with gr.Row():
        slice_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index", interactive=True, visible=False)
        visualization_output = gr.Plot(label="2D Slice Visualization")

    # Button click triggers segmentation
    segment_button.click(segment_button_click, 
                         inputs=[file_input, resx_input, resy_input, resz_input], 
                         outputs=[status_output, download_output, slice_slider])

    # Slider changes trigger visualization updates
    slice_slider.change(update_visualization, inputs=slice_slider, outputs=visualization_output)

if __name__ == "__main__":
    iface.launch()