File size: 12,188 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import kornia
import torch
import torch_lydorn.torchvision
# from pytorch_memlab import profile, profile_every
from frame_field_learning import measures
import cv2 as cv
import numpy as np
def compute_distance_transform(tensor: torch.Tensor) -> torch.Tensor:
device = tensor.device
array = tensor.cpu().numpy()
shape = array.shape
array = array.reshape(-1, *shape[-2:]).astype(np.uint8)
dist_trans = np.empty(array.shape, dtype=np.float32)
for i in range(array.shape[0]):
dist_trans[i] = cv.distanceTransform(array[i], distanceType=cv.DIST_L2, maskSize=cv.DIST_MASK_5, dstType=cv.CV_64F)
dist_trans = dist_trans.reshape(shape)
dist_trans = torch.tensor(dist_trans, device=device)
return dist_trans
def select_crossfield(all_outputs, final_seg):
# Choose frame field from the replicate that best matches the final seg interior
dice_loss = measures.dice_loss(all_outputs["seg"][:, :, 0, :, :], final_seg[None, :, 0, :, :])
# Get index of the replicate that achieves the min dice_loss (as it's a loss, lower is better)
indices_best = torch.argmin(dice_loss, dim=0)
batch_range = torch.arange(all_outputs["seg"].shape[1]) # batch size
# For each batch select frame field from the replicate in indices_best
final_crossfield = all_outputs["crossfield"][indices_best, batch_range]
return final_crossfield
def aggr_mean(all_outputs):
final_outputs = {}
if "seg" in all_outputs:
final_seg = torch.mean(all_outputs["seg"], dim=0)
final_outputs["seg"] = final_seg # Final seg is between min and max: positive pixels are closer to min
if "crossfield" in all_outputs:
final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
else:
raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
return final_outputs
def aggr_median(all_outputs):
final_outputs = {}
if "seg" in all_outputs:
final_seg, _ = torch.median(all_outputs["seg"], dim=0)
final_outputs["seg"] = final_seg # Final seg is between min and max: positive pixels are closer to min
if "crossfield" in all_outputs:
final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
else:
raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
return final_outputs
def aggr_dist_trans(all_outputs, seg_threshold):
final_outputs = {}
if "seg" in all_outputs:
min_seg, _ = torch.min(all_outputs["seg"], dim=0)
max_seg, _ = torch.max(all_outputs["seg"], dim=0)
# Final seg will be between min and max seg. The idea is that we don't loose the sharp corners (which taking the mean does)
dist_ext_to_min_seg = compute_distance_transform(min_seg < seg_threshold)
dist_int_to_max_seg = compute_distance_transform(seg_threshold < max_seg)
final_seg = dist_ext_to_min_seg < dist_int_to_max_seg
final_outputs["seg"] = final_seg # Final seg is between min and max: positive pixels are closer to min
if "crossfield" in all_outputs:
final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
else:
raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
return final_outputs
def aggr_translated(all_outputs, seg_threshold, image_display=None):
final_outputs = {}
if "seg" in all_outputs:
# Cleanup all_seg by multiplying with the mean seg
all_seg = all_outputs["seg"]
all_seg_mask: torch.Tensor = seg_threshold < all_seg
mean_seg = torch.mean(all_seg_mask.float(), dim=0)
mean_seg_mask = seg_threshold < mean_seg
all_cleaned_seg = all_seg_mask * mean_seg[None, ...]
# all_cleaned_seg_mask = seg_threshold < all_cleaned_seg
# all_cleaned_seg[~all_cleaned_seg_mask] = 0 # Put 0 where seg is below threshold
# # --- DEBUG SAVE
# image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, mean_seg)
# image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
# skimage.io.imsave(f"image_seg_display_mean_seg.png", image_seg_display)
# for i, cleaned_seg in enumerate(all_cleaned_seg):
# image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, cleaned_seg)
# image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
# skimage.io.imsave(f"image_seg_display_replicate_cleaned_{i}.png", image_seg_display)
# # ---
# Compute barycenter of all cleaned segs
range_x = torch.arange(all_cleaned_seg.shape[4], device=all_cleaned_seg.device)
range_y = torch.arange(all_cleaned_seg.shape[3], device=all_cleaned_seg.device)
grid_y, grid_x = torch.meshgrid([range_x, range_y])
grid_xy = torch.stack([grid_x, grid_y], dim=-1)
# Average of coordinates, weighted by segmentation confidence
spatial_mean_xy = torch.sum(grid_xy[None, None, None, :, :, :] * all_cleaned_seg[:, :, :, :, :, None], dim=(3, 4)) / torch.sum(all_cleaned_seg[:, :, :, :, :, None], dim=(3, 4))
# Median of all replicate's means
median_spatial_mean_xy, _ = torch.median(spatial_mean_xy, dim=0)
# Compute shift between each replicates and the average
shift_xy = median_spatial_mean_xy[None, :, :, :] - spatial_mean_xy
shift_xy *= 2 # The shift for the original segs is twice the shift between cleaned segs (assuming homogenous shifts and enough segs)
shift_xy = shift_xy.view(-1, shift_xy.shape[-1])
shape = all_outputs["seg"].shape
shifted_seg = kornia.geometry.translate(all_outputs["seg"].view(-1, *shape[-3:]), shift_xy).view(shape)
# # --- DEBUG SAVE
# for i, replicate_shifted_seg in enumerate(shifted_seg):
# image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, replicate_shifted_seg)
# image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
# skimage.io.imsave(f"image_seg_display_replicate_shifted_{i}.png", image_seg_display)
# # ---
# Compute mean shifted seg
mean_shifted_seg = torch.mean(shifted_seg, dim=0)
# Select replicate seg (and crossfield) that best matches mean_shifted_seg
dice_loss = measures.dice_loss(shifted_seg[:, :, 0, :, :], mean_shifted_seg[None, :, 0, :, :])
# Get index of the replicate that achieves the min dice_loss (as it's a loss, lower is better)
indices_best = torch.argmin(dice_loss, dim=0)
batch_range = torch.arange(all_outputs["seg"].shape[1]) # batch size
# For each batch select seg and frame field from the replicate in indices_best
final_outputs["seg"] = shifted_seg[indices_best, batch_range]
if "crossfield" in all_outputs:
final_outputs["crossfield"] = all_outputs["crossfield"][indices_best, batch_range]
# if "crossfield" in all_outputs:
# final_outputs["crossfield"] = select_crossfield(all_outputs, final_seg)
else:
raise NotImplementedError("Test Time Augmentation requires segmentation to be computed.")
return final_outputs
def tta_inference(model, xb, seg_threshold):
# Perform inference several times with transformed input image and aggregate results
replicates = 4 * 2 # 4 rotations, each with vflip/no vflip
# Init results tensors
notrans_outputs = model.inference(xb["image"])
output_keys = notrans_outputs.keys()
all_outputs = {}
for key in output_keys:
all_outputs[key] = torch.empty((replicates, *notrans_outputs[key].shape), dtype=notrans_outputs[key].dtype,
device=notrans_outputs[key].device)
all_outputs[key][0] = notrans_outputs[key]
# Flip image
flipped_image = kornia.geometry.transform.vflip(xb["image"])
flipped_outputs = model.inference(flipped_image)
for key in output_keys:
reversed_output = kornia.geometry.transform.vflip(flipped_outputs[key])
all_outputs[key][1] = reversed_output
# --- Apply transforms one by one and add results to all_outputs
for k in range(1, 4):
rotated_image = torch.rot90(xb["image"], k=k, dims=(-2, -1))
rotated_outputs = model.inference(rotated_image)
for key in output_keys:
reversed_output = torch.rot90(rotated_outputs[key], k=-k, dims=(-2, -1))
if key == "crossfield":
angle = -k * 90
# TODO: use a faster implementation of rotate_framefield that only handles angles [0, 90, 180, 270]
reversed_output = torch_lydorn.torchvision.transforms.functional.rotate_framefield(reversed_output,
angle)
all_outputs[key][2 * k] = reversed_output
# Flip rotated image
flipped_rotated_image = kornia.geometry.transform.vflip(rotated_image)
flipped_rotated_outputs = model.inference(flipped_rotated_image)
for key in output_keys:
reversed_output = torch.rot90(kornia.geometry.transform.vflip(flipped_rotated_outputs[key]), k=-k,
dims=(-2, -1))
if key == "crossfield":
angle = -k * 90
reversed_output = torch_lydorn.torchvision.transforms.functional.vflip_framefield(reversed_output)
reversed_output = torch_lydorn.torchvision.transforms.functional.rotate_framefield(reversed_output,
angle)
all_outputs[key][2 * k + 1] = reversed_output
# --- DEBUG
# all_outputs["seg"] *= 0
# for i in range(all_outputs["seg"].shape[0]):
# center = 512
# size = 100
# shift_x = random.randint(-20, 20)
# shift_y = random.randint(-20, 20)
# all_outputs["seg"][i][..., center + shift_y - size:center + shift_y + size, center + shift_x - size:center + shift_x + size] = 1
# # Add noise
# noise_center_x = random.randint(100, 1024-100)
# noise_center_y = random.randint(100, 1024-100)
# noise_size = 10
# all_outputs["seg"][i][..., noise_center_y - noise_size:noise_center_y + noise_size, noise_center_x - noise_size:noise_center_x + noise_size] = 1
# # Add more noise
# all_outputs["seg"][i] += 0.25 * torch.rand(all_outputs["seg"][i].shape, device=all_outputs["seg"][i].device)
# all_outputs["seg"][i] = torch.clamp(all_outputs["seg"][i], 0, 1)
# # --- DEBUG SAVE
# image_display = torch_lydorn.torchvision.transforms.functional.batch_denormalize(xb["image"],
# xb["image_mean"],
# xb["image_std"])
# for i, replicate_seg in enumerate(all_outputs["seg"]):
# image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, replicate_seg)
# image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
# skimage.io.imsave(f"image_seg_display_replicate_{i}.png", image_seg_display)
# # ---
# --- Aggregate results
# final_outputs = aggr_dist_trans(all_outputs, seg_threshold)
# final_outputs = aggr_translated(all_outputs, seg_threshold, image_display=image_display)
# final_outputs = aggr_translated(all_outputs, seg_threshold)
final_outputs = aggr_mean(all_outputs)
# final_outputs = aggr_median(all_outputs)
# # --- DEBUG SAVE
# image_seg_display = plot_utils.get_tensorboard_image_seg_display(image_display, final_outputs["seg"])
# image_seg_display = image_seg_display[0].cpu().detach().numpy().transpose(1, 2, 0)
# skimage.io.imsave("image_seg_display_final.png", image_seg_display)
# # ---
# input("Press <Enter>...")
return final_outputs |