|
import torch |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
|
|
class SimpleCNN(nn.Module): |
|
def __init__(self, n_hidden_layers, n_kernels, kernel_size): |
|
super().__init__() |
|
self.n_hidden_layers = n_hidden_layers |
|
layers = [ |
|
nn.Conv2d(1, n_kernels, kernel_size=kernel_size, padding='same'), |
|
nn.GroupNorm(4, n_kernels), |
|
nn.PReLU() |
|
] |
|
|
|
for _ in range(self.n_hidden_layers): |
|
layers.extend([ |
|
nn.Conv2d(n_kernels, n_kernels, kernel_size=kernel_size, padding='same'), |
|
nn.GroupNorm(4, n_kernels), |
|
nn.PReLU(), |
|
]) |
|
|
|
layers.extend([ |
|
nn.Conv2d(n_kernels, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
]) |
|
|
|
self.conv_layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.conv_layers(x) |
|
|
|
class MicrographCleaner(pl.LightningModule): |
|
def __init__(self, n_hidden_layers=12, n_kernels=16, kernel_size=5, learning_rate=0.001): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.model = SimpleCNN(n_hidden_layers, n_kernels, kernel_size) |
|
self.lossF = nn.BCELoss() |
|
self.learning_rate = learning_rate |
|
self.val_imgs_to_log = [] |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
images, masks = batch |
|
outputs = self(images) |
|
loss = self.lossF(outputs, masks) |
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
images, masks = batch |
|
outputs = self(images) |
|
loss = self.lossF(outputs, masks) |
|
self.log('val_loss', loss, on_epoch=True, prog_bar=True) |
|
return loss |
|
|
|
def configure_optimizers(self): |
|
return torch.optim.Adam(self.parameters(), lr=self.learning_rate) |