cconsti commited on
Commit
a6fa489
·
verified ·
1 Parent(s): b4d264b

Upload 10 files

Browse files
Files changed (10) hide show
  1. dataset.py +150 -0
  2. final_checkpoint.ckpt +3 -0
  3. infer.py +76 -0
  4. inference_utils.py +54 -0
  5. kaggle_id.txt +1 -0
  6. model.py +59 -0
  7. report.pdf +0 -0
  8. report_template.md +38 -0
  9. requirements.txt +8 -0
  10. train.py +62 -0
dataset.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import zlib
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torchvision.transforms.v2 as transforms
8
+ from typing import Optional, Tuple
9
+
10
+ def decode_array(encoded_base64_str):
11
+ decoded = base64.b64decode(encoded_base64_str)
12
+ decompressed = zlib.decompress(decoded)
13
+ return np.load(io.BytesIO(decompressed))
14
+
15
+ def encode_array(array):
16
+ bytes_io = io.BytesIO()
17
+ np.save(bytes_io, array, allow_pickle=False)
18
+ compressed = zlib.compress(bytes_io.getvalue(), level=9)
19
+ return base64.b64encode(compressed).decode('utf-8')
20
+
21
+ class BaseMicrographDataset(Dataset):
22
+ def __init__(self, df, window_size: int):
23
+ self.df = df
24
+ self.window_size = window_size
25
+
26
+ def __len__(self) -> int:
27
+ return len(self.df)
28
+
29
+ def load_and_normalize_image(self, encoded_image: str) -> torch.Tensor:
30
+ image = decode_array(encoded_image).astype(np.float32)
31
+ image = (image - image.min()) / (image.max() - image.min())
32
+ if len(image.shape) == 2:
33
+ image = image[np.newaxis, ...]
34
+ return torch.from_numpy(image)
35
+
36
+ def load_mask(self, encoded_mask: str) -> torch.Tensor:
37
+ mask = decode_array(encoded_mask).astype(np.float32)
38
+ if len(mask.shape) == 2:
39
+ mask = mask[np.newaxis, ...]
40
+ return torch.from_numpy(mask)
41
+
42
+ def pad_to_min_size(self, image: torch.Tensor, min_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
43
+ _, h, w = image.shape
44
+ pad_h = max(0, min_size - h)
45
+ pad_w = max(0, min_size - w)
46
+ padded = torch.nn.functional.pad(image, (0, pad_w, 0, pad_h), mode="reflect")
47
+ return padded, (pad_h, pad_w)
48
+
49
+ class TrainMicrographDataset(BaseMicrographDataset):
50
+ """Dataset for training with random augmentations"""
51
+
52
+ def __init__(self, df, window_size: int):
53
+ super().__init__(df, window_size)
54
+
55
+ # Define training-specific transforms
56
+ self.shared_transform = transforms.Compose([
57
+ transforms.RandomCrop(window_size),
58
+ transforms.RandomVerticalFlip(),
59
+ transforms.RandomHorizontalFlip()
60
+ ])
61
+ self.image_only_transforms = transforms.Compose([
62
+ transforms.GaussianBlur(7, sigma=(0.1, 2.))
63
+ ])
64
+
65
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
66
+ row = self.df.iloc[idx]
67
+
68
+ # Load and preprocess image
69
+ image = self.load_and_normalize_image(row['image'])
70
+ image, _ = self.pad_to_min_size(image, self.window_size)
71
+ image = self.image_only_transforms(image)
72
+
73
+ # Load and preprocess mask
74
+ mask = self.load_mask(row['mask'])
75
+ mask, _ = self.pad_to_min_size(mask, self.window_size)
76
+
77
+ # Apply shared transforms to both image and mask
78
+ stacked = torch.cat([image, mask], dim=0)
79
+ stacked = self.shared_transform(stacked)
80
+ image, mask = torch.split(stacked, [1, 1], dim=0)
81
+
82
+ return image, mask
83
+
84
+
85
+ class ValidationMicrographDataset(BaseMicrographDataset):
86
+ """Dataset for validation using corner crops. This is a good idea because the regions of interest can be
87
+ at the edges of the image"""
88
+
89
+ def __init__(self, df, window_size: int):
90
+ super().__init__(df, window_size)
91
+ # Define 5 fixed crops: 4 corners + center
92
+ self.n_crops = 5
93
+
94
+ def __len__(self) -> int:
95
+ return len(self.df) * self.n_crops
96
+
97
+ def get_crop_coordinates(self, image_shape: Tuple[int, int], crop_idx: int) -> Tuple[int, int]:
98
+ """Get coordinates for specific crop index"""
99
+ h, w = image_shape
100
+
101
+ if crop_idx == 4: # Center crop
102
+ h_start = (h - self.window_size) // 2
103
+ w_start = (w - self.window_size) // 2
104
+ else:
105
+ h_start = 0 if crop_idx < 2 else h - self.window_size
106
+ w_start = 0 if crop_idx % 2 == 0 else w - self.window_size
107
+
108
+ return h_start, w_start
109
+
110
+ def crop_tensors(self, image: torch.Tensor, mask: torch.Tensor,
111
+ h_start: int, w_start: int) -> Tuple[torch.Tensor, torch.Tensor]:
112
+ """Extract a crop from both image and mask"""
113
+ h_end = h_start + self.window_size
114
+ w_end = w_start + self.window_size
115
+
116
+ return (
117
+ image[:, h_start:h_end, w_start:w_end],
118
+ mask[:, h_start:h_end, w_start:w_end]
119
+ )
120
+
121
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ image_idx = idx // self.n_crops
123
+ crop_idx = idx % self.n_crops
124
+ row = self.df.iloc[image_idx]
125
+
126
+ # Load and preprocess image and mask
127
+ image = self.load_and_normalize_image(row['image'])
128
+ image, _ = self.pad_to_min_size(image, self.window_size)
129
+
130
+ mask = self.load_mask(row['mask'])
131
+ mask, _ = self.pad_to_min_size(mask, self.window_size)
132
+
133
+ # Get specific corner/center crop
134
+ h_start, w_start = self.get_crop_coordinates(image.shape[1:], crop_idx)
135
+ image, mask = self.crop_tensors(image, mask, h_start, w_start)
136
+
137
+ return image, mask
138
+
139
+
140
+ class InferenceMicrographDataset(BaseMicrographDataset):
141
+ """Dataset for inference without any augmentations"""
142
+
143
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, Tuple[int, int]]:
144
+ row = self.df.iloc[idx]
145
+
146
+ # Load and preprocess image
147
+ image = self.load_and_normalize_image(row['image'])
148
+ image, padding = self.pad_to_min_size(image, self.window_size)
149
+
150
+ return image, row['Id'], padding
final_checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24ceee5d5db945c0d25ecfde13508b40f039842ab864aafa714f048ccc17a881
3
+ size 1016916
infer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import pandas as pd
4
+ import torch
5
+ from model import MicrographCleaner
6
+ from dataset import InferenceMicrographDataset, decode_array
7
+ from inference_utils import sliding_window_inference
8
+ import matplotlib.pyplot as plt
9
+ import tqdm
10
+
11
+
12
+ def main():
13
+ # Create predictions directory if it doesn't exist
14
+ os.makedirs('predictions', exist_ok=True)
15
+
16
+ # Parameters
17
+ WINDOW_SIZE = 512
18
+ THRESHOLD = 0.5
19
+ OVERLAP = 0.5
20
+
21
+ # Load model
22
+ model = MicrographCleaner.load_from_checkpoint('final_checkpoint.ckpt', map_location='cpu')
23
+ model.eval()
24
+
25
+ # Load test data
26
+ test_df = pd.read_csv('test.csv')
27
+ test_dataset = InferenceMicrographDataset(test_df, window_size=WINDOW_SIZE)
28
+
29
+ # Process each image
30
+ unique_ids = set()
31
+ model.eval()
32
+ with torch.inference_mode():
33
+ for idx in tqdm.tqdm(range(len(test_dataset))):
34
+ image, image_id, (pad_h, pad_w) = test_dataset[idx]
35
+
36
+ # Skip if already processed
37
+ if image_id in unique_ids:
38
+ continue
39
+ unique_ids.add(image_id)
40
+
41
+ # Perform inference
42
+ pred = sliding_window_inference(
43
+ model,
44
+ image,
45
+ window_size=WINDOW_SIZE,
46
+ overlap=OVERLAP
47
+ )
48
+
49
+ # Remove padding if necessary
50
+ if pad_h > 0:
51
+ pred = pred[..., :-pad_h, :]
52
+ if pad_w > 0:
53
+ pred = pred[..., :-pad_w]
54
+
55
+ # Convert to binary mask
56
+ pred_mask = (pred > THRESHOLD).cpu().numpy()[0]
57
+
58
+ # Create visualization
59
+ orig_image = decode_array(test_df.iloc[idx]['image'])
60
+
61
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
62
+ ax1.imshow(orig_image, cmap='gray')
63
+ ax1.set_title('Original Image')
64
+ ax1.axis('off')
65
+
66
+ ax2.imshow(pred_mask, cmap='gray')
67
+ ax2.set_title('Predicted Mask')
68
+ ax2.axis('off')
69
+
70
+ plt.tight_layout()
71
+ plt.savefig(f'predictions/{image_id}_prediction.png')
72
+ plt.close()
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
inference_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def sliding_window_inference(model, image, window_size, overlap=0.5):
5
+ """Perform sliding window inference on large images"""
6
+ model.eval()
7
+
8
+ # Get dimensions
9
+ _, height, width = image.shape
10
+ stride = int(window_size * (1 - overlap))
11
+
12
+ # Calculate number of windows needed
13
+ n_h = int(np.ceil((height - window_size) / stride) + 1)
14
+ n_w = int(np.ceil((width - window_size) / stride) + 1)
15
+
16
+ # Create empty prediction map and count map for averaging
17
+ pred_map = torch.zeros((1, height, width)).to(model.device)
18
+ count_map = torch.zeros((1, height, width)).to(model.device)
19
+
20
+ # Slide window over image
21
+ with torch.no_grad():
22
+ for i in range(n_h):
23
+ for j in range(n_w):
24
+ # Calculate window boundaries
25
+ h_start = min(i * stride, height - window_size)
26
+ w_start = min(j * stride, width - window_size)
27
+ h_end = h_start + window_size
28
+ w_end = w_start + window_size
29
+
30
+ # Extract window
31
+ window = image[:, h_start:h_end, w_start:w_end]
32
+
33
+ # If window is smaller than window_size, pad it
34
+ if window.shape[1:] != (window_size, window_size):
35
+ pad_h = window_size - window.shape[1]
36
+ pad_w = window_size - window.shape[2]
37
+ window = torch.nn.functional.pad(window, (0, pad_w, 0, pad_h))
38
+
39
+ # Make prediction
40
+ window = window.unsqueeze(0) # Add batch dimension
41
+ pred = model(window)
42
+ pred = pred.squeeze(0) # Remove batch dimension
43
+
44
+ # If window was padded, remove padding from prediction
45
+ if window.shape[2] - h_end + h_start > 0 or window.shape[3] - w_end + w_start > 0:
46
+ pred = pred[:, :h_end - h_start, :w_end - w_start]
47
+
48
+ # Add prediction to map
49
+ pred_map[:, h_start:h_end, w_start:w_end] += pred
50
+ count_map[:, h_start:h_end, w_start:w_end] += 1
51
+
52
+ # Average overlapping predictions
53
+ final_pred = pred_map / count_map
54
+ return final_pred.cpu()
kaggle_id.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ rsancg00
model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+
5
+ class SimpleCNN(nn.Module):
6
+ def __init__(self, n_hidden_layers, n_kernels, kernel_size):
7
+ super().__init__()
8
+ self.n_hidden_layers = n_hidden_layers
9
+ layers = [
10
+ nn.Conv2d(1, n_kernels, kernel_size=kernel_size, padding='same'),
11
+ nn.GroupNorm(4, n_kernels),
12
+ nn.PReLU()
13
+ ]
14
+
15
+ for _ in range(self.n_hidden_layers):
16
+ layers.extend([
17
+ nn.Conv2d(n_kernels, n_kernels, kernel_size=kernel_size, padding='same'),
18
+ nn.GroupNorm(4, n_kernels),
19
+ nn.PReLU(),
20
+ ])
21
+
22
+ layers.extend([
23
+ nn.Conv2d(n_kernels, 1, kernel_size=1),
24
+ nn.Sigmoid()
25
+ ])
26
+
27
+ self.conv_layers = nn.Sequential(*layers)
28
+
29
+ def forward(self, x):
30
+ return self.conv_layers(x)
31
+
32
+ class MicrographCleaner(pl.LightningModule):
33
+ def __init__(self, n_hidden_layers=12, n_kernels=16, kernel_size=5, learning_rate=0.001):
34
+ super().__init__()
35
+ self.save_hyperparameters()
36
+ self.model = SimpleCNN(n_hidden_layers, n_kernels, kernel_size)
37
+ self.lossF = nn.BCELoss()
38
+ self.learning_rate = learning_rate
39
+ self.val_imgs_to_log = []
40
+
41
+ def forward(self, x):
42
+ return self.model(x)
43
+
44
+ def training_step(self, batch, batch_idx):
45
+ images, masks = batch
46
+ outputs = self(images)
47
+ loss = self.lossF(outputs, masks)
48
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
49
+ return loss
50
+
51
+ def validation_step(self, batch, batch_idx):
52
+ images, masks = batch
53
+ outputs = self(images)
54
+ loss = self.lossF(outputs, masks)
55
+ self.log('val_loss', loss, on_epoch=True, prog_bar=True)
56
+ return loss
57
+
58
+ def configure_optimizers(self):
59
+ return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
report.pdf ADDED
File without changes
report_template.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cryo-EM Image Segmentation Report
2
+
3
+ ## Phase 1: Manual Implementation
4
+
5
+ ### Approach
6
+ [Describe your approach to solving the problem, including the model architecture, loss functions, and training strategy]
7
+
8
+ ### Experiments
9
+
10
+ | Experiment | Description | Training Loss | Validation Loss | Public Score | Private Score |
11
+ |------------|-------------|---------------|-----------------|--------------|---------------|
12
+ | Baseline | Simple CNN | 0.XX | 0.XX | 0.XX | 0.XX |
13
+ | Exp 1 | [Change 1] | 0.XX | 0.XX | 0.XX | 0.XX |
14
+ | Exp 2 | [Change 2] | 0.XX | 0.XX | 0.XX | 0.XX |
15
+
16
+ ### Training Curves
17
+ [Insert training and validation loss curves for your final solution]
18
+
19
+ ### Analysis
20
+ [Analyze the results of your experiments, discussing what worked and what didn't]
21
+
22
+ ## Phase 2: Open Resources
23
+
24
+ ### Approach
25
+ [Describe the tools and pre-implemented solutions you used]
26
+
27
+ ### Results
28
+
29
+ | Method | Description | Public Score | Private Score |
30
+ |--------|-------------|--------------|---------------|
31
+ | Method 1| [Description]| 0.XX | 0.XX |
32
+ | Method 2| [Description]| 0.XX | 0.XX |
33
+
34
+ ### Comparison
35
+ [Compare the results between Phase 1 and Phase 2, discussing the benefits and drawbacks of each approach]
36
+
37
+ ## Conclusions
38
+ [Summarize your findings and discuss potential future improvements]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ pytorch-lightning>=2.0.0
4
+ pandas>=1.5.0
5
+ numpy>=1.23.0
6
+ matplotlib>=3.5.0
7
+ scikit-learn>=1.0.0
8
+ tqdm>=4.65.0
train.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import pandas as pd
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.callbacks import ModelCheckpoint
6
+ from pytorch_lightning.loggers import TensorBoardLogger
7
+ from torch.utils.data import DataLoader
8
+ from sklearn.model_selection import train_test_split
9
+
10
+ from model import MicrographCleaner
11
+ from dataset import TrainMicrographDataset, ValidationMicrographDataset
12
+
13
+
14
+ def main():
15
+ # Training parameters
16
+ WINDOW_SIZE = 512
17
+ BATCH_SIZE = 8
18
+ N_EPOCHS = 3 #TODO, change this to many more epochs
19
+
20
+ # Load and split data
21
+ train_df = pd.read_csv('train.csv')
22
+ train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
23
+
24
+ # Create datasets and dataloaders
25
+ train_dataset = TrainMicrographDataset(train_df, window_size=WINDOW_SIZE)
26
+ val_dataset = ValidationMicrographDataset(val_df, window_size=WINDOW_SIZE)
27
+
28
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
29
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4)
30
+
31
+ # Initialize model
32
+ model = MicrographCleaner()
33
+
34
+ # Setup training
35
+ logger = TensorBoardLogger('lightning_logs', name='micrograph_cleaner')
36
+ checkpoint_callback = ModelCheckpoint(
37
+ monitor='val_loss',
38
+ dirpath='checkpoints',
39
+ filename='micrograph-{epoch:02d}-{val_loss:.2f}',
40
+ save_top_k=3,
41
+ mode='min'
42
+ )
43
+
44
+ # Initialize trainer
45
+ trainer = pl.Trainer(
46
+ max_epochs=N_EPOCHS,
47
+ accelerator='auto',
48
+ devices=1,
49
+ logger=logger,
50
+ callbacks=[checkpoint_callback],
51
+ log_every_n_steps=10
52
+ )
53
+
54
+ # Train model
55
+ trainer.fit(model, train_loader, val_loader)
56
+
57
+ # Save final checkpoint as final_checkpoint.pt
58
+ trainer.save_checkpoint("final_checkpoint.pt")
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()