Pouriarouzrokh commited on
Commit
b8c299e
·
1 Parent(s): fbc26e7

added LoadImageD from osail-utils

Browse files
Files changed (2) hide show
  1. app.py +4 -92
  2. io_utils.py +121 -0
app.py CHANGED
@@ -9,6 +9,8 @@ from mediffusion import DiffusionModule
9
  import monai as mn
10
  import torch
11
 
 
 
12
  # Loading the model for inference
13
 
14
  model = DiffusionModule("./diffusion_configs.yaml")
@@ -25,22 +27,6 @@ BASELINE_NOISE = torch.randn(1, 1, 256, 256).half()
25
 
26
  # Model helper functions
27
 
28
- class LoadImageD(mn.transforms.Transform):
29
- def __init__(self, keys, transpose=False, normalize=False):
30
- self.keys = keys
31
- self.transpose = transpose
32
- self.normalize = normalize
33
- def __call__(self, data):
34
- for key in self.keys:
35
- img = skimage.io.imread(data[key])
36
- if self.transpose:
37
- img = img.transpose(0, 1)
38
- if self.normalize:
39
- img -= img.min()
40
- img /= (img.max()+1e-6)
41
- data[key] = img
42
- return data
43
-
44
  def create_ds(img_paths):
45
  if type(img_paths) == str:
46
  img_paths = [img_paths]
@@ -125,58 +111,15 @@ def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False):
125
  out_img = (out_img[..., :3] * 255).astype(np.uint8)
126
  current_img = out_img
127
  return out_img
128
-
129
- def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False):
130
-
131
- global current_img
132
-
133
- out_img = make_predictions(img_path, rotate_to_standard=True)[0]
134
- if not add_bone_cmap:
135
- return out_img
136
- cmap = plt.get_cmap('bone')
137
- out_img = cmap(out_img)
138
- out_img = (out_img[..., :3] * 255).astype(np.uint8)
139
- current_img = out_img
140
- return out_img
141
 
142
  def use_current_btn_fn(input_img):
143
  return input_img
144
-
145
-
146
- def make_live_btn_fn(img_path, axis, add_bone_cmap=False):
147
-
148
- global live_preds
149
-
150
- base_angles = list(range(-20, 21, 1))
151
- base_angles = [float(i) for i in base_angles]
152
- if axis.lower() == "axis x":
153
- all_angles = [[i, 0, 0] for i in base_angles]
154
- elif axis.lower() == "axis y":
155
- all_angles = [[0, i, 0] for i in base_angles]
156
- elif axis.lower() == "axis z":
157
- all_angles = [[0, 0, i] for i in base_angles]
158
- fp = torch.zeros(768)
159
- cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles])
160
-
161
- live_preds = make_predictions(img_path, cls_batch=cls_batch)
162
- live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)}
163
- return img_path
164
-
165
- def rotate_live_img_fn(angle, add_bone_cmap=False):
166
-
167
- global live_img
168
- global live_preds
169
-
170
- if live_img is not None:
171
- if angle == 0:
172
- return live_img
173
- return live_preds[float(angle)]
174
 
175
  css_style = "./style.css"
176
  callback = gr.CSVLogger()
177
  with gr.Blocks(css=css_style) as app:
178
- gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title")
179
- gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note")
180
  gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note")
181
 
182
  with gr.TabItem("Single Rotation"):
@@ -207,41 +150,10 @@ with gr.Blocks(css=css_style) as app:
207
  zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
208
  with gr.Row():
209
  rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
210
- with gr.Row():
211
- rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button')
212
  with gr.Row():
213
  use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
214
  rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
215
- rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img)
216
  use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
217
-
218
- with gr.TabItem("Live Rotation"):
219
- with gr.Row():
220
- live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs')
221
- with gr.Row():
222
- gr.Examples(
223
- examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f],
224
- inputs = [live_img],
225
- label = "Xray Examples",
226
- elem_id='examples'
227
- )
228
- gr.Examples(
229
- examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f],
230
- inputs = [live_img],
231
- label = "DRR Examples",
232
- elem_id='examples'
233
- )
234
- with gr.Row():
235
- gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text')
236
- with gr.Row():
237
- axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X')
238
- live_btn = gr.Button("Make Live!", elem_classes='make_live_button')
239
- with gr.Row():
240
- gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text')
241
- with gr.Row():
242
- slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True)
243
- live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img)
244
- slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img)
245
 
246
  try:
247
  app.close()
 
9
  import monai as mn
10
  import torch
11
 
12
+ from io_utils import LoadImageD
13
+
14
  # Loading the model for inference
15
 
16
  model = DiffusionModule("./diffusion_configs.yaml")
 
27
 
28
  # Model helper functions
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def create_ds(img_paths):
31
  if type(img_paths) == str:
32
  img_paths = [img_paths]
 
111
  out_img = (out_img[..., :3] * 255).astype(np.uint8)
112
  current_img = out_img
113
  return out_img
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def use_current_btn_fn(input_img):
116
  return input_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  css_style = "./style.css"
119
  callback = gr.CSVLogger()
120
  with gr.Blocks(css=css_style) as app:
121
+ gr.HTML("VCNet: A tool for 3D Rotation of Radiographs with Diffusion Models", elem_classes="title")
122
+ gr.HTML("Developed by: Pouria Rouzrokh, Bardia Khosravi, Shahriar Faghani, Kellen Mulford, Michael J. Taunton, Bradley J. Erickson, Cody C. Wyles", elem_classes="note")
123
  gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note")
