File size: 5,474 Bytes
a6fa489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import base64
import io
import zlib
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms.v2 as transforms
from typing import Optional, Tuple

def decode_array(encoded_base64_str):
    decoded = base64.b64decode(encoded_base64_str)
    decompressed = zlib.decompress(decoded)
    return np.load(io.BytesIO(decompressed))

def encode_array(array):
    bytes_io = io.BytesIO()
    np.save(bytes_io, array, allow_pickle=False)
    compressed = zlib.compress(bytes_io.getvalue(), level=9)
    return base64.b64encode(compressed).decode('utf-8')

class BaseMicrographDataset(Dataset):
    def __init__(self, df, window_size: int):
        self.df = df
        self.window_size = window_size

    def __len__(self) -> int:
        return len(self.df)

    def load_and_normalize_image(self, encoded_image: str) -> torch.Tensor:
        image = decode_array(encoded_image).astype(np.float32)
        image = (image - image.min()) / (image.max() - image.min())
        if len(image.shape) == 2:
            image = image[np.newaxis, ...]
        return torch.from_numpy(image)

    def load_mask(self, encoded_mask: str) -> torch.Tensor:
        mask = decode_array(encoded_mask).astype(np.float32)
        if len(mask.shape) == 2:
            mask = mask[np.newaxis, ...]
        return torch.from_numpy(mask)

    def pad_to_min_size(self, image: torch.Tensor, min_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
        _, h, w = image.shape
        pad_h = max(0, min_size - h)
        pad_w = max(0, min_size - w)
        padded = torch.nn.functional.pad(image, (0, pad_w, 0, pad_h), mode="reflect")
        return padded, (pad_h, pad_w)

class TrainMicrographDataset(BaseMicrographDataset):
    """Dataset for training with random augmentations"""

    def __init__(self, df, window_size: int):
        super().__init__(df, window_size)

        # Define training-specific transforms
        self.shared_transform = transforms.Compose([
            transforms.RandomCrop(window_size),
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip()
        ])
        self.image_only_transforms = transforms.Compose([
            transforms.GaussianBlur(7, sigma=(0.1, 2.))
        ])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        row = self.df.iloc[idx]

        # Load and preprocess image
        image = self.load_and_normalize_image(row['image'])
        image, _ = self.pad_to_min_size(image, self.window_size)
        image = self.image_only_transforms(image)

        # Load and preprocess mask
        mask = self.load_mask(row['mask'])
        mask, _ = self.pad_to_min_size(mask, self.window_size)

        # Apply shared transforms to both image and mask
        stacked = torch.cat([image, mask], dim=0)
        stacked = self.shared_transform(stacked)
        image, mask = torch.split(stacked, [1, 1], dim=0)

        return image, mask


class ValidationMicrographDataset(BaseMicrographDataset):
    """Dataset for validation using corner crops. This is a good idea because the regions of interest can be
        at the edges of the image"""

    def __init__(self, df, window_size: int):
        super().__init__(df, window_size)
        # Define 5 fixed crops: 4 corners + center
        self.n_crops = 5

    def __len__(self) -> int:
        return len(self.df) * self.n_crops

    def get_crop_coordinates(self, image_shape: Tuple[int, int], crop_idx: int) -> Tuple[int, int]:
        """Get coordinates for specific crop index"""
        h, w = image_shape

        if crop_idx == 4:  # Center crop
            h_start = (h - self.window_size) // 2
            w_start = (w - self.window_size) // 2
        else:
            h_start = 0 if crop_idx < 2 else h - self.window_size
            w_start = 0 if crop_idx % 2 == 0 else w - self.window_size

        return h_start, w_start

    def crop_tensors(self, image: torch.Tensor, mask: torch.Tensor,
                     h_start: int, w_start: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract a crop from both image and mask"""
        h_end = h_start + self.window_size
        w_end = w_start + self.window_size

        return (
            image[:, h_start:h_end, w_start:w_end],
            mask[:, h_start:h_end, w_start:w_end]
        )

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image_idx = idx // self.n_crops
        crop_idx = idx % self.n_crops
        row = self.df.iloc[image_idx]

        # Load and preprocess image and mask
        image = self.load_and_normalize_image(row['image'])
        image, _ = self.pad_to_min_size(image, self.window_size)

        mask = self.load_mask(row['mask'])
        mask, _ = self.pad_to_min_size(mask, self.window_size)

        # Get specific corner/center crop
        h_start, w_start = self.get_crop_coordinates(image.shape[1:], crop_idx)
        image, mask = self.crop_tensors(image, mask, h_start, w_start)

        return image, mask


class InferenceMicrographDataset(BaseMicrographDataset):
    """Dataset for inference without any augmentations"""

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, Tuple[int, int]]:
        row = self.df.iloc[idx]

        # Load and preprocess image
        image = self.load_and_normalize_image(row['image'])
        image, padding = self.pad_to_min_size(image, self.window_size)

        return image, row['Id'], padding