|
import csv |
|
import h5py |
|
import torch |
|
import torch.nn as nn |
|
import random |
|
import numpy as np |
|
import os |
|
import shutil |
|
import pandas as pd |
|
from torchvision import transforms |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset, random_split |
|
import torch.optim as optim |
|
import time |
|
from tqdm import tqdm |
|
from torch.optim import lr_scheduler |
|
from transformers import ViTFeatureExtractor, AutoImageProcessor, ViTMAEConfig, ViTMAEModel, ViTMAEForPreTraining |
|
from torchvision.datasets import ImageFolder |
|
import lightning.pytorch as pl |
|
from lightning.pytorch import Trainer |
|
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, RichProgressBar |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
from lightning.pytorch.callbacks import RichProgressBar |
|
from lightning.pytorch.callbacks import TQDMProgressBar |
|
from lightning.pytorch.utilities import rank_zero_only |
|
|
|
DEVICE_NUM = torch.cuda.device_count() |
|
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in range(DEVICE_NUM)]) |
|
|
|
SEED = 42 |
|
DATA_DIR = "../../0.data/pretrain_nucleus_image_all_16M.hdf5" |
|
BATCH_SIZE = 400 *2 |
|
NUM_EPOCHS = 70 |
|
LEARNINGRATE = 0.0001 |
|
PROJECT_NAME = 'Nuspire_Pretraining_V5' |
|
|
|
transform = transforms.Compose([ |
|
transforms.Grayscale(), |
|
transforms.RandomResizedCrop((112, 112), scale=(0.5625, 1.0), ratio=(0.75, 1.33)), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomVerticalFlip(p=0.5), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102]) |
|
]) |
|
|
|
configuration = ViTMAEConfig( |
|
hidden_size=768, |
|
num_hidden_layers=12, |
|
num_attention_heads=12, |
|
intermediate_size=3072, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.0, |
|
attention_probs_dropout_prob=0.0, |
|
initializer_range=0.02, |
|
layer_norm_eps=1e-12, |
|
image_size=112, |
|
patch_size=8, |
|
num_channels=1, |
|
qkv_bias=True, |
|
decoder_num_attention_heads=16, |
|
decoder_hidden_size=512, |
|
decoder_num_hidden_layers=8, |
|
decoder_intermediate_size=1024, |
|
mask_ratio=0.75, |
|
norm_pix_loss=False |
|
) |
|
|
|
class HDF5Dataset(Dataset): |
|
def __init__(self, hdf5_path, transform=None): |
|
self.hdf5_path = hdf5_path |
|
self.transform = transform |
|
self.hdf5_file = h5py.File(hdf5_path, 'r', rdcc_nbytes=10*1024**3, rdcc_w0=0.0, rdcc_nslots=10007) |
|
self.images = self.hdf5_file['images'] |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img = self.images[idx] |
|
|
|
if self.transform: |
|
img = Image.fromarray(img) |
|
img = self.transform(img) |
|
|
|
return img |
|
|
|
def __del__(self): |
|
self.hdf5_file.close() |
|
|
|
class NucleusDataModule(pl.LightningDataModule): |
|
def __init__(self, dataset, batch_size): |
|
super().__init__() |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
|
|
def setup(self, stage=None): |
|
|
|
train_size = int(0.8 * len(self.dataset)) |
|
test_size = len(self.dataset) - train_size |
|
self.train_dataset, self.test_dataset = random_split(self.dataset, [train_size, test_size]) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=16, pin_memory=True, prefetch_factor=5) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.test_dataset, batch_size=self.batch_size * 3, num_workers=16, pin_memory=True, prefetch_factor=5) |
|
|
|
class ViTMAEPreTraining(pl.LightningModule): |
|
def __init__(self, configuration): |
|
super().__init__() |
|
self.model = ViTMAEForPreTraining(configuration) |
|
self.save_hyperparameters() |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
x = batch |
|
x = x.to(self.device) |
|
outputs = self.model(x) |
|
loss = outputs.loss |
|
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
x = batch |
|
x = x.to(self.device) |
|
outputs = self.model(x) |
|
loss = outputs.loss |
|
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
|
return loss |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW(self.model.parameters(), lr=LEARNINGRATE) |
|
warmup_epochs = 10 |
|
warmup_factor = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else 1 |
|
scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_factor) |
|
scheduler_regular = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5) |
|
scheduler = { |
|
'scheduler': torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_regular], milestones=[warmup_epochs]), |
|
'interval': 'epoch', |
|
'frequency': 1 |
|
} |
|
return [optimizer], [scheduler] |
|
|
|
class EpochLoggingCallback(pl.Callback): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
@rank_zero_only |
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
train_loss = trainer.callback_metrics.get('train_loss') |
|
val_loss = trainer.callback_metrics.get('val_loss') |
|
if train_loss is not None and val_loss is not None: |
|
trainer.logger.experiment.add_scalars( |
|
"Epoch/Loss", |
|
{'Train Loss': train_loss, 'Validation Loss': val_loss}, |
|
trainer.current_epoch |
|
) |
|
|
|
class SaveEpochModelCallback(pl.Callback): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
@rank_zero_only |
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
path = trainer.checkpoint_callback.dirpath |
|
epoch = trainer.current_epoch |
|
pl_module.model.save_pretrained(f'{path}/epoch{epoch}') |
|
|
|
dataset = HDF5Dataset(hdf5_path=DATA_DIR, transform=transform) |
|
|
|
data_module = NucleusDataModule(dataset, BATCH_SIZE) |
|
|
|
epoch_logging_callback = EpochLoggingCallback() |
|
|
|
save_epoch_model_callback = SaveEpochModelCallback() |
|
|
|
progress_bar = RichProgressBar() |
|
|
|
logger = TensorBoardLogger(save_dir=f'./{PROJECT_NAME}_outputs', name="tensorboard") |
|
|
|
best_model_callback = ModelCheckpoint( |
|
dirpath=f'./{PROJECT_NAME}_outputs/model', |
|
filename='{epoch:02d}-{val_loss:.2f}', |
|
save_top_k=3, |
|
mode='min', |
|
monitor='val_loss' |
|
) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='epoch') |
|
|
|
trainer = Trainer( |
|
max_epochs=NUM_EPOCHS, |
|
devices=DEVICE_NUM, |
|
accelerator='gpu', |
|
strategy='ddp', |
|
logger=logger, |
|
callbacks=[lr_monitor, |
|
progress_bar, |
|
epoch_logging_callback, |
|
save_epoch_model_callback, |
|
best_model_callback] |
|
) |
|
|
|
|
|
pl.seed_everything(SEED, workers=True) |
|
|
|
model = ViTMAEPreTraining(configuration,) |
|
trainer.fit(model, data_module) |
|
|
|
|
|
|