124
 
125
  with gr.TabItem("Single Rotation"):
 
150
  zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1)
151
  with gr.Row():
152
  rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button')
 
 
153
  with gr.Row():
154
  use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button')
155
  rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img)
 
156
  use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  try:
159
  app.close()
io_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # This files contains OSAIL utils to read and write files.
3
+ ################################################################################
4
+
5
+ from .data import pad_to_square
6
+ import copy
7
+ import monai as mn
8
+ import numpy as np
9
+ import os
10
+ import skimage
11
+
12
+ ################################################################################
13
+ # -F: load_image
14
+
15
+ def load_image(input_object, pad=False, normalize=True, standardize=False,
16
+ dtype=np.float32, percentile_clip=None, target_shape=None,
17
+ transpose=False, ensure_grayscale=True, LoadImage_args=[], LoadImage_kwargs={}):
18
+ """A helper function to load different input types.
19
+
20
+ Args:
21
+ input_object (Union[np.ndarray, str]):
22
+ a 2D NumPy array of X-ray an image, a DICOM file of an X-ray image,
23
+ or a string path to a .npy, any regular image file format
24
+ saved on disk that skimage.io can load.
25
+ pad (bool, optional): whether to pad the image to square shape.
26
+ Defaults to True.
27
+ normalize (bool, optional): whether to normalize the image.
28
+ Defaults to True.
29
+ standardize (bool, optional): whether to standardize the image.
30
+ Defaults to False.
31
+ dtype (np.dtype, optional): the data type of the output image.
32
+ Defaults to np.float32.
33
+ percentile_clip (float, optional): the percentile to clip the image.
34
+ Defaults to 2.5.
35
+ target_shape (tuple, optional): the target shape of the output image.
36
+ Defaults to None, which means no resizing.
37
+ transpose (bool, optional): whether to transpose the image.
38
+ Defaults to False.
39
+ ensure_grayscale (bool, optional): whether to make the image grayscale.
40
+ Defaults to True.
41
+ LoadImg_args: a list of keyword arguments to pass to mn.transforms.LoadImage.
42
+ LoadImg_kwargs: a dictionary of keyword arguments to pass to mn.transforms.LoadImage.
43
+
44
+ Returns:
45
+ the loaded image array.
46
+ """
47
+ # Load the image.
48
+ if isinstance(input_object, np.ndarray):
49
+ image = input_object
50
+ elif isinstance(input_object, str):
51
+ assert os.path.exists(input_object), f"File not found: {input_object}"
52
+ reader = mn.transforms.LoadImage(image_only=True, *LoadImage_args, **LoadImage_kwargs)
53
+ image = reader(input_object)
54
+
55
+ # Make the image 2D.
56
+ if ensure_grayscale:
57
+ if image.shape[-1] == 3:
58
+ image = np.mean(image, axis=-1)
59
+ elif image.shape[0] == 3:
60
+ image = np.mean(image, axis=0)
61
+ elif image.shape[-1] == 4:
62
+ image = np.mean(image[...,:3], axis=-1)
63
+ elif image.shape[0] == 4:
64
+ image = np.mean(image[:3,...], axis=0)
65
+ assert len(image.shape) == 2, f"Image must be 2D: {image.shape}"
66
+
67
+ # Transpose the image.
68
+ if transpose:
69
+ image = np.transpose(image, axes=(1,0))
70
+
71
+ # Clip the image.
72
+ if percentile_clip is not None:
73
+ percentile_low = np.percentile(image, percentile_clip)
74
+ percentile_high = np.percentile(image, 100-percentile_clip)
75
+ image = np.clip(image, percentile_low, percentile_high)
76
+
77
+ # Standardize the image.
78
+ if standardize:
79
+ image = image.astype(np.float32)
80
+ image -= image.mean()
81
+ image /= (image.std() + 1e-8)
82
+
83
+ # Normalize the image.
84
+ if normalize:
85
+ image = image.astype(np.float32)
86
+ image -= image.min()
87
+ image /= (image.max() + 1e-8)
88
+
89
+ # Pad the image to square shape.
90
+ if pad:
91
+ image = pad_to_square(image)
92
+
93
+ # Resize the image.
94
+ if target_shape is not None:
95
+ image = skimage.transform.resize(image, target_shape, preserve_range=True)
96
+
97
+ # Cast the image to the target data type.
98
+ if dtype is np.uint8:
99
+ image = (image * 255).astype(np.uint8)
100
+ else:
101
+ image = image.astype(dtype)
102
+
103
+ return image
104
+
105
+ ################################################################################
106
+ # -C: LoadImageD
107
+
108
+ class LoadImageD(mn.transforms.Transform):
109
+ """A MONAI transform to load input image using load_image function.
110
+ """
111
+ def __init__(self, keys, *to_pass_keys, **to_pass_kwargs) -> None:
112
+ super().__init__()
113
+ self.keys = keys
114
+ self.to_pass_keys = to_pass_keys
115
+ self.to_pass_kwargs = to_pass_kwargs
116
+
117
+ def __call__(self, data):
118
+ data_copy = copy.deepcopy(data)
119
+ for key in self.keys:
120
+ data_copy[key] = load_image(data[key], *self.to_pass_keys, **self.to_pass_kwargs)
121
+ return data_copy