|
import gradio as gr |
|
import numpy as np |
|
import tifffile |
|
from aicsimageio import AICSImage |
|
import matplotlib.pyplot as plt |
|
import predict_mask from prediction |
|
|
|
|
|
def process_3d_image(image): |
|
|
|
model_dir = 'https://huggingface.co/Hemaxi/3DCycleGAN/tree/main/CycleGANVesselSegmentation.h5' |
|
binary_mask = predict_mask(model_dir,image) |
|
return binary_mask |
|
|
|
def auximread(filepath): |
|
|
|
image = tifffile.imread(filepath) |
|
|
|
|
|
original_0 = np.shape(image)[0] |
|
original_1 = np.shape(image)[1] |
|
original_2 = np.shape(image)[2] |
|
|
|
index_min = np.argmin([original_0, original_1, original_2]) |
|
|
|
if index_min == 0: |
|
image = image.transpose(1,2,0) |
|
elif index_min == 1: |
|
image = image.transpose(0,2,1) |
|
|
|
return image |
|
|
|
|
|
def process_file(file): |
|
""" |
|
Process the uploaded file and return the binary mask. |
|
""" |
|
if file.name.endswith(".tif"): |
|
|
|
image = auximread(file.name) |
|
else: |
|
raise ValueError("Unsupported file format. Please upload a .tif or .czi file.") |
|
|
|
|
|
if len(image.shape) != 3: |
|
raise ValueError("Input image is not 3D.") |
|
|
|
|
|
binary_mask = process_3d_image(image) |
|
|
|
|
|
output_path = "output_mask.tif" |
|
tiff.imwrite(output_path, binary_mask) |
|
|
|
return image, binary_mask, output_path |
|
|
|
|
|
def visualize_slice(image, mask, slice_index): |
|
""" |
|
Visualizes a 2D slice of the image and the corresponding mask at the given index. |
|
""" |
|
fig, axes = plt.subplots(1, 2, figsize=(12, 6)) |
|
|
|
|
|
image_slice = image[slice_index, :, :] |
|
mask_slice = mask[slice_index, :, :] |
|
|
|
|
|
axes[0].imshow(image_slice, cmap="gray") |
|
axes[0].set_title("Image Slice") |
|
axes[0].axis("off") |
|
|
|
|
|
axes[1].imshow(mask_slice, cmap="gray") |
|
axes[1].set_title("Mask Slice") |
|
axes[1].axis("off") |
|
|
|
|
|
plt.tight_layout() |
|
plt.close(fig) |
|
return fig |
|
|
|
|
|
def interface(file, slice_index): |
|
image, mask, output_path = process_file(file) |
|
fig = visualize_slice(image, mask, slice_index) |
|
return fig, output_path |
|
|
|
|
|
iface = gr.Interface( |
|
fn=interface, |
|
inputs=[ |
|
gr.File(label="Upload 3D Image (.tif or .czi)"), |
|
gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index") |
|
], |
|
outputs=[ |
|
gr.Plot(label="2D Slice Visualization"), |
|
gr.File(label="Download Binary Mask (.tif)") |
|
], |
|
title="3D Image Processing with Binary Mask Output", |
|
description="Upload a 3D image in .tif or .czi format. The model will process the image and output a 3D binary mask. Use the slider to navigate through the 2D slices.", |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|