import gradio as gr import numpy as np import tifffile from aicsimageio import AICSImage # To handle .czi files import matplotlib.pyplot as plt import predict_mask from prediction # Placeholder for your 3D model def process_3d_image(image): # Dummy model implementation: Replace with your actual model logic model_dir = 'https://huggingface.co/Hemaxi/3DCycleGAN/tree/main/CycleGANVesselSegmentation.h5' binary_mask = predict_mask(model_dir,image) return binary_mask def auximread(filepath): image = tifffile.imread(filepath) #the output image should be (X,Y,Z) original_0 = np.shape(image)[0] original_1 = np.shape(image)[1] original_2 = np.shape(image)[2] index_min = np.argmin([original_0, original_1, original_2]) if index_min == 0: image = image.transpose(1,2,0) elif index_min == 1: image = image.transpose(0,2,1) return image # Function to handle file input and processing def process_file(file): """ Process the uploaded file and return the binary mask. """ if file.name.endswith(".tif"): # Load .tif file as a 3D numpy array image = auximread(file.name) else: raise ValueError("Unsupported file format. Please upload a .tif or .czi file.") # Ensure image is 3D if len(image.shape) != 3: raise ValueError("Input image is not 3D.") # Process image through the model binary_mask = process_3d_image(image) # Save binary mask to a .tif file to return output_path = "output_mask.tif" tiff.imwrite(output_path, binary_mask) return image, binary_mask, output_path # Function to generate the slice visualization def visualize_slice(image, mask, slice_index): """ Visualizes a 2D slice of the image and the corresponding mask at the given index. """ fig, axes = plt.subplots(1, 2, figsize=(12, 6)) # Extract the 2D slices image_slice = image[slice_index, :, :] mask_slice = mask[slice_index, :, :] # Plot image slice axes[0].imshow(image_slice, cmap="gray") axes[0].set_title("Image Slice") axes[0].axis("off") # Plot mask slice axes[1].imshow(mask_slice, cmap="gray") axes[1].set_title("Mask Slice") axes[1].axis("off") # Return the plot as a Gradio-compatible output plt.tight_layout() plt.close(fig) return fig # Gradio Interface function def interface(file, slice_index): image, mask, output_path = process_file(file) fig = visualize_slice(image, mask, slice_index) return fig, output_path # Gradio Interface iface = gr.Interface( fn=interface, inputs=[ gr.File(label="Upload 3D Image (.tif or .czi)"), gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index") # No live=True ], outputs=[ gr.Plot(label="2D Slice Visualization"), gr.File(label="Download Binary Mask (.tif)") ], title="3D Image Processing with Binary Mask Output", description="Upload a 3D image in .tif or .czi format. The model will process the image and output a 3D binary mask. Use the slider to navigate through the 2D slices.", live=True # Enable real-time updates in Interface ) if __name__ == "__main__": iface.launch()