|
import base64 |
|
import io |
|
import zlib |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision.transforms.v2 as transforms |
|
from typing import Optional, Tuple |
|
|
|
def decode_array(encoded_base64_str): |
|
decoded = base64.b64decode(encoded_base64_str) |
|
decompressed = zlib.decompress(decoded) |
|
return np.load(io.BytesIO(decompressed)) |
|
|
|
def encode_array(array): |
|
bytes_io = io.BytesIO() |
|
np.save(bytes_io, array, allow_pickle=False) |
|
compressed = zlib.compress(bytes_io.getvalue(), level=9) |
|
return base64.b64encode(compressed).decode('utf-8') |
|
|
|
class BaseMicrographDataset(Dataset): |
|
def __init__(self, df, window_size: int): |
|
self.df = df |
|
self.window_size = window_size |
|
|
|
def __len__(self) -> int: |
|
return len(self.df) |
|
|
|
def load_and_normalize_image(self, encoded_image: str) -> torch.Tensor: |
|
image = decode_array(encoded_image).astype(np.float32) |
|
image = (image - image.min()) / (image.max() - image.min()) |
|
if len(image.shape) == 2: |
|
image = image[np.newaxis, ...] |
|
return torch.from_numpy(image) |
|
|
|
def load_mask(self, encoded_mask: str) -> torch.Tensor: |
|
mask = decode_array(encoded_mask).astype(np.float32) |
|
if len(mask.shape) == 2: |
|
mask = mask[np.newaxis, ...] |
|
return torch.from_numpy(mask) |
|
|
|
def pad_to_min_size(self, image: torch.Tensor, min_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: |
|
_, h, w = image.shape |
|
pad_h = max(0, min_size - h) |
|
pad_w = max(0, min_size - w) |
|
padded = torch.nn.functional.pad(image, (0, pad_w, 0, pad_h), mode="reflect") |
|
return padded, (pad_h, pad_w) |
|
|
|
class TrainMicrographDataset(BaseMicrographDataset): |
|
"""Dataset for training with random augmentations""" |
|
|
|
def __init__(self, df, window_size: int): |
|
super().__init__(df, window_size) |
|
|
|
|
|
self.shared_transform = transforms.Compose([ |
|
transforms.RandomCrop(window_size), |
|
transforms.RandomVerticalFlip(), |
|
transforms.RandomHorizontalFlip() |
|
]) |
|
self.image_only_transforms = transforms.Compose([ |
|
transforms.GaussianBlur(7, sigma=(0.1, 2.)) |
|
]) |
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
row = self.df.iloc[idx] |
|
|
|
|
|
image = self.load_and_normalize_image(row['image']) |
|
image, _ = self.pad_to_min_size(image, self.window_size) |
|
image = self.image_only_transforms(image) |
|
|
|
|
|
mask = self.load_mask(row['mask']) |
|
mask, _ = self.pad_to_min_size(mask, self.window_size) |
|
|
|
|
|
stacked = torch.cat([image, mask], dim=0) |
|
stacked = self.shared_transform(stacked) |
|
image, mask = torch.split(stacked, [1, 1], dim=0) |
|
|
|
return image, mask |
|
|
|
|
|
class ValidationMicrographDataset(BaseMicrographDataset): |
|
"""Dataset for validation using corner crops. This is a good idea because the regions of interest can be |
|
at the edges of the image""" |
|
|
|
def __init__(self, df, window_size: int): |
|
super().__init__(df, window_size) |
|
|
|
self.n_crops = 5 |
|
|
|
def __len__(self) -> int: |
|
return len(self.df) * self.n_crops |
|
|
|
def get_crop_coordinates(self, image_shape: Tuple[int, int], crop_idx: int) -> Tuple[int, int]: |
|
"""Get coordinates for specific crop index""" |
|
h, w = image_shape |
|
|
|
if crop_idx == 4: |
|
h_start = (h - self.window_size) // 2 |
|
w_start = (w - self.window_size) // 2 |
|
else: |
|
h_start = 0 if crop_idx < 2 else h - self.window_size |
|
w_start = 0 if crop_idx % 2 == 0 else w - self.window_size |
|
|
|
return h_start, w_start |
|
|
|
def crop_tensors(self, image: torch.Tensor, mask: torch.Tensor, |
|
h_start: int, w_start: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Extract a crop from both image and mask""" |
|
h_end = h_start + self.window_size |
|
w_end = w_start + self.window_size |
|
|
|
return ( |
|
image[:, h_start:h_end, w_start:w_end], |
|
mask[:, h_start:h_end, w_start:w_end] |
|
) |
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
image_idx = idx // self.n_crops |
|
crop_idx = idx % self.n_crops |
|
row = self.df.iloc[image_idx] |
|
|
|
|
|
image = self.load_and_normalize_image(row['image']) |
|
image, _ = self.pad_to_min_size(image, self.window_size) |
|
|
|
mask = self.load_mask(row['mask']) |
|
mask, _ = self.pad_to_min_size(mask, self.window_size) |
|
|
|
|
|
h_start, w_start = self.get_crop_coordinates(image.shape[1:], crop_idx) |
|
image, mask = self.crop_tensors(image, mask, h_start, w_start) |
|
|
|
return image, mask |
|
|
|
|
|
class InferenceMicrographDataset(BaseMicrographDataset): |
|
"""Dataset for inference without any augmentations""" |
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, Tuple[int, int]]: |
|
row = self.df.iloc[idx] |
|
|
|
|
|
image = self.load_and_normalize_image(row['image']) |
|
image, padding = self.pad_to_min_size(image, self.window_size) |
|
|
|
return image, row['Id'], padding |