File size: 3,240 Bytes
a134324
 
9dfef1f
f9fed5c
30b140a
9dfef1f
f9fed5c
 
 
 
9dfef1f
 
f9fed5c
 
9dfef1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fed5c
 
 
 
 
 
 
9dfef1f
a0eb692
f9fed5c
 
 
 
 
 
 
 
 
 
 
 
 
30b140a
f9fed5c
30b140a
 
 
 
 
 
 
 
ac9ba9f
 
30b140a
 
 
 
 
f9fed5c
30b140a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9fed5c
 
30b140a
 
ac9ba9f
30b140a
 
 
 
 
f9fed5c
ac9ba9f
 
a134324
 
 
f9fed5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
import tifffile
from aicsimageio import AICSImage  # To handle .czi files
import matplotlib.pyplot as plt
import predict_mask from prediction

# Placeholder for your 3D model
def process_3d_image(image):
    # Dummy model implementation: Replace with your actual model logic
    model_dir = 'https://huggingface.co/Hemaxi/3DCycleGAN/tree/main/CycleGANVesselSegmentation.h5'
    binary_mask = predict_mask(model_dir,image)
    return binary_mask

def auximread(filepath):

    image = tifffile.imread(filepath)
        
    #the output image should be (X,Y,Z)
    original_0 = np.shape(image)[0]
    original_1 = np.shape(image)[1]
    original_2 = np.shape(image)[2]

    index_min = np.argmin([original_0, original_1, original_2])

    if index_min == 0:
        image = image.transpose(1,2,0)
    elif index_min == 1:
        image = image.transpose(0,2,1)

    return image

# Function to handle file input and processing
def process_file(file):
    """
    Process the uploaded file and return the binary mask.
    """
    if file.name.endswith(".tif"):
        # Load .tif file as a 3D numpy array
        image = auximread(file.name)
    else:
        raise ValueError("Unsupported file format. Please upload a .tif or .czi file.")

    # Ensure image is 3D
    if len(image.shape) != 3:
        raise ValueError("Input image is not 3D.")

    # Process image through the model
    binary_mask = process_3d_image(image)

    # Save binary mask to a .tif file to return
    output_path = "output_mask.tif"
    tiff.imwrite(output_path, binary_mask)

    return image, binary_mask, output_path

# Function to generate the slice visualization
def visualize_slice(image, mask, slice_index):
    """
    Visualizes a 2D slice of the image and the corresponding mask at the given index.
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # Extract the 2D slices
    image_slice = image[slice_index, :, :]
    mask_slice = mask[slice_index, :, :]

    # Plot image slice
    axes[0].imshow(image_slice, cmap="gray")
    axes[0].set_title("Image Slice")
    axes[0].axis("off")

    # Plot mask slice
    axes[1].imshow(mask_slice, cmap="gray")
    axes[1].set_title("Mask Slice")
    axes[1].axis("off")

    # Return the plot as a Gradio-compatible output
    plt.tight_layout()
    plt.close(fig)
    return fig

# Gradio Interface function
def interface(file, slice_index):
    image, mask, output_path = process_file(file)
    fig = visualize_slice(image, mask, slice_index)
    return fig, output_path

# Gradio Interface
iface = gr.Interface(
    fn=interface,
    inputs=[
        gr.File(label="Upload 3D Image (.tif or .czi)"),
        gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index")  # No live=True
    ],
    outputs=[
        gr.Plot(label="2D Slice Visualization"),
        gr.File(label="Download Binary Mask (.tif)")
    ],
    title="3D Image Processing with Binary Mask Output",
    description="Upload a 3D image in .tif or .czi format. The model will process the image and output a 3D binary mask. Use the slider to navigate through the 2D slices.",
    live=True  # Enable real-time updates in Interface
)

if __name__ == "__main__":
    iface.launch()