|
|
|
import os |
|
import pandas as pd |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from torch.utils.data import DataLoader |
|
from sklearn.model_selection import train_test_split |
|
|
|
from model import MicrographCleaner |
|
from dataset import TrainMicrographDataset, ValidationMicrographDataset |
|
|
|
|
|
def main(): |
|
|
|
WINDOW_SIZE = 512 |
|
BATCH_SIZE = 8 |
|
N_EPOCHS = 3 |
|
|
|
|
|
train_df = pd.read_csv('train.csv') |
|
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42) |
|
|
|
|
|
train_dataset = TrainMicrographDataset(train_df, window_size=WINDOW_SIZE) |
|
val_dataset = ValidationMicrographDataset(val_df, window_size=WINDOW_SIZE) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) |
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4) |
|
|
|
|
|
model = MicrographCleaner() |
|
|
|
|
|
logger = TensorBoardLogger('lightning_logs', name='micrograph_cleaner') |
|
checkpoint_callback = ModelCheckpoint( |
|
monitor='val_loss', |
|
dirpath='checkpoints', |
|
filename='micrograph-{epoch:02d}-{val_loss:.2f}', |
|
save_top_k=3, |
|
mode='min' |
|
) |
|
|
|
|
|
trainer = pl.Trainer( |
|
max_epochs=N_EPOCHS, |
|
accelerator='auto', |
|
devices=1, |
|
logger=logger, |
|
callbacks=[checkpoint_callback], |
|
log_every_n_steps=10 |
|
) |
|
|
|
|
|
trainer.fit(model, train_loader, val_loader) |
|
|
|
|
|
trainer.save_checkpoint("final_checkpoint.pt") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |