File size: 12,188 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import kornia
import torch
import torch_lydorn.torchvision
# from pytorch_memlab import profile, profile_every
from frame_field_learning import measures
import cv2 as cv
import numpy as np


def compute_distance_transform(tensor: torch.Tensor) -> torch.Tensor:
    device = tensor.device
    array = tensor.cpu().numpy()
    shape = array.shape
    array = array.reshape(-1, *shape[-2:]).astype(np.uint8)
    dist_trans = np.empty(array.shape, dtype=np.float32)
    for i in range(array.shape[0]):
        dist_trans[i] = cv.distanceTransform(array[i], distanceType=cv.DIST_L2, maskSize=cv.DIST_MASK_5, dstType=cv.CV_64F)
    dist_trans = dist_trans.reshape(shape)
    dist_trans = torch.tensor(dist_trans, device=device)
    return dist_trans


def select_crossfield(all_outputs, final_seg):
    # Choose frame field from the replicate that best matches the final seg interior
    dice_loss = measures.dice_loss(all_outputs["seg"][:, :, 0, :, :], final_seg[None, :, 0, :, :])
    # Get index of the replicate that achieves the min dice_loss (as it's a loss, lower is better)
    indices_best = torch.argmin(dice_loss, dim=0)
    batch_range = torch.arange(all_outputs["seg"].shape[1])  # batch size
    # For each batch select frame field from the replicate in indices_best
    final_crossfield = all_outputs["crossfield"][indices_best, batch_range]
    return final_crossfield


def aggr_mean(all_outputs):
    final_outputs = {}
    if "seg" in all_outputs:
        final_seg = torch.mean(all_outputs["seg"], dim=0)
        final_outputs["seg"] = final_seg  # Final seg is between min and max: positive pixels are closer to min
        if "crossfield" in all_outputs:
            final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
    else:
        raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
    return final_outputs


def aggr_median(all_outputs):
    final_outputs = {}
    if "seg" in all_outputs:
        final_seg, _ = torch.median(all_outputs["seg"], dim=0)
        final_outputs["seg"] = final_seg  # Final seg is between min and max: positive pixels are closer to min
        if "crossfield" in all_outputs:
            final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
    else:
        raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
    return final_outputs


def aggr_dist_trans(all_outputs, seg_threshold):
    final_outputs = {}
    if "seg" in all_outputs:
        min_seg, _ = torch.min(all_outputs["seg"], dim=0)
        max_seg, _ = torch.max(all_outputs["seg"], dim=0)
        # Final seg will be between min and max seg. The idea is that we don't loose the sharp corners (which taking the mean does)
        dist_ext_to_min_seg = compute_distance_transform(min_seg < seg_threshold)
        dist_int_to_max_seg = compute_distance_transform(seg_threshold < max_seg)
        final_seg = dist_ext_to_min_seg < dist_int_to_max_seg
        final_outputs["seg"] = final_seg  # Final seg is between min and max: positive pixels are closer to min
        if "crossfield" in all_outputs:
            final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
    else:
        raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
    return final_outputs


def aggr_translated(all_outputs, seg_threshold, image_display=None):
    final_outputs = {}
    if "seg" in all_outputs:
        # Cleanup all_seg by multiplying with the mean seg
        all_seg = all_outputs["seg"]
        all_seg_mask: torch.Tensor = seg_threshold < all_seg
        mean_seg = torch.mean(all_seg_mask.float(), dim=0)
        mean_seg_mask = seg_threshold < mean_seg
        all_cleaned_seg = all_seg_mask * mean_seg[None, ...]
        # all_cleaned_seg_mask = seg_threshold < all_cleaned_seg
        # all_cleaned_seg[~all_cleaned_seg_mask] = 0  # Put 0 where seg is below threshold

        # # --- DEBUG SAVE
        # image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, mean_seg)
        # image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
        # skimage.io.imsave(f"image_seg_display_mean_seg.png", image_seg_display)
        # for i, cleaned_seg in enumerate(all_cleaned_seg):
        #     image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, cleaned_seg)
        #     image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
        #     skimage.io.imsave(f"image_seg_display_replicate_cleaned_{i}.png", image_seg_display)
        # # ---

        # Compute barycenter of all cleaned segs
        range_x = torch.arange(all_cleaned_seg.shape[4], device=all_cleaned_seg.device)
        range_y = torch.arange(all_cleaned_seg.shape[3], device=all_cleaned_seg.device)
        grid_y, grid_x = torch.meshgrid([range_x, range_y])
        grid_xy = torch.stack([grid_x, grid_y], dim=-1)

        # Average of coordinates, weighted by segmentation confidence
        spatial_mean_xy = torch.sum(grid_xy[None, None, None, :, :, :] * all_cleaned_seg[:, :, :, :, :, None], dim=(3, 4)) / torch.sum(all_cleaned_seg[:, :, :, :, :, None], dim=(3, 4))
        # Median of all replicate's means
        median_spatial_mean_xy, _ = torch.median(spatial_mean_xy, dim=0)
        # Compute shift between each replicates and the average
        shift_xy = median_spatial_mean_xy[None, :, :, :] - spatial_mean_xy
        shift_xy *= 2  # The shift for the original segs is twice the shift between cleaned segs (assuming homogenous shifts and enough segs)
        shift_xy = shift_xy.view(-1, shift_xy.shape[-1])
        shape = all_outputs["seg"].shape
        shifted_seg = kornia.geometry.translate(all_outputs["seg"].view(-1, *shape[-3:]), shift_xy).view(shape)

        # # --- DEBUG SAVE
        # for i, replicate_shifted_seg in enumerate(shifted_seg):
        #     image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, replicate_shifted_seg)
        #     image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
        #     skimage.io.imsave(f"image_seg_display_replicate_shifted_{i}.png", image_seg_display)
        # # ---

        # Compute mean shifted seg
        mean_shifted_seg = torch.mean(shifted_seg, dim=0)
        # Select replicate seg (and crossfield) that best matches mean_shifted_seg
        dice_loss = measures.dice_loss(shifted_seg[:, :, 0, :, :], mean_shifted_seg[None, :, 0, :, :])
        # Get index of the replicate that achieves the min dice_loss (as it's a loss, lower is better)
        indices_best = torch.argmin(dice_loss, dim=0)
        batch_range = torch.arange(all_outputs["seg"].shape[1])  # batch size
        # For each batch select seg and frame field from the replicate in indices_best
        final_outputs["seg"] = shifted_seg[indices_best, batch_range]
        if "crossfield" in all_outputs:
            final_outputs["crossfield"] = all_outputs["crossfield"][indices_best, batch_range]

        # if "crossfield" in all_outputs:
        #     final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
    else:
        raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
    return final_outputs


