AIML / model.py
cconsti's picture
Upload 10 files
a6fa489 verified
import torch
import torch.nn as nn
import pytorch_lightning as pl
class SimpleCNN(nn.Module):
def __init__(self, n_hidden_layers, n_kernels, kernel_size):
super().__init__()
self.n_hidden_layers = n_hidden_layers
layers = [
nn.Conv2d(1, n_kernels, kernel_size=kernel_size, padding='same'),
nn.GroupNorm(4, n_kernels),
nn.PReLU()
]
for _ in range(self.n_hidden_layers):
layers.extend([
nn.Conv2d(n_kernels, n_kernels, kernel_size=kernel_size, padding='same'),
nn.GroupNorm(4, n_kernels),
nn.PReLU(),
])
layers.extend([
nn.Conv2d(n_kernels, 1, kernel_size=1),
nn.Sigmoid()
])
self.conv_layers = nn.Sequential(*layers)
def forward(self, x):
return self.conv_layers(x)
class MicrographCleaner(pl.LightningModule):
def __init__(self, n_hidden_layers=12, n_kernels=16, kernel_size=5, learning_rate=0.001):
super().__init__()
self.save_hyperparameters()
self.model = SimpleCNN(n_hidden_layers, n_kernels, kernel_size)
self.lossF = nn.BCELoss()
self.learning_rate = learning_rate
self.val_imgs_to_log = []
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
images, masks = batch
outputs = self(images)
loss = self.lossF(outputs, masks)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
images, masks = batch
outputs = self(images)
loss = self.lossF(outputs, masks)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)