Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import skimage | |
from mediffusion import DiffusionModule | |
import monai as mn | |
import torch | |
from io_utils import LoadImageD | |
# Loading the model for inference | |
model = DiffusionModule("./diffusion_configs.yaml") | |
model.load_ckpt("./data/model.ckpt") | |
model.eval(); | |
# Loading a baseline noise for making predictions | |
seed = 3407 | |
np.random.seed(seed) | |
torch.random.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
BASELINE_NOISE = torch.randn(1, 1, 256, 256).half() | |
# Model helper functions | |
def create_ds(img_paths): | |
if type(img_paths) == str: | |
img_paths = [img_paths] | |
data_list = [{"img": img_path} for img_path in img_paths] | |
# Get the transforms | |
Ts_list = [ | |
LoadImageD(keys=["img"], transpose=True, normalize=True), | |
mn.transforms.EnsureChannelFirstD( | |
keys=["img"], channel_dim="no_channel" | |
), | |
mn.transforms.ResizeD( | |
keys=["img"], | |
spatial_size=(256, 256), | |
mode=["bicubic"], | |
), | |
mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1), | |
mn.transforms.ToTensorD(keys=["img"], track_meta=None), | |
mn.transforms.SelectItemsD(keys=["img"]), | |
] | |
return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list)) | |
def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"): | |
global model | |
global BASELINE_NOISE | |
# Create the image dataset | |
if cls_batch is not None: | |
ds = create_ds([img_path]*len(cls_batch)) | |
else: | |
ds = create_ds(img_path) | |
dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False) | |
input_batch = next(iter(dl)) | |
original_imgs = input_batch["img"].detach().cpu().numpy() | |
# Create the classifier condition if not provided | |
if cls_batch is None: | |
fp = torch.zeros(768) | |
if rotate_to_standard or angles is None: | |
angles = [1000, 1000, 1000] | |
cls_value = torch.tensor([2, *angles, *fp]) | |
else: | |
cls_value = torch.tensor([1, *angles, *fp]) | |
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1) | |
# Generate noise | |
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1) | |
model_kwargs = { | |
"cls": cls_batch, | |
"concat": input_batch["img"] | |
} | |
# Make predictions | |
preds = model.predict( | |
noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler | |
) | |
adjusted_preds = list() | |
for pred, original_img in zip(preds, original_imgs): | |
adjusted_pred = pred.detach().cpu().numpy().squeeze() | |
original_img = original_img.squeeze() | |
adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img) | |
adjusted_preds.append(adjusted_pred) | |
return adjusted_preds | |
# Gradio helper functions | |
current_img = None | |
live_preds = None | |
def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False): | |
global current_img | |
angles = [float(xt), float(yt), float(zt)] | |
out_img = make_predictions(img_path, angles)[0] | |
if not add_bone_cmap: | |
print(out_img.shape) | |
return out_img | |
cmap = plt.get_cmap('bone') | |
out_img = cmap(out_img) | |
out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
current_img = out_img | |
return out_img | |
def use_current_btn_fn(input_img): | |
return input_img | |
css_style = "./style.css" | |
callback = gr.CSVLogger() | |
with gr.Blocks(css=css_style) as app: | |
gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title") | |
gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="note") | |
gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note") | |
with gr.TabItem("Single Rotation"): | |
with gr.Row(): | |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs') | |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs') | |
with gr.Row(): | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
inputs = [input_img], | |
label = "Xray Examples", | |
elem_id='examples' | |
) | |
gr.Examples( | |
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], | |
inputs = [input_img], | |
label = "DRR Examples", | |
elem_id='examples' | |
) | |
with gr.Row(): | |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
with gr.Column(scale=1): | |
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
with gr.Column(scale=1): | |
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
with gr.Row(): | |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') | |
with gr.Row(): | |
use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button') | |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) | |
use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img) | |
try: | |
app.close() | |
gr.close_all() | |
except: | |
pass | |
demo = app.launch( | |
max_threads=4, | |
share=True, | |
inline=False, | |
show_api=False, | |
show_error=False, | |
) |