File size: 1,924 Bytes
a6fa489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import pytorch_lightning as pl

class SimpleCNN(nn.Module):
    def __init__(self, n_hidden_layers, n_kernels, kernel_size):
        super().__init__()
        self.n_hidden_layers = n_hidden_layers
        layers = [
            nn.Conv2d(1, n_kernels, kernel_size=kernel_size, padding='same'),
            nn.GroupNorm(4, n_kernels),
            nn.PReLU()
        ]

        for _ in range(self.n_hidden_layers):
            layers.extend([
                nn.Conv2d(n_kernels, n_kernels, kernel_size=kernel_size, padding='same'),
                nn.GroupNorm(4, n_kernels),
                nn.PReLU(),
            ])

        layers.extend([
            nn.Conv2d(n_kernels, 1, kernel_size=1),
            nn.Sigmoid()
        ])

        self.conv_layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.conv_layers(x)

class MicrographCleaner(pl.LightningModule):
    def __init__(self, n_hidden_layers=12, n_kernels=16, kernel_size=5, learning_rate=0.001):
        super().__init__()
        self.save_hyperparameters()
        self.model = SimpleCNN(n_hidden_layers, n_kernels, kernel_size)
        self.lossF = nn.BCELoss()
        self.learning_rate = learning_rate
        self.val_imgs_to_log = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, masks = batch
        outputs = self(images)
        loss = self.lossF(outputs, masks)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, masks = batch
        outputs = self(images)
        loss = self.lossF(outputs, masks)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)