Pouriarouzrokh commited on
Commit
8523150
·
1 Parent(s): f9f482a

precomputed all rotations

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -2
  2. __pycache__/io_utils.cpython-311.pyc +0 -0
  3. app.py +20 -116
  4. data/cached_outputs/xr_1_(-10, -10, -10).png +0 -0
  5. data/cached_outputs/xr_1_(-10, -10, -15).png +0 -0
  6. data/cached_outputs/xr_1_(-10, -10, -5).png +0 -0
  7. data/cached_outputs/xr_1_(-10, -10, 0).png +0 -0
  8. data/cached_outputs/xr_1_(-10, -10, 10).png +0 -0
  9. data/cached_outputs/xr_1_(-10, -10, 15).png +0 -0
  10. data/cached_outputs/xr_1_(-10, -10, 5).png +0 -0
  11. data/cached_outputs/xr_1_(-10, -15, -10).png +0 -0
  12. data/cached_outputs/xr_1_(-10, -15, -15).png +0 -0
  13. data/cached_outputs/xr_1_(-10, -15, -5).png +0 -0
  14. data/cached_outputs/xr_1_(-10, -15, 0).png +0 -0
  15. data/cached_outputs/xr_1_(-10, -15, 10).png +0 -0
  16. data/cached_outputs/xr_1_(-10, -15, 15).png +0 -0
  17. data/cached_outputs/xr_1_(-10, -15, 5).png +0 -0
  18. data/cached_outputs/xr_1_(-10, -5, -10).png +0 -0
  19. data/cached_outputs/xr_1_(-10, -5, -15).png +0 -0
  20. data/cached_outputs/xr_1_(-10, -5, -5).png +0 -0
  21. data/cached_outputs/xr_1_(-10, -5, 0).png +0 -0
  22. data/cached_outputs/xr_1_(-10, -5, 10).png +0 -0
  23. data/cached_outputs/xr_1_(-10, -5, 15).png +0 -0
  24. data/cached_outputs/xr_1_(-10, -5, 5).png +0 -0
  25. data/cached_outputs/xr_1_(-10, 0, -10).png +0 -0
  26. data/cached_outputs/xr_1_(-10, 0, -15).png +0 -0
  27. data/cached_outputs/xr_1_(-10, 0, -5).png +0 -0
  28. data/cached_outputs/xr_1_(-10, 0, 0).png +0 -0
  29. data/cached_outputs/xr_1_(-10, 0, 10).png +0 -0
  30. data/cached_outputs/xr_1_(-10, 0, 15).png +0 -0
  31. data/cached_outputs/xr_1_(-10, 0, 5).png +0 -0
  32. data/cached_outputs/xr_1_(-10, 10, -10).png +0 -0
  33. data/cached_outputs/xr_1_(-10, 10, -15).png +0 -0
  34. data/cached_outputs/xr_1_(-10, 10, -5).png +0 -0
  35. data/cached_outputs/xr_1_(-10, 10, 0).png +0 -0
  36. data/cached_outputs/xr_1_(-10, 10, 10).png +0 -0
  37. data/cached_outputs/xr_1_(-10, 10, 15).png +0 -0
  38. data/cached_outputs/xr_1_(-10, 10, 5).png +0 -0
  39. data/cached_outputs/xr_1_(-10, 15, -10).png +0 -0
  40. data/cached_outputs/xr_1_(-10, 15, -15).png +0 -0
  41. data/cached_outputs/xr_1_(-10, 15, -5).png +0 -0
  42. data/cached_outputs/xr_1_(-10, 15, 0).png +0 -0
  43. data/cached_outputs/xr_1_(-10, 15, 10).png +0 -0
  44. data/cached_outputs/xr_1_(-10, 15, 15).png +0 -0
  45. data/cached_outputs/xr_1_(-10, 15, 5).png +0 -0
  46. data/cached_outputs/xr_1_(-10, 5, -10).png +0 -0
  47. data/cached_outputs/xr_1_(-10, 5, -15).png +0 -0
  48. data/cached_outputs/xr_1_(-10, 5, -5).png +0 -0
  49. data/cached_outputs/xr_1_(-10, 5, 0).png +0 -0
  50. data/cached_outputs/xr_1_(-10, 5, 10).png +0 -0
.gitignore CHANGED
@@ -1,2 +1 @@
1
- flagged/
2
- *.ckpt
 
1
+ app_copy.py
 
__pycache__/io_utils.cpython-311.pyc ADDED
Binary file (6.62 kB). View file
 
app.py CHANGED
@@ -3,95 +3,12 @@ import os
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- import pandas as pd
7
  import skimage
8
- from mediffusion import DiffusionModule
9
  import monai as mn
10
  import torch
11
 
12
  from io_utils import LoadImageD
13
 