def tta_inference(model, xb, seg_threshold):
    # Perform inference several times with transformed input image and aggregate results
    replicates = 4 * 2  # 4 rotations, each with vflip/no vflip

    # Init results tensors
    notrans_outputs = model.inference(xb["image"])
    output_keys = notrans_outputs.keys()
    all_outputs = {}
    for key in output_keys:
        all_outputs[key] = torch.empty((replicates, *notrans_outputs[key].shape), dtype=notrans_outputs[key].dtype,
                                       device=notrans_outputs[key].device)
        all_outputs[key][0] = notrans_outputs[key]
    # Flip image
    flipped_image = kornia.geometry.transform.vflip(xb["image"])
    flipped_outputs = model.inference(flipped_image)
    for key in output_keys:
        reversed_output = kornia.geometry.transform.vflip(flipped_outputs[key])
        all_outputs[key][1] = reversed_output

    # --- Apply transforms one by one and add results to all_outputs
    for k in range(1, 4):
        rotated_image = torch.rot90(xb["image"], k=k, dims=(-2, -1))
        rotated_outputs = model.inference(rotated_image)
        for key in output_keys:
            reversed_output = torch.rot90(rotated_outputs[key], k=-k, dims=(-2, -1))
            if key == "crossfield":
                angle = -k * 90
                # TODO: use a faster implementation of rotate_framefield that only handles angles [0, 90, 180, 270]
                reversed_output = torch_lydorn.torchvision.transforms.functional.rotate_framefield(reversed_output,
                                                                                                   angle)
            all_outputs[key][2 * k] = reversed_output

        # Flip rotated image
        flipped_rotated_image = kornia.geometry.transform.vflip(rotated_image)
        flipped_rotated_outputs = model.inference(flipped_rotated_image)
        for key in output_keys:
            reversed_output = torch.rot90(kornia.geometry.transform.vflip(flipped_rotated_outputs[key]), k=-k,
                                          dims=(-2, -1))
            if key == "crossfield":
                angle = -k * 90
                reversed_output = torch_lydorn.torchvision.transforms.functional.vflip_framefield(reversed_output)
                reversed_output = torch_lydorn.torchvision.transforms.functional.rotate_framefield(reversed_output,
                                                                                                   angle)
            all_outputs[key][2 * k + 1] = reversed_output

    # --- DEBUG
    # all_outputs["seg"] *= 0
    # for i in range(all_outputs["seg"].shape[0]):
    #     center = 512
    #     size = 100
    #     shift_x = random.randint(-20, 20)
    #     shift_y = random.randint(-20, 20)
    #     all_outputs["seg"][i][..., center + shift_y - size:center + shift_y + size, center + shift_x - size:center + shift_x + size] = 1
    #     # Add noise
    #     noise_center_x = random.randint(100, 1024-100)
    #     noise_center_y = random.randint(100, 1024-100)
    #     noise_size = 10
    #     all_outputs["seg"][i][..., noise_center_y - noise_size:noise_center_y + noise_size, noise_center_x - noise_size:noise_center_x + noise_size] = 1
    #     # Add more noise
    #     all_outputs["seg"][i] += 0.25 * torch.rand(all_outputs["seg"][i].shape, device=all_outputs["seg"][i].device)
    #     all_outputs["seg"][i] = torch.clamp(all_outputs["seg"][i], 0, 1)


    # # --- DEBUG SAVE
    # image_display = torch_lydorn.torchvision.transforms.functional.batch_denormalize(xb["image"],
    #                                                                                  xb["image_mean"],
    #                                                                                  xb["image_std"])
    # for i, replicate_seg in enumerate(all_outputs["seg"]):
    #     image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, replicate_seg)
    #     image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
    #     skimage.io.imsave(f"image_seg_display_replicate_{i}.png", image_seg_display)
    # # ---


    # --- Aggregate results
    # final_outputs = aggr_dist_trans(all_outputs, seg_threshold)
    # final_outputs = aggr_translated(all_outputs, seg_threshold, image_display=image_display)
    # final_outputs = aggr_translated(all_outputs, seg_threshold)
    final_outputs = aggr_mean(all_outputs)
    # final_outputs = aggr_median(all_outputs)

    # # --- DEBUG SAVE
    # image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, final_outputs["seg"])
    # image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
    # skimage.io.imsave("image_seg_display_final.png", image_seg_display)
    # # ---

    # input("Press <Enter>...")

    return final_outputs