import os import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import skimage from mediffusion import DiffusionModule import monai as mn import torch from io_utils import LoadImageD # Loading the model for inference model = DiffusionModule("./diffusion_configs.yaml") model.load_ckpt("./data/model.ckpt") model.eval(); # Loading a baseline noise for making predictions seed = 3407 np.random.seed(seed) torch.random.manual_seed(seed) torch.backends.cudnn.deterministic = True BASELINE_NOISE = torch.randn(1, 1, 256, 256).half() # Model helper functions def create_ds(img_paths): if type(img_paths) == str: img_paths = [img_paths] data_list = [{"img": img_path} for img_path in img_paths] # Get the transforms Ts_list = [ LoadImageD(keys=["img"], transpose=True, normalize=True), mn.transforms.EnsureChannelFirstD( keys=["img"], channel_dim="no_channel" ), mn.transforms.ResizeD( keys=["img"], spatial_size=(256, 256), mode=["bicubic"], ), mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1), mn.transforms.ToTensorD(keys=["img"], track_meta=None), mn.transforms.SelectItemsD(keys=["img"]), ] return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list)) def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"): global model global BASELINE_NOISE # Create the image dataset if cls_batch is not None: ds = create_ds([img_path]*len(cls_batch)) else: ds = create_ds(img_path) dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False) input_batch = next(iter(dl)) original_imgs = input_batch["img"].detach().cpu().numpy() # Create the classifier condition if not provided if cls_batch is None: fp = torch.zeros(768) if rotate_to_standard or angles is None: angles = [1000, 1000, 1000] cls_value = torch.tensor([2, *angles, *fp]) else: cls_value = torch.tensor([1, *angles, *fp]) cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1) # Generate noise noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1) model_kwargs = { "cls": cls_batch, "concat": input_batch["img"] } # Make predictions preds = model.predict( noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler ) adjusted_preds = list() for pred, original_img in zip(preds, original_imgs): adjusted_pred = pred.detach().cpu().numpy().squeeze() original_img = original_img.squeeze() adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img) adjusted_preds.append(adjusted_pred) return adjusted_preds # Gradio helper functions current_img = None live_preds = None def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False): global current_img angles = [float(xt), float(yt), float(zt)] out_img = make_predictions(img_path, angles)[0] if not add_bone_cmap: print(out_img.shape) return out_img cmap = plt.get_cmap('bone') out_img = cmap(out_img) out_img = (out_img[..., :3] * 255).astype(np.uint8) current_img = out_img return out_img def use_current_btn_fn(input_img): return input_img css_style = "./style.css" callback = gr.CSVLogger() with gr.Blocks(css=css_style) as app: gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title") gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="note") gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note") with gr.TabItem("Single Rotation"): with gr.Row(): input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs') output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs') with gr.Row(): gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], inputs = [input_img], label = "Xray Examples", elem_id='examples' ) gr.Examples( examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], inputs = [input_img], label = "DRR Examples", elem_id='examples' ) with gr.Row(): gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') with gr.Row(): with gr.Column(scale=1): xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) with gr.Column(scale=1): yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) with gr.Column(scale=1): zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) with gr.Row(): rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') with gr.Row(): use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button') rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img) try: app.close() gr.close_all() except: pass demo = app.launch( max_threads=4, share=True, inline=False, show_api=False, show_error=False, )