14
- # Loading the model for inference
15
-
16
- model = DiffusionModule("./diffusion_configs.yaml")
17
- model.load_ckpt("./data/model.ckpt")
18
- model.eval();
19
-
20
- # Loading a baseline noise for making predictions
21
-
22
- seed = 3407
23
- np.random.seed(seed)
24
- torch.random.manual_seed(seed)
25
- torch.backends.cudnn.deterministic = True
26
- BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
27
-
28
- # Model helper functions
29
-
30
- def create_ds(img_paths):
31
- if type(img_paths) == str:
32
- img_paths = [img_paths]
33
- data_list = [{"img": img_path} for img_path in img_paths]
34
-
35
- # Get the transforms
36
- Ts_list = [
37
- LoadImageD(keys=["img"], transpose=True, normalize=True),
38
- mn.transforms.EnsureChannelFirstD(
39
- keys=["img"], channel_dim="no_channel"
40
- ),
41
- mn.transforms.ResizeD(
42
- keys=["img"],
43
- spatial_size=(256, 256),
44
- mode=["bicubic"],
45
- ),
46
- mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1),
47
- mn.transforms.ToTensorD(keys=["img"], track_meta=None),
48
- mn.transforms.SelectItemsD(keys=["img"]),
49
- ]
50
- return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list))
51
-
52
- def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM5"):
53
-
54
- global model
55
- global BASELINE_NOISE
56
-
57
- # Create the image dataset
58
- if cls_batch is not None:
59
- ds = create_ds([img_path]*len(cls_batch))
60
- else:
61
- ds = create_ds(img_path)
62
- dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False)
63
- input_batch = next(iter(dl))
64
- original_imgs = input_batch["img"].detach().cpu().numpy()
65
-
66
- # Create the classifier condition if not provided
67
- if cls_batch is None:
68
- fp = torch.zeros(768)
69
- if rotate_to_standard or angles is None:
70
- angles = [1000, 1000, 1000]
71
- cls_value = torch.tensor([2, *angles, *fp])
72
- else:
73
- cls_value = torch.tensor([1, *angles, *fp])
74
- cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1)
75
-
76
- # Generate noise
77
- noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1)
78
- model_kwargs = {
79
- "cls": cls_batch,
80
- "concat": input_batch["img"]
81
- }
82
-
83
- # Make predictions
84
- preds = model.predict(
85
- noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler
86
- )
87
- adjusted_preds = list()
88
- for pred, original_img in zip(preds, original_imgs):
89
- adjusted_pred = pred.detach().cpu().numpy().squeeze()
90
- original_img = original_img.squeeze()
91
- adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img)
92
- adjusted_preds.append(adjusted_pred)
93
- return adjusted_preds
94
-
95
  # Gradio helper functions
96
 
97
  current_img = None
@@ -101,65 +18,52 @@ def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
101
 
102
  global current_img
103
 
104
- angles = [float(xt), float(yt), float(zt)]
105
- out_img = make_predictions(img_path, angles)[0]
 
106
  if not add_bone_cmap:
107
- print(out_img.shape)
108
  return out_img
109
  cmap = plt.get_cmap('bone')
110
  out_img = cmap(out_img)
111
  out_img = (out_img[..., :3] * 255).astype(np.uint8)
112
  current_img = out_img
113
  return out_img
114
-
115
- def use_current_btn_fn(input_img):
116
- return input_img
117
-
118
- def retrieve_examples(examples, inputs):
119
- global current_img
120
- if current_img is not None:
121
- return current_img
122
- return examples[0]
123
 
124
  css_style = "./style.css"
125
  callback = gr.CSVLogger()
126
  with gr.Blocks(css=css_style) as app:
127
  gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
128
- gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="note")
129
- gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note")
130
 
131
- with gr.TabItem("Single Rotation"):
132
  with gr.Row():
133
  input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
134
  output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
135
  with gr.Row():
136
- gr.Examples(
137
- examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
138
- inputs = [input_img],
139
- label = "Xray Examples",
140
- elem_id='examples',
141
- )
142
- gr.Examples(
143
- examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
144
- inputs = [input_img],
145
- label = "DRR Examples",
146
- elem_id='examples',
147
- )
148
  with gr.Row():
149
  gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
150
  with gr.Row():
151
  with gr.Column(scale=1):
152
- xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
153
  with gr.Column(scale=1):
154
- yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
155
  with gr.Column(scale=1):
156
- zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
157
  with gr.Row():
158
  rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
159
- with gr.Row():
160
- use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
161
  rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
162
- use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
163
 
164
  try:
165
  app.close()
 
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
6
  import skimage
 
7
  import monai as mn
8
  import torch
9
 
10
  from io_utils import LoadImageD
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Gradio helper functions
13
 
14
  current_img = None
 
18
 
19
  global current_img
20
 
21
+ angles = (xt, yt, zt)
22
+ out_img_path = f'data/cached_outputs/{os.path.basename(img_path)[:-4]}_{angles}.png'
23
+ out_img = skimage.io.imread(out_img_path)
24
  if not add_bone_cmap:
 
25
  return out_img
26
  cmap = plt.get_cmap('bone')
