AIML / train.py
cconsti's picture
Upload 10 files
a6fa489 verified
#!/usr/bin/env python3
import os
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from model import MicrographCleaner
from dataset import TrainMicrographDataset, ValidationMicrographDataset
def main():
# Training parameters
WINDOW_SIZE = 512
BATCH_SIZE = 8
N_EPOCHS = 3 #TODO, change this to many more epochs
# Load and split data
train_df = pd.read_csv('train.csv')
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
# Create datasets and dataloaders
train_dataset = TrainMicrographDataset(train_df, window_size=WINDOW_SIZE)
val_dataset = ValidationMicrographDataset(val_df, window_size=WINDOW_SIZE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4)
# Initialize model
model = MicrographCleaner()
# Setup training
logger = TensorBoardLogger('lightning_logs', name='micrograph_cleaner')
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='checkpoints',
filename='micrograph-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min'
)
# Initialize trainer
trainer = pl.Trainer(
max_epochs=N_EPOCHS,
accelerator='auto',
devices=1,
logger=logger,
callbacks=[checkpoint_callback],
log_every_n_steps=10
)
# Train model
trainer.fit(model, train_loader, val_loader)
# Save final checkpoint as final_checkpoint.pt
trainer.save_checkpoint("final_checkpoint.pt")
if __name__ == "__main__":
main()