import os import gradio as gr import numpy as np import skimage from skimage import io import torch import monai from monai.transforms import Rotate # Placeholder for the 3D reconstruction model class Simple3DReconstructionModel: def __init__(self): # Load your pre-trained model here self.model = None # replace with actual model loading def reconstruct_3d(self, image): # Implement the 3D reconstruction logic here # This is a placeholder example return np.zeros((128, 128, 128)) def rotate_3d(self, volume, angles): # Rotate the 3D volume using MONAI rotate = Rotate(angles, mode='bilinear') rotated_volume = rotate(volume) return rotated_volume def project_2d(self, volume): # Project the 3D volume back to 2D # This is a placeholder example projection = np.max(volume, axis=0) return projection # Initialize the model model = Simple3DReconstructionModel() # Gradio helper functions def process_image(img, xt, yt, zt): # Reconstruct the 3D volume volume = model.reconstruct_3d(img) # Rotate the 3D volume rotated_volume = model.rotate_3d(volume, (xt, yt, zt)) # Project the rotated volume back to 2D output_img = model.project_2d(rotated_volume) return output_img def rotate_btn_fn(img, xt, yt, zt, add_bone_cmap=False): try: angles = (xt, yt, zt) print(f"Rotating with angles: {angles}") if isinstance(img, np.ndarray): input_img_path = "uploaded_image.png" skimage.io.imsave(input_img_path, img) elif isinstance(img, str) and os.path.exists(img): input_img_path = img img = skimage.io.imread(input_img_path) else: raise ValueError("Invalid input image") # Process the image with the model out_img = process_image(img, xt, yt, zt) if not add_bone_cmap: return out_img cmap = plt.get_cmap('bone') out_img = cmap(out_img) out_img = (out_img[..., :3] * 255).astype(np.uint8) return out_img except Exception as e: print(f"Error in rotate_btn_fn: {e}") return None css_style = "./style.css" callback = gr.CSVLogger() with gr.Blocks(css=css_style, title="RadRotator") as app: gr.HTML("RadRotator: 3D Rotation of Radiographs with Diffusion Models", elem_classes="title") gr.HTML("Developed by:
Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles
[Our website], [arXiv Paper]", elem_classes="note") gr.HTML("Note: The demo operates on a CPU, and since diffusion models require more computational capacity to function, all predictions are precomputed.", elem_classes="note") with gr.TabItem("Demo"): with gr.Row(): input_img = gr.Image(type='numpy', label='Input image', interactive=True, elem_classes='imgs') output_img = gr.Image(type='numpy', label='Output image', interactive=False, elem_classes='imgs') with gr.Row(): with gr.Column(scale=0.25): pass with gr.Column(scale=1): gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], inputs = [input_img], label = "Xray Examples", elem_id='examples', ) with gr.Column(scale=0.25): pass with gr.Row(): gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') with gr.Row(): with gr.Column(scale=1): xt = gr.Slider(label='x axis (medial/lateral rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) with gr.Column(scale=1): yt = gr.Slider(label='y axis (inlet/outlet rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) with gr.Column(scale=1): zt = gr.Slider(label='z axis (plane rotation):', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5) with gr.Row(): rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) try: app.close() gr.close_all() except Exception as e: print(f"Error closing app: {e}") demo = app.launch( max_threads=4, share=True, inline=False, show_api=False, show_error=False, )