Spaces:
Sleeping
Sleeping
Commit
·
8523150
1
Parent(s):
f9f482a
precomputed all rotations
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -2
- __pycache__/io_utils.cpython-311.pyc +0 -0
- app.py +20 -116
- data/cached_outputs/xr_1_(-10, -10, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, -10, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, -15, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, -5, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, 0, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, 10, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 10).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 15).png +0 -0
- data/cached_outputs/xr_1_(-10, 15, 5).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, -10).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, -15).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, -5).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, 0).png +0 -0
- data/cached_outputs/xr_1_(-10, 5, 10).png +0 -0
.gitignore
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
|
2 |
-
*.ckpt
|
|
|
1 |
+
app_copy.py
|
|
__pycache__/io_utils.cpython-311.pyc
ADDED
Binary file (6.62 kB). View file
|
|
app.py
CHANGED
@@ -3,95 +3,12 @@ import os
|
|
3 |
import gradio as gr
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
import skimage
|
8 |
-
from mediffusion import DiffusionModule
|
9 |
import monai as mn
|
10 |
import torch
|
11 |
|
12 |
from io_utils import LoadImageD
|
13 |
|
14 |
-
# Loading the model for inference
|
15 |
-
|
16 |
-
model = DiffusionModule("./diffusion_configs.yaml")
|
17 |
-
model.load_ckpt("./data/model.ckpt")
|
18 |
-
model.eval();
|
19 |
-
|
20 |
-
# Loading a baseline noise for making predictions
|
21 |
-
|
22 |
-
seed = 3407
|
23 |
-
np.random.seed(seed)
|
24 |
-
torch.random.manual_seed(seed)
|
25 |
-
torch.backends.cudnn.deterministic = True
|
26 |
-
BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
|
27 |
-
|
28 |
-
# Model helper functions
|
29 |
-
|
30 |
-
def create_ds(img_paths):
|
31 |
-
if type(img_paths) == str:
|
32 |
-
img_paths = [img_paths]
|
33 |
-
data_list = [{"img": img_path} for img_path in img_paths]
|
34 |
-
|
35 |
-
# Get the transforms
|
36 |
-
Ts_list = [
|
37 |
-
LoadImageD(keys=["img"], transpose=True, normalize=True),
|
38 |
-
mn.transforms.EnsureChannelFirstD(
|
39 |
-
keys=["img"], channel_dim="no_channel"
|
40 |
-
),
|
41 |
-
mn.transforms.ResizeD(
|
42 |
-
keys=["img"],
|
43 |
-
spatial_size=(256, 256),
|
44 |
-
mode=["bicubic"],
|
45 |
-
),
|
46 |
-
mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1),
|
47 |
-
mn.transforms.ToTensorD(keys=["img"], track_meta=None),
|
48 |
-
mn.transforms.SelectItemsD(keys=["img"]),
|
49 |
-
]
|
50 |
-
return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list))
|
51 |
-
|
52 |
-
def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM5"):
|
53 |
-
|
54 |
-
global model
|
55 |
-
global BASELINE_NOISE
|
56 |
-
|
57 |
-
# Create the image dataset
|
58 |
-
if cls_batch is not None:
|
59 |
-
ds = create_ds([img_path]*len(cls_batch))
|
60 |
-
else:
|
61 |
-
ds = create_ds(img_path)
|
62 |
-
dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False)
|
63 |
-
input_batch = next(iter(dl))
|
64 |
-
original_imgs = input_batch["img"].detach().cpu().numpy()
|
65 |
-
|
66 |
-
# Create the classifier condition if not provided
|
67 |
-
if cls_batch is None:
|
68 |
-
fp = torch.zeros(768)
|
69 |
-
if rotate_to_standard or angles is None:
|
70 |
-
angles = [1000, 1000, 1000]
|
71 |
-
cls_value = torch.tensor([2, *angles, *fp])
|
72 |
-
else:
|
73 |
-
cls_value = torch.tensor([1, *angles, *fp])
|
74 |
-
cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
|
75 |
-
|
76 |
-
# Generate noise
|
77 |
-
noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
|
78 |
-
model_kwargs = {
|
79 |
-
"cls": cls_batch,
|
80 |
-
"concat": input_batch["img"]
|
81 |
-
}
|
82 |
-
|
83 |
-
# Make predictions
|
84 |
-
preds = model.predict(
|
85 |
-
noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler
|
86 |
-
)
|
87 |
-
adjusted_preds = list()
|
88 |
-
for pred, original_img in zip(preds, original_imgs):
|
89 |
-
adjusted_pred = pred.detach().cpu().numpy().squeeze()
|
90 |
-
original_img = original_img.squeeze()
|
91 |
-
adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img)
|
92 |
-
adjusted_preds.append(adjusted_pred)
|
93 |
-
return adjusted_preds
|
94 |
-
|
95 |
# Gradio helper functions
|
96 |
|
97 |
current_img = None
|
@@ -101,65 +18,52 @@ def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
|
|
101 |
|
102 |
global current_img
|
103 |
|
104 |
-
angles =
|
105 |
-
|
|
|
106 |
if not add_bone_cmap:
|
107 |
-
print(out_img.shape)
|
108 |
return out_img
|
109 |
cmap = plt.get_cmap('bone')
|
110 |
out_img = cmap(out_img)
|
111 |
out_img = (out_img[..., :3] * 255).astype(np.uint8)
|
112 |
current_img = out_img
|
113 |
return out_img
|
114 |
-
|
115 |
-
def use_current_btn_fn(input_img):
|
116 |
-
return input_img
|
117 |
-
|
118 |
-
def retrieve_examples(examples, inputs):
|
119 |
-
global current_img
|
120 |
-
if current_img is not None:
|
121 |
-
return current_img
|
122 |
-
return examples[0]
|
123 |
|
124 |
css_style = "./style.css"
|
125 |
callback = gr.CSVLogger()
|
126 |
with gr.Blocks(css=css_style) as app:
|
127 |
gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
|
128 |
-
gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="
|
129 |
-
gr.HTML("Note: This is a proof-of-concept demo
|
130 |
|
131 |
-
with gr.TabItem("
|
132 |
with gr.Row():
|
133 |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
|
134 |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
|
135 |
with gr.Row():
|
136 |
-
gr.
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
)
|
148 |
with gr.Row():
|
149 |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
|
150 |
with gr.Row():
|
151 |
with gr.Column(scale=1):
|
152 |
-
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-
|
153 |
with gr.Column(scale=1):
|
154 |
-
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-
|
155 |
with gr.Column(scale=1):
|
156 |
-
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-
|
157 |
with gr.Row():
|
158 |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
|
159 |
-
with gr.Row():
|
160 |
-
use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
|
161 |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
|
162 |
-
use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
|
163 |
|
164 |
try:
|
165 |
app.close()
|
|
|
3 |
import gradio as gr
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
|
|
6 |
import skimage
|
|
|
7 |
import monai as mn
|
8 |
import torch
|
9 |
|
10 |
from io_utils import LoadImageD
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Gradio helper functions
|
13 |
|
14 |
current_img = None
|
|
|
18 |
|
19 |
global current_img
|
20 |
|
21 |
+
angles = (xt, yt, zt)
|
22 |
+
out_img_path = f'data/cached_outputs/{os.path.basename(img_path)[:-4]}_{angles}.png'
|
23 |
+
out_img = skimage.io.imread(out_img_path)
|
24 |
if not add_bone_cmap:
|
|
|
25 |
return out_img
|
26 |
cmap = plt.get_cmap('bone')
|
27 |
out_img = cmap(out_img)
|
28 |
out_img = (out_img[..., :3] * 255).astype(np.uint8)
|
29 |
current_img = out_img
|
30 |
return out_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
css_style = "./style.css"
|
33 |
callback = gr.CSVLogger()
|
34 |
with gr.Blocks(css=css_style) as app:
|
35 |
gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
|
36 |
+
gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="subtitle")
|
37 |
+
gr.HTML("Note: This is a proof-of-concept demo running on CPU. All predictions are pre-computed.", elem_classes="note")
|
38 |
|
39 |
+
with gr.TabItem("Demo"):
|
40 |
with gr.Row():
|
41 |
input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
|
42 |
output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
|
43 |
with gr.Row():
|
44 |
+
with gr.Column(scale=0.25):
|
45 |
+
pass
|
46 |
+
with gr.Column(scale=1):
|
47 |
+
gr.Examples(
|
48 |
+
examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
|
49 |
+
inputs = [input_img],
|
50 |
+
label = "Xray Examples",
|
51 |
+
elem_id='examples',
|
52 |
+
)
|
53 |
+
with gr.Column(scale=0.25):
|
54 |
+
pass
|
|
|
55 |
with gr.Row():
|
56 |
gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
|
57 |
with gr.Row():
|
58 |
with gr.Column(scale=1):
|
59 |
+
xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
|
60 |
with gr.Column(scale=1):
|
61 |
+
yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
|
62 |
with gr.Column(scale=1):
|
63 |
+
zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
|
64 |
with gr.Row():
|
65 |
rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
|
|
|
|
|
66 |
rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
|
|
|
67 |
|
68 |
try:
|
69 |
app.close()
|
data/cached_outputs/xr_1_(-10, -10, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -10, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -10, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -10, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -10, 10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -10, 15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -10, 5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, 10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, 15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -15, 5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, 10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, 15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, -5, 5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, 10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, 15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 0, 5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, 10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, 15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 10, 5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, 10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, 15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 15, 5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 5, -10).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 5, -15).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 5, -5).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 5, 0).png
ADDED
![]() |
data/cached_outputs/xr_1_(-10, 5, 10).png
ADDED
![]() |