TongjiZhanglab commited on
Commit
0f8f155
·
verified ·
1 Parent(s): 1ffa4fe

Upload modle

Browse files
Files changed (3) hide show
  1. config.json +27 -0
  2. model.safetensors +3 -0
  3. pretraining_pl_DDP_v5.py +213 -0
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTMAEForPreTraining"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "decoder_hidden_size": 512,
7
+ "decoder_intermediate_size": 1024,
8
+ "decoder_num_attention_heads": 16,
9
+ "decoder_num_hidden_layers": 8,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.0,
12
+ "hidden_size": 768,
13
+ "image_size": 112,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-12,
17
+ "mask_ratio": 0.75,
18
+ "model_type": "vit_mae",
19
+ "norm_pix_loss": false,
20
+ "num_attention_heads": 12,
21
+ "num_channels": 1,
22
+ "num_hidden_layers": 12,
23
+ "patch_size": 8,
24
+ "qkv_bias": true,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.41.2"
27
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13d8f408d176af86d658b98d009de423978dc6969adb5e63f5447ee8c99982db
3
+ size 410476120
pretraining_pl_DDP_v5.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import h5py
3
+ import torch
4
+ import torch.nn as nn
5
+ import random
6
+ import numpy as np
7
+ import os
8
+ import shutil
9
+ import pandas as pd
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset, random_split
13
+ import torch.optim as optim
14
+ import time
15
+ from tqdm import tqdm
16
+ from torch.optim import lr_scheduler
17
+ from transformers import ViTFeatureExtractor, AutoImageProcessor, ViTMAEConfig, ViTMAEModel, ViTMAEForPreTraining
18
+ from torchvision.datasets import ImageFolder
19
+ import lightning.pytorch as pl
20
+ from lightning.pytorch import Trainer
21
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, RichProgressBar
22
+ from lightning.pytorch.loggers import TensorBoardLogger
23
+ from lightning.pytorch.callbacks import RichProgressBar
24
+ from lightning.pytorch.callbacks import TQDMProgressBar
25
+ from lightning.pytorch.utilities import rank_zero_only
26
+
27
+ DEVICE_NUM = torch.cuda.device_count()
28
+ os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in range(DEVICE_NUM)])
29
+
30
+ SEED = 42
31
+ DATA_DIR = "../../0.data/pretrain_nucleus_image_all_16M.hdf5"
32
+ BATCH_SIZE = 400 *2
33
+ NUM_EPOCHS = 70
34
+ LEARNINGRATE = 0.0001
35
+ PROJECT_NAME = 'Nuspire_Pretraining_V5'
36
+
37
+ transform = transforms.Compose([
38
+ transforms.Grayscale(),
39
+ transforms.RandomResizedCrop((112, 112), scale=(0.5625, 1.0), ratio=(0.75, 1.33)),
40
+ transforms.RandomHorizontalFlip(p=0.5),
41
+ transforms.RandomVerticalFlip(p=0.5),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102])
44
+ ])
45
+
46
+ configuration = ViTMAEConfig(
47
+ hidden_size=768,
48
+ num_hidden_layers=12,
49
+ num_attention_heads=12,
50
+ intermediate_size=3072,
51
+ hidden_act="gelu",
52
+ hidden_dropout_prob=0.0,
53
+ attention_probs_dropout_prob=0.0,
54
+ initializer_range=0.02,
55
+ layer_norm_eps=1e-12,
56
+ image_size=112,
57
+ patch_size=8,
58
+ num_channels=1,
59
+ qkv_bias=True,
60
+ decoder_num_attention_heads=16,
61
+ decoder_hidden_size=512,
62
+ decoder_num_hidden_layers=8,
63
+ decoder_intermediate_size=1024,
64
+ mask_ratio=0.75,
65
+ norm_pix_loss=False
66
+ )
67
+
68
+ class HDF5Dataset(Dataset):
69
+ def __init__(self, hdf5_path, transform=None):
70
+ self.hdf5_path = hdf5_path
71
+ self.transform = transform
72
+ self.hdf5_file = h5py.File(hdf5_path, 'r', rdcc_nbytes=10*1024**3, rdcc_w0=0.0, rdcc_nslots=10007)
73
+ self.images = self.hdf5_file['images']
74
+
75
+ def __len__(self):
76
+ return len(self.images)
77
+
78
+ def __getitem__(self, idx):
79
+ img = self.images[idx]
80
+
81
+ if self.transform:
82
+ img = Image.fromarray(img)
83
+ img = self.transform(img)
84
+
85
+ return img
86
+
87
+ def __del__(self):
88
+ self.hdf5_file.close()
89
+
90
+ class NucleusDataModule(pl.LightningDataModule):
91
+ def __init__(self, dataset, batch_size):
92
+ super().__init__()
93
+ self.dataset = dataset
94
+ self.batch_size = batch_size
95
+
96
+ def setup(self, stage=None):
97
+
98
+ train_size = int(0.8 * len(self.dataset))
99
+ test_size = len(self.dataset) - train_size
100
+ self.train_dataset, self.test_dataset = random_split(self.dataset, [train_size, test_size])
101
+
102
+ def train_dataloader(self):
103
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=16, pin_memory=True, prefetch_factor=5)
104
+
105
+ def val_dataloader(self):
106
+ return DataLoader(self.test_dataset, batch_size=self.batch_size * 3, num_workers=16, pin_memory=True, prefetch_factor=5)
107
+
108
+ class ViTMAEPreTraining(pl.LightningModule):
109
+ def __init__(self, configuration):
110
+ super().__init__()
111
+ self.model = ViTMAEForPreTraining(configuration)
112
+ self.save_hyperparameters()
113
+
114
+ def forward(self, x):
115
+ return self.model(x)
116
+
117
+ def training_step(self, batch, batch_idx):
118
+ x = batch
119
+ x = x.to(self.device)
120
+ outputs = self.model(x)
121
+ loss = outputs.loss
122
+ self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
123
+ return loss
124
+
125
+
126
+ def validation_step(self, batch, batch_idx):
127
+ x = batch
128
+ x = x.to(self.device)
129
+ outputs = self.model(x)
130
+ loss = outputs.loss
131
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
132
+ return loss
133
+
134
+ def configure_optimizers(self):
135
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=LEARNINGRATE)
136
+ warmup_epochs = 10
137
+ warmup_factor = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else 1
138
+ scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_factor)
139
+ scheduler_regular = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5)
140
+ scheduler = {
141
+ 'scheduler': torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_regular], milestones=[warmup_epochs]),
142
+ 'interval': 'epoch',
143
+ 'frequency': 1
144
+ }
145
+ return [optimizer], [scheduler]
146
+
147
+ class EpochLoggingCallback(pl.Callback):
148
+ def __init__(self):
149
+ super().__init__()
150
+
151
+ @rank_zero_only
152
+ def on_validation_epoch_end(self, trainer, pl_module):
153
+ train_loss = trainer.callback_metrics.get('train_loss')
154
+ val_loss = trainer.callback_metrics.get('val_loss')
155
+ if train_loss is not None and val_loss is not None:
156
+ trainer.logger.experiment.add_scalars(
157
+ "Epoch/Loss",
158
+ {'Train Loss': train_loss, 'Validation Loss': val_loss},
159
+ trainer.current_epoch
160
+ )
161
+
162
+ class SaveEpochModelCallback(pl.Callback):
163
+ def __init__(self):
164
+ super().__init__()
165
+
166
+ @rank_zero_only
167
+ def on_validation_epoch_end(self, trainer, pl_module):
168
+ path = trainer.checkpoint_callback.dirpath
169
+ epoch = trainer.current_epoch
170
+ pl_module.model.save_pretrained(f'{path}/epoch{epoch}')
171
+
172
+ dataset = HDF5Dataset(hdf5_path=DATA_DIR, transform=transform)
173
+
174
+ data_module = NucleusDataModule(dataset, BATCH_SIZE)
175
+
176
+ epoch_logging_callback = EpochLoggingCallback()
177
+
178
+ save_epoch_model_callback = SaveEpochModelCallback()
179
+
180
+ progress_bar = RichProgressBar()
181
+
182
+ logger = TensorBoardLogger(save_dir=f'./{PROJECT_NAME}_outputs', name="tensorboard")
183
+
184
+ best_model_callback = ModelCheckpoint(
185
+ dirpath=f'./{PROJECT_NAME}_outputs/model',
186
+ filename='{epoch:02d}-{val_loss:.2f}',
187
+ save_top_k=3,
188
+ mode='min',
189
+ monitor='val_loss'
190
+ )
191
+
192
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
193
+
194
+ trainer = Trainer(
195
+ max_epochs=NUM_EPOCHS,
196
+ devices=DEVICE_NUM, # 设置使用的设备数量
197
+ accelerator='gpu', # 指定使用GPU
198
+ strategy='ddp',
199
+ logger=logger,
200
+ callbacks=[lr_monitor,
201
+ progress_bar,
202
+ epoch_logging_callback,
203
+ save_epoch_model_callback,
204
+ best_model_callback]
205
+ )
206
+
207
+ # 设置随机种子
208
+ pl.seed_everything(SEED, workers=True)
209
+
210
+ model = ViTMAEPreTraining(configuration,)
211
+ trainer.fit(model, data_module)
212
+
213
+