|
import os, time, pickle, shutil |
|
import pandas as pd |
|
import numpy as np |
|
|
|
from PIL import Image, ImageFile |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.cuda.amp import GradScaler |
|
from torch import autocast |
|
|
|
import torchvision.transforms as transforms |
|
|
|
import timm |
|
from timm.models import create_model |
|
from timm.utils import ModelEmaV2 |
|
|
|
from timm.optim import create_optimizer_v2 |
|
|
|
from torchmetrics import MeanMetric |
|
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score |
|
from torchmetrics import MetricCollection |
|
|
|
from pytorch_metric_learning.losses import ArcFaceLoss |
|
|
|
import wandb |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
MODEL_DIR = "./convnext2b_metaEmbedding_focal05es_arcloss/" |
|
|
|
if not os.path.exists(MODEL_DIR): |
|
os.makedirs(MODEL_DIR) |
|
shutil.copyfile('./convnext2b_exp4_meta_embedding_focalarcloss.py', f'{MODEL_DIR}convnext2b_exp4_meta_embedding_focalarcloss.py') |
|
|
|
TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" |
|
ADD_TRAIN_DATA_DIR = "/HMP/" |
|
VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" |
|
|
|
TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv" |
|
ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv" |
|
VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv" |
|
|
|
MISSING_FILES = "../missing_train_data.csv" |
|
|
|
CCM = "../code_class_mapping_obid.csv" |
|
|
|
|
|
NUM_CLASSES = 1784 |
|
|
|
|
|
NUM_EPOCHS = 40 |
|
WARMUP_EPOCHS = 0 |
|
RESUME_EPOCH = 14 |
|
|
|
|
|
LEARNING_RATE = { |
|
'cnn': 1e-05, |
|
'embeddings': 1e-04, |
|
'classifier': 1e-04, |
|
} |
|
|
|
BATCH_SIZE = { |
|
'train': 32, |
|
'valid': 48, |
|
'grad_acc': 4, |
|
} |
|
|
|
BATCH_SIZE_AFTER_WARMUP = { |
|
'train': 32, |
|
'valid': 48, |
|
'grad_acc': 4, |
|
} |
|
|
|
TRANSFORMS = { |
|
'IMAGE_SIZE_TRAIN': 544, |
|
'IMAGE_SIZE_VAL': 544, |
|
'RandAug' : { |
|
'm': 7, |
|
'n': 2 |
|
} |
|
} |
|
|
|
|
|
|
|
FOCAL_LOSS = { |
|
'class_dist': pickle.load(open("../classDist_HMP_missedRemoved.p", "rb"))['counts'], |
|
'gamma': 0.5, |
|
} |
|
|
|
|
|
|
|
CHECKPOINTS = { |
|
'fe_cnn': None, |
|
'model': None, |
|
'optimizer': None, |
|
'scaler': None, |
|
} |
|
|
|
|
|
META_SIZES = {'endemic': 2, 'code': 212} |
|
EMBEDDING_SIZES = {'endemic': 64, 'code': 64} |
|
|
|
CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb")) |
|
ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb")) |
|
|
|
|
|
WANDB = True |
|
|
|
if WANDB: |
|
wandb.init( |
|
entity="snakeclef2023", |
|
|
|
|
|
project="exp4", |
|
|
|
|
|
name="focal05es_arcloss", |
|
|
|
|
|
config={ |
|
"learning_rate": LEARNING_RATE, |
|
"focal_loss": FOCAL_LOSS, |
|
"architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384", |
|
"pretrained": "iNat21", |
|
"dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}", |
|
"epochs": NUM_EPOCHS, |
|
"transforms": TRANSFORMS, |
|
"checkpoints": CHECKPOINTS, |
|
"model_dir": MODEL_DIR |
|
|
|
}, |
|
save_code=True, |
|
dir=MODEL_DIR |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class SnakeTrainDataset(Dataset): |
|
def __init__(self, data, ccm, transform=None): |
|
self.data = data |
|
self.transform = transform |
|
self.code_class_mapping = ccm |
|
self.code_tokens = CODE_TOKENS |
|
self.endemic_tokens = ENDEMIC_TOKENS |
|
|
|
def __len__(self): |
|
return self.data.shape[0] |
|
|
|
def __getitem__(self, index): |
|
obj = self.data.iloc[index] |
|
label = obj.class_id |
|
code = obj.code if obj.code in self.code_tokens.keys() else "unknown" |
|
endemic = obj.endemic if obj.endemic in self.endemic_tokens.keys() else False |
|
|
|
img = Image.open(obj.image_path).convert("RGB") |
|
ccm = torch.tensor(self.code_class_mapping[code].to_numpy()) |
|
meta = torch.tensor([self.code_tokens[code], self.endemic_tokens[endemic]]) |
|
|
|
|
|
img = self.transform(img) |
|
|
|
return (img, label, ccm, meta) |
|
|
|
|
|
|
|
def get_val_preprocessing(img_size): |
|
print(f'IMG_SIZE_VAL: {img_size}') |
|
return transforms.Compose([ |
|
transforms.Resize(int(img_size * 1.25)), |
|
transforms.Compose([ |
|
transforms.FiveCrop((img_size, img_size)), |
|
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) |
|
]), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
]) |
|
|
|
class IdentityTransform: |
|
def __call__(self, x): |
|
return x |
|
|
|
|
|
|
|
def get_train_augmentation_preprocessing(img_size, rand_aug=False): |
|
print(f'IMG_SIZE_TRAIN: {img_size}, RandAug: {rand_aug}') |
|
return transforms.Compose([ |
|
transforms.Resize(int(img_size * 1.25)), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomVerticalFlip(p=0.5), |
|
transforms.RandomCrop((img_size, img_size)), |
|
transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m']) if rand_aug else IdentityTransform(), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
]) |
|
|
|
|
|
def get_datasets(train_transfroms, val_transforms): |
|
|
|
nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null'] |
|
train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False) |
|
missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False) |
|
valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False) |
|
|
|
|
|
train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True) |
|
train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]] |
|
|
|
|
|
train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path'] |
|
valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path'] |
|
|
|
|
|
if ADD_TRAINDATA_CONFIG: |
|
add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False) |
|
add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path'] |
|
train_data = pd.concat([train_data, add_train_data], axis=0) |
|
|
|
|
|
|
|
|
|
print(f'train data shape: {train_data.shape}') |
|
|
|
|
|
train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True) |
|
valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True) |
|
|
|
|
|
ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False) |
|
|
|
|
|
train_dataset = SnakeTrainDataset(train_data, ccm, transform=train_transfroms) |
|
valid_dataset = SnakeTrainDataset(valid_data, ccm, transform=val_transforms) |
|
|
|
return train_dataset, valid_dataset |
|
|
|
|
|
def get_dataloaders(imgsize_train, imgsize_val, rand_aug): |
|
|
|
train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train, rand_aug) |
|
val_preprocessing = get_val_preprocessing(imgsize_val) |
|
|
|
train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing) |
|
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE['train'], num_workers=6, drop_last=True, pin_memory=True) |
|
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=BATCH_SIZE['valid'], num_workers=6, drop_last=False, pin_memory=True) |
|
|
|
return train_loader, valid_loader |
|
|
|
|
|
|
|
|
|
def plot_history(logs): |
|
fig, ax = plt.subplots(3, 1, figsize=(8, 12)) |
|
|
|
ax[0].plot(logs['loss'], label="train data") |
|
ax[0].plot(logs['val_loss'], label="valid data") |
|
ax[0].legend(loc="best") |
|
ax[0].set_ylabel("loss") |
|
ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)]) |
|
|
|
ax[0].set_title("train- vs. valid loss") |
|
|
|
ax[1].plot(logs['acc'], label="train data") |
|
ax[1].plot(logs['val_acc'], label="valid data") |
|
ax[1].legend(loc="best") |
|
ax[1].set_ylabel("accuracy") |
|
ax[1].set_ylim([0, 1.01]) |
|
|
|
ax[1].set_title("train- vs. valid accuracy") |
|
|
|
ax[2].plot(logs['f1'], label="train data") |
|
ax[2].plot(logs['val_f1'], label="valid data") |
|
ax[2].legend(loc="best") |
|
ax[2].set_ylabel("f1") |
|
ax[2].set_ylim([0, 1.01]) |
|
ax[2].set_xlabel("epochs") |
|
ax[2].set_title("train- vs. valid f1") |
|
|
|
fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg") |
|
plt.show() |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
''' |
|
Multi-class Focal Loss |
|
''' |
|
def __init__(self, gamma, class_dist=None, reduction='mean', device='cuda'): |
|
super(FocalLoss, self).__init__() |
|
self.gamma = gamma |
|
|
|
self.weight = torch.tensor((1.0 - 0.999) / (1.0 - 0.999**class_dist), dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device) |
|
self.reduction = reduction |
|
|
|
def forward(self, inputs, targets): |
|
""" |
|
input: [N, C], float32 |
|
target: [N, ], int64 |
|
""" |
|
logpt = torch.nn.functional.log_softmax(inputs, dim=1) |
|
pt = torch.exp(logpt) |
|
logpt = (1-pt)**self.gamma * logpt |
|
loss = torch.nn.functional.nll_loss(logpt, targets, weight=self.weight, reduction=self.reduction) |
|
return loss |
|
|
|
|
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
def __init__(self): |
|
super(FeatureExtractor, self).__init__() |
|
self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2) |
|
if CHECKPOINTS['fe_cnn']: |
|
self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True) |
|
print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}") |
|
torch.cuda.empty_cache() |
|
|
|
def forward(self, img): |
|
conv_features = self.conv_backbone(img) |
|
return conv_features |
|
|
|
|
|
class MetaEmbeddings(nn.Module): |
|
def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None): |
|
super(MetaEmbeddings, self).__init__() |
|
self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0) |
|
self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0) |
|
|
|
self.dim_embedding = sum(embedding_sizes.values()) |
|
self.embedding_net = nn.Sequential( |
|
nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True), |
|
nn.GELU(), |
|
nn.LayerNorm(self.dim_embedding, eps=1e-06), |
|
nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(), |
|
nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True), |
|
nn.GELU(), |
|
nn.LayerNorm(self.dim_embedding, eps=1e-06), |
|
) |
|
|
|
def forward(self, meta): |
|
code_feature = self.code_embedding(meta[:,0]) |
|
endemic_feature = self.endemic_embedding(meta[:,1]) |
|
|
|
embeddings = torch.concat([code_feature, endemic_feature], dim=-1) |
|
embedding_features = self.embedding_net(embeddings) |
|
|
|
return embedding_features |
|
|
|
|
|
class Classifier(nn.Module): |
|
def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None): |
|
super(Classifier, self).__init__() |
|
self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity() |
|
self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True) |
|
|
|
def forward(self, embeddings): |
|
dropped_feature = self.dropout(embeddings) |
|
outputs = self.classifier(dropped_feature) |
|
|
|
return outputs |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self): |
|
super(Model, self).__init__() |
|
self.feature_extractor = FeatureExtractor() |
|
self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25) |
|
self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25) |
|
|
|
def forward(self, img, meta): |
|
img_features = self.feature_extractor(img) |
|
|
|
meta_features = self.embedding_net(meta) |
|
cat_features = torch.concat([img_features, meta_features], dim=-1) |
|
classifier_outputs = self.classifier(cat_features) |
|
|
|
return classifier_outputs, cat_features |
|
|
|
class LossLayer(nn.Module): |
|
def __init__(self): |
|
super(LossLayer, self).__init__() |
|
self.arcloss = ArcFaceLoss(num_classes=NUM_CLASSES, embedding_size=1024+128, margin=28.6, scale=64) |
|
self.celoss = FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist']) |
|
|
|
def forward(self, classifier_outputs, cat_features, labels): |
|
classifier_loss = self.celoss(classifier_outputs, labels) |
|
embedding_loss = self.arcloss(cat_features, labels) |
|
return classifier_loss + embedding_loss |
|
|
|
|
|
def load_checkpoints(model=None, optimizer=None, scaler=None): |
|
if CHECKPOINTS['model'] and model is not None: |
|
model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu')) |
|
print(f"use model checkpoints: {CHECKPOINTS['model']}") |
|
if CHECKPOINTS['optimizer'] and optimizer is not None: |
|
optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu')) |
|
print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}") |
|
if CHECKPOINTS['scaler'] and scaler is not None: |
|
scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu')) |
|
print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}") |
|
torch.cuda.empty_cache() |
|
|
|
def resume_checkpoints(model=None, optimizer=None, scaler=None): |
|
if model is not None: |
|
model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu')) |
|
print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth") |
|
if optimizer is not None: |
|
optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu')) |
|
print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth") |
|
|
|
if scaler is not None: |
|
scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu')) |
|
print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth") |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def resume_logs(logs): |
|
old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv") |
|
for m in list(logs.keys()): |
|
logs[m].extend(list(old_logs[m].values)) |
|
|
|
|
|
def get_optm_group(module): |
|
""" |
|
This long function is unfortunately doing something very simple and is being very defensive: |
|
We are separating out all parameters of the model into two buckets: those that will experience |
|
weight decay for regularization and those that won't (biases, and layernorm/embedding weights). |
|
We are then returning the PyTorch optimizer object. |
|
""" |
|
|
|
|
|
decay = set() |
|
no_decay = set() |
|
whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp) |
|
blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding) |
|
for mn, m in module.named_modules(): |
|
for pn, p in m.named_parameters(): |
|
fpn = '%s.%s' % (mn, pn) if mn else pn |
|
|
|
if pn.endswith('bias'): |
|
|
|
no_decay.add(fpn) |
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): |
|
|
|
decay.add(fpn) |
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): |
|
|
|
no_decay.add(fpn) |
|
|
|
|
|
|
|
param_dict = {pn: p for pn, p in module.named_parameters()} |
|
inter_params = decay & no_decay |
|
union_params = decay | no_decay |
|
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) |
|
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ |
|
% (str(param_dict.keys() - union_params), ) |
|
|
|
return param_dict, decay, no_decay |
|
|
|
|
|
def get_warmup_optimizer(model): |
|
params_group = [] |
|
|
|
param_dict, decay, no_decay = get_optm_group(model.embedding_net) |
|
params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']}) |
|
params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']}) |
|
|
|
param_dict, decay, no_decay = get_optm_group(model.classifier) |
|
params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']}) |
|
params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']}) |
|
|
|
optimizer = torch.optim.AdamW(params_group) |
|
return optimizer |
|
|
|
|
|
def get_after_warmup_optimizer(model, old_opt): |
|
new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn']) |
|
|
|
|
|
for group in old_opt.param_groups: |
|
new_opt.add_param_group(group) |
|
|
|
return new_opt |
|
|
|
|
|
|
|
|
|
def warmup_start(model): |
|
|
|
for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()): |
|
param.requires_grad = False |
|
print(f'--> freeze feature_extractor.conv_backbone during warmup phase') |
|
|
|
def warmup_end(model): |
|
|
|
for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()): |
|
param.requires_grad = True |
|
print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase') |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
device = torch.device(f'cuda:1') |
|
torch.cuda.set_device(device) |
|
|
|
|
|
train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'], |
|
imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'], |
|
rand_aug=True) |
|
|
|
|
|
model = Model().to(device) |
|
|
|
if RESUME_EPOCH > 0: |
|
resume_checkpoints(model=model) |
|
ema_model = ModelEmaV2(model, decay=0.9998, device=device) |
|
|
|
|
|
|
|
optimizer = get_warmup_optimizer(model) |
|
scaler = GradScaler() |
|
|
|
if RESUME_EPOCH > 0: |
|
optimizer = get_after_warmup_optimizer(model, optimizer) if RESUME_EPOCH > WARMUP_EPOCHS else optimizer |
|
resume_checkpoints(optimizer=optimizer, scaler=scaler) |
|
|
|
loss_fn = LossLayer().to(device) |
|
optimizer.add_param_group({"params": loss_fn.arcloss.parameters(), "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']}) |
|
|
|
|
|
loss_metric = MeanMetric().to(device) |
|
metrics = MetricCollection(metrics={ |
|
'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'), |
|
'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3), |
|
'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro') |
|
}).to(device) |
|
metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device) |
|
|
|
|
|
start_training = time.perf_counter() |
|
|
|
logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []} |
|
if RESUME_EPOCH > 0: |
|
resume_logs(logs) |
|
|
|
|
|
start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0 |
|
for epoch in range(start_epoch, NUM_EPOCHS): |
|
|
|
epoch_start = time.perf_counter() |
|
print(f'Epoch {epoch+1}/{NUM_EPOCHS}') |
|
|
|
|
|
if (epoch) == WARMUP_EPOCHS: |
|
warmup_end(model) |
|
optimizer = get_after_warmup_optimizer(model, optimizer) |
|
global BATCH_SIZE |
|
BATCH_SIZE = BATCH_SIZE_AFTER_WARMUP |
|
train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'], |
|
imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'], |
|
rand_aug=True) |
|
|
|
elif (epoch) < WARMUP_EPOCHS: |
|
print(f'--> Warm Up {epoch+1}/{WARMUP_EPOCHS}') |
|
|
|
|
|
model.train() |
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False) |
|
|
|
|
|
for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader): |
|
inputs = inputs.to(device, non_blocking=True) |
|
meta = meta.to(device, non_blocking=True) |
|
labels = labels.to(device, non_blocking=True) |
|
ccm = ccm.to(device, non_blocking=True) |
|
|
|
|
|
with autocast(device_type='cuda', dtype=torch.float16): |
|
outputs, embeddings = model(inputs, meta) |
|
loss = loss_fn(outputs, embeddings, labels) / loss_div |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
loss_metric.update((loss * loss_div).detach()) |
|
|
|
preds = outputs.softmax(dim=-1).detach() |
|
metrics.update(preds, labels) |
|
metric_ccm.update(preds * ccm, labels) |
|
|
|
|
|
if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0: |
|
|
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
ema_model.update(model) |
|
|
|
|
|
|
|
epoch_loss = loss_metric.compute() |
|
epoch_metrics = metrics.compute() |
|
epoch_metric_ccm = metric_ccm.compute() |
|
|
|
loss_metric.reset() |
|
metrics.reset() |
|
metric_ccm.reset() |
|
|
|
|
|
logs['loss'].append(epoch_loss.cpu().item()) |
|
logs['acc'].append(epoch_metrics['acc'].cpu().item()) |
|
logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item()) |
|
logs['f1'].append(epoch_metrics['f1'].cpu().item()) |
|
logs['f1country'].append(epoch_metric_ccm.detach().cpu().item()) |
|
|
|
print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ') |
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm |
|
torch.cuda.empty_cache() |
|
|
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
|
|
|
|
for (inputs, labels, ccm, meta) in valid_loader: |
|
inputs = inputs.to(device, non_blocking=True) |
|
inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL']) |
|
meta = meta.to(device, non_blocking=True) |
|
meta = torch.repeat_interleave(meta, repeats=5, dim=0) |
|
labels = labels.to(device, non_blocking=True) |
|
ccm = ccm.to(device, non_blocking=True) |
|
|
|
|
|
with autocast(device_type='cuda', dtype=torch.float16): |
|
outputs, embeddings = model(inputs, meta) |
|
outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1) |
|
embeddings = embeddings.view(-1, 5, 1024+128).mean(1) |
|
loss = loss_fn(outputs, embeddings, labels) |
|
|
|
|
|
loss_metric.update(loss.detach()) |
|
|
|
preds = outputs.softmax(dim=-1).detach() |
|
metrics.update(preds, labels) |
|
metric_ccm.update(preds * ccm, labels) |
|
|
|
|
|
epoch_loss = loss_metric.compute() |
|
epoch_metrics = metrics.compute() |
|
epoch_metric_ccm = metric_ccm.compute() |
|
|
|
loss_metric.reset() |
|
metrics.reset() |
|
metric_ccm.reset() |
|
|
|
|
|
logs['val_loss'].append(epoch_loss.cpu().item()) |
|
logs['val_acc'].append(epoch_metrics['acc'].cpu().item()) |
|
logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item()) |
|
logs['val_f1'].append(epoch_metrics['f1'].cpu().item()) |
|
logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item()) |
|
|
|
print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ') |
|
|
|
del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm |
|
torch.cuda.empty_cache() |
|
|
|
|
|
logs_df = pd.DataFrame(logs) |
|
logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8') |
|
|
|
if WANDB: |
|
|
|
wandb.log( |
|
{k:v[epoch] for k,v in logs.items()}, |
|
step=epoch |
|
) |
|
|
|
|
|
torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth') |
|
torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth') |
|
torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth') |
|
torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth') |
|
torch.save(loss_fn.arcloss.state_dict(), f'{MODEL_DIR}arcloss_epoch{epoch}.pth') |
|
|
|
|
|
epoch_end = time.perf_counter() |
|
print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.") |
|
|
|
del logs_df, epoch_start, epoch_end |
|
torch.cuda.empty_cache() |
|
|
|
|
|
del model |
|
torch.cuda.empty_cache() |
|
|
|
ema_net = ema_model.module |
|
ema_net.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
for (inputs, labels, ccm, meta) in valid_loader: |
|
inputs = inputs.to(device, non_blocking=True) |
|
inputs = inputs.view(-1, 3, TRANSFORMS['IMAGE_SIZE_VAL'], TRANSFORMS['IMAGE_SIZE_VAL']) |
|
meta = meta.to(device, non_blocking=True) |
|
meta = torch.repeat_interleave(meta, repeats=5, dim=0) |
|
labels = labels.to(device, non_blocking=True) |
|
ccm = ccm.to(device, non_blocking=True) |
|
|
|
|
|
with autocast(device_type='cuda', dtype=torch.float16): |
|
outputs, embeddings = ema_net(inputs, meta) |
|
outputs = outputs.view(-1, 5, NUM_CLASSES).mean(1) |
|
embeddings = embeddings.view(-1, 5, 1024+128).mean(1) |
|
loss = loss_fn(outputs, embeddings, labels) |
|
|
|
|
|
loss_metric.update(loss.detach()) |
|
|
|
preds = outputs.softmax(dim=-1).detach() |
|
metrics.update(preds, labels) |
|
metric_ccm.update(preds * ccm, labels) |
|
|
|
|
|
epoch_loss = loss_metric.compute() |
|
epoch_metrics = metrics.compute() |
|
epoch_metric_ccm = metric_ccm.compute() |
|
|
|
loss_metric.reset() |
|
metrics.reset() |
|
metric_ccm.reset() |
|
|
|
print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}") |
|
|
|
with open(f'{MODEL_DIR}ema_results.txt', 'w') as f: |
|
print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f) |
|
|
|
plot_history(logs) |
|
|
|
end_training = time.perf_counter() |
|
print(f'Training succeeded in {(end_training - start_training):5.3f}s') |
|
|
|
if WANDB: |
|
wandb.finish() |
|
|
|
|
|
if __name__=="__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|