27
  out_img = cmap(out_img)
28
  out_img = (out_img[..., :3] * 255).astype(np.uint8)
29
  current_img = out_img
30
  return out_img
 
 
 
 
 
 
 
 
 
31
 
32
  css_style = "./style.css"
33
  callback = gr.CSVLogger()
34
  with gr.Blocks(css=css_style) as app:
35
  gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
36
+ gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="subtitle")
37
+ gr.HTML("Note: This is a proof-of-concept demo running on CPU. All predictions are pre-computed.", elem_classes="note")
38
 
39
+ with gr.TabItem("Demo"):
40
  with gr.Row():
41
  input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs')
42
  output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs')
43
  with gr.Row():
44
+ with gr.Column(scale=0.25):
45
+ pass
46
+ with gr.Column(scale=1):
47
+ gr.Examples(
48
+ examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
49
+ inputs = [input_img],
50
+ label = "Xray Examples",
51
+ elem_id='examples',
52
+ )
53
+ with gr.Column(scale=0.25):
54
+ pass
 
55
  with gr.Row():
56
  gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text')
57
  with gr.Row():
58
  with gr.Column(scale=1):
59
+ xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
60
  with gr.Column(scale=1):
61
+ yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
62
  with gr.Column(scale=1):
63
+ zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-15, maximum=15, step=5)
64
  with gr.Row():
65
  rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
 
 
66
  rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
 
67
 
68
  try:
69
  app.close()
data/cached_outputs/xr_1_(-10, -10, -10).png ADDED
data/cached_outputs/xr_1_(-10, -10, -15).png ADDED
data/cached_outputs/xr_1_(-10, -10, -5).png ADDED
data/cached_outputs/xr_1_(-10, -10, 0).png ADDED
data/cached_outputs/xr_1_(-10, -10, 10).png ADDED
data/cached_outputs/xr_1_(-10, -10, 15).png ADDED
data/cached_outputs/xr_1_(-10, -10, 5).png ADDED
data/cached_outputs/xr_1_(-10, -15, -10).png ADDED
data/cached_outputs/xr_1_(-10, -15, -15).png ADDED
data/cached_outputs/xr_1_(-10, -15, -5).png ADDED
data/cached_outputs/xr_1_(-10, -15, 0).png ADDED
data/cached_outputs/xr_1_(-10, -15, 10).png ADDED
data/cached_outputs/xr_1_(-10, -15, 15).png ADDED
data/cached_outputs/xr_1_(-10, -15, 5).png ADDED
data/cached_outputs/xr_1_(-10, -5, -10).png ADDED
data/cached_outputs/xr_1_(-10, -5, -15).png ADDED
data/cached_outputs/xr_1_(-10, -5, -5).png ADDED
data/cached_outputs/xr_1_(-10, -5, 0).png ADDED
data/cached_outputs/xr_1_(-10, -5, 10).png ADDED
data/cached_outputs/xr_1_(-10, -5, 15).png ADDED
data/cached_outputs/xr_1_(-10, -5, 5).png ADDED
data/cached_outputs/xr_1_(-10, 0, -10).png ADDED
data/cached_outputs/xr_1_(-10, 0, -15).png ADDED
data/cached_outputs/xr_1_(-10, 0, -5).png ADDED
data/cached_outputs/xr_1_(-10, 0, 0).png ADDED
data/cached_outputs/xr_1_(-10, 0, 10).png ADDED
data/cached_outputs/xr_1_(-10, 0, 15).png ADDED
data/cached_outputs/xr_1_(-10, 0, 5).png ADDED
data/cached_outputs/xr_1_(-10, 10, -10).png ADDED
data/cached_outputs/xr_1_(-10, 10, -15).png ADDED
data/cached_outputs/xr_1_(-10, 10, -5).png ADDED
data/cached_outputs/xr_1_(-10, 10, 0).png ADDED
data/cached_outputs/xr_1_(-10, 10, 10).png ADDED
data/cached_outputs/xr_1_(-10, 10, 15).png ADDED
data/cached_outputs/xr_1_(-10, 10, 5).png ADDED
data/cached_outputs/xr_1_(-10, 15, -10).png ADDED
data/cached_outputs/xr_1_(-10, 15, -15).png ADDED
data/cached_outputs/xr_1_(-10, 15, -5).png ADDED
data/cached_outputs/xr_1_(-10, 15, 0).png ADDED
data/cached_outputs/xr_1_(-10, 15, 10).png ADDED
data/cached_outputs/xr_1_(-10, 15, 15).png ADDED
data/cached_outputs/xr_1_(-10, 15, 5).png ADDED
data/cached_outputs/xr_1_(-10, 5, -10).png ADDED
data/cached_outputs/xr_1_(-10, 5, -15).png ADDED
data/cached_outputs/xr_1_(-10, 5, -5).png ADDED
data/cached_outputs/xr_1_(-10, 5, 0).png ADDED
data/cached_outputs/xr_1_(-10, 5, 10).png ADDED