AIML / dataset.py
cconsti's picture
Upload 10 files
a6fa489 verified
import base64
import io
import zlib
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms.v2 as transforms
from typing import Optional, Tuple
def decode_array(encoded_base64_str):
decoded = base64.b64decode(encoded_base64_str)
decompressed = zlib.decompress(decoded)
return np.load(io.BytesIO(decompressed))
def encode_array(array):
bytes_io = io.BytesIO()
np.save(bytes_io, array, allow_pickle=False)
compressed = zlib.compress(bytes_io.getvalue(), level=9)
return base64.b64encode(compressed).decode('utf-8')
class BaseMicrographDataset(Dataset):
def __init__(self, df, window_size: int):
self.df = df
self.window_size = window_size
def __len__(self) -> int:
return len(self.df)
def load_and_normalize_image(self, encoded_image: str) -> torch.Tensor:
image = decode_array(encoded_image).astype(np.float32)
image = (image - image.min()) / (image.max() - image.min())
if len(image.shape) == 2:
image = image[np.newaxis, ...]
return torch.from_numpy(image)
def load_mask(self, encoded_mask: str) -> torch.Tensor:
mask = decode_array(encoded_mask).astype(np.float32)
if len(mask.shape) == 2:
mask = mask[np.newaxis, ...]
return torch.from_numpy(mask)
def pad_to_min_size(self, image: torch.Tensor, min_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
_, h, w = image.shape
pad_h = max(0, min_size - h)
pad_w = max(0, min_size - w)
padded = torch.nn.functional.pad(image, (0, pad_w, 0, pad_h), mode="reflect")
return padded, (pad_h, pad_w)
class TrainMicrographDataset(BaseMicrographDataset):
"""Dataset for training with random augmentations"""
def __init__(self, df, window_size: int):
super().__init__(df, window_size)
# Define training-specific transforms
self.shared_transform = transforms.Compose([
transforms.RandomCrop(window_size),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip()
])
self.image_only_transforms = transforms.Compose([
transforms.GaussianBlur(7, sigma=(0.1, 2.))
])
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
row = self.df.iloc[idx]
# Load and preprocess image
image = self.load_and_normalize_image(row['image'])
image, _ = self.pad_to_min_size(image, self.window_size)
image = self.image_only_transforms(image)
# Load and preprocess mask
mask = self.load_mask(row['mask'])
mask, _ = self.pad_to_min_size(mask, self.window_size)
# Apply shared transforms to both image and mask
stacked = torch.cat([image, mask], dim=0)
stacked = self.shared_transform(stacked)
image, mask = torch.split(stacked, [1, 1], dim=0)
return image, mask
class ValidationMicrographDataset(BaseMicrographDataset):
"""Dataset for validation using corner crops. This is a good idea because the regions of interest can be
at the edges of the image"""
def __init__(self, df, window_size: int):
super().__init__(df, window_size)
# Define 5 fixed crops: 4 corners + center
self.n_crops = 5
def __len__(self) -> int:
return len(self.df) * self.n_crops
def get_crop_coordinates(self, image_shape: Tuple[int, int], crop_idx: int) -> Tuple[int, int]:
"""Get coordinates for specific crop index"""
h, w = image_shape
if crop_idx == 4: # Center crop
h_start = (h - self.window_size) // 2
w_start = (w - self.window_size) // 2
else:
h_start = 0 if crop_idx < 2 else h - self.window_size
w_start = 0 if crop_idx % 2 == 0 else w - self.window_size
return h_start, w_start
def crop_tensors(self, image: torch.Tensor, mask: torch.Tensor,
h_start: int, w_start: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract a crop from both image and mask"""
h_end = h_start + self.window_size
w_end = w_start + self.window_size
return (
image[:, h_start:h_end, w_start:w_end],
mask[:, h_start:h_end, w_start:w_end]
)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
image_idx = idx // self.n_crops
crop_idx = idx % self.n_crops
row = self.df.iloc[image_idx]
# Load and preprocess image and mask
image = self.load_and_normalize_image(row['image'])
image, _ = self.pad_to_min_size(image, self.window_size)
mask = self.load_mask(row['mask'])
mask, _ = self.pad_to_min_size(mask, self.window_size)
# Get specific corner/center crop
h_start, w_start = self.get_crop_coordinates(image.shape[1:], crop_idx)
image, mask = self.crop_tensors(image, mask, h_start, w_start)
return image, mask
class InferenceMicrographDataset(BaseMicrographDataset):
"""Dataset for inference without any augmentations"""
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, Tuple[int, int]]:
row = self.df.iloc[idx]
# Load and preprocess image
image = self.load_and_normalize_image(row['image'])
image, padding = self.pad_to_min_size(image, self.window_size)
return image, row['Id'], padding