File size: 4,690 Bytes
b8c299e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
################################################################################
# This files contains OSAIL utils to read and write files.
################################################################################

from .data import pad_to_square
import copy
import monai as mn
import numpy as np
import os
import skimage

################################################################################
# -F: load_image

def load_image(input_object, pad=False, normalize=True, standardize=False, 
               dtype=np.float32, percentile_clip=None, target_shape=None, 
               transpose=False, ensure_grayscale=True, LoadImage_args=[], LoadImage_kwargs={}):
    """A helper function to load different input types.

    Args:
        input_object (Union[np.ndarray, str]): 
            a 2D NumPy array of X-ray an image, a DICOM file of an X-ray image, 
            or a string path to a .npy, any regular image file format 
            saved on disk that skimage.io can load.
        pad (bool, optional): whether to pad the image to square shape. 
            Defaults to True.
        normalize (bool, optional): whether to normalize the image. 
            Defaults to True.
        standardize (bool, optional): whether to standardize the image.
            Defaults to False.
        dtype (np.dtype, optional): the data type of the output image. 
            Defaults to np.float32.
        percentile_clip (float, optional): the percentile to clip the image. 
            Defaults to 2.5.
        target_shape (tuple, optional): the target shape of the output image. 
            Defaults to None, which means no resizing.
        transpose (bool, optional): whether to transpose the image.
            Defaults to False.
        ensure_grayscale (bool, optional): whether to make the image grayscale.
            Defaults to True.
        LoadImg_args: a list of keyword arguments to pass to  mn.transforms.LoadImage.
        LoadImg_kwargs: a dictionary of keyword arguments to pass to  mn.transforms.LoadImage.
            
    Returns:
        the loaded image array.
    """
    # Load the image.
    if isinstance(input_object, np.ndarray):
        image = input_object
    elif isinstance(input_object, str):
        assert os.path.exists(input_object), f"File not found: {input_object}"
        reader = mn.transforms.LoadImage(image_only=True, *LoadImage_args, **LoadImage_kwargs)
        image = reader(input_object)

    # Make the image 2D.
    if ensure_grayscale:
        if image.shape[-1] == 3:
            image = np.mean(image, axis=-1)  
        elif image.shape[0] == 3:
            image = np.mean(image, axis=0)
        elif image.shape[-1] == 4:
            image = np.mean(image[...,:3], axis=-1)  
        elif image.shape[0] == 4:
            image = np.mean(image[:3,...], axis=0)  
        assert len(image.shape) == 2, f"Image must be 2D: {image.shape}"
    
    # Transpose the image.
    if transpose:
        image = np.transpose(image, axes=(1,0))
    
    # Clip the image.
    if percentile_clip is not None:
        percentile_low = np.percentile(image, percentile_clip)
        percentile_high = np.percentile(image, 100-percentile_clip)
        image = np.clip(image, percentile_low, percentile_high)
        
    # Standardize the image.
    if standardize:
        image = image.astype(np.float32)
        image -= image.mean()
        image /= (image.std() + 1e-8)
        
    # Normalize the image.
    if normalize:
        image = image.astype(np.float32)
        image -= image.min()
        image /= (image.max() + 1e-8)
    
    # Pad the image to square shape.
    if pad:
        image = pad_to_square(image)   
    
    # Resize the image.
    if target_shape is not None:
        image = skimage.transform.resize(image, target_shape, preserve_range=True)
        
    # Cast the image to the target data type.
    if dtype is np.uint8:
        image = (image * 255).astype(np.uint8)
    else:
        image = image.astype(dtype)  
    
    return image

################################################################################
# -C: LoadImageD

class LoadImageD(mn.transforms.Transform):
    """A MONAI transform to load input image using load_image function.
    """
    def __init__(self, keys, *to_pass_keys, **to_pass_kwargs) -> None:
        super().__init__()
        self.keys = keys
        self.to_pass_keys = to_pass_keys
        self.to_pass_kwargs = to_pass_kwargs
        
    def __call__(self, data):
        data_copy = copy.deepcopy(data)
        for key in self.keys:
            data_copy[key] = load_image(data[key], *self.to_pass_keys, **self.to_pass_kwargs)
        return data_copy