|
import os |
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.optim |
|
from torchvision.models import resnet18, resnet34, resnet50 |
|
from torchvision.utils import make_grid, save_image |
|
import torch.nn.functional as F |
|
|
|
import pytorch_lightning as pl |
|
import numpy as np |
|
|
|
import mlflow.pytorch |
|
|
|
|
|
def resnet_model(model='resnet18', pretrained=True, in_channels=3, fc_out_features=2): |
|
if model.lower() == 'resnet18': |
|
resnet = resnet18(pretrained=pretrained) |
|
if model.lower() == 'resnet34': |
|
resnet = resnet34(pretrained=pretrained) |
|
if model.lower() == 'resnet50': |
|
resnet = resnet50(pretrained=pretrained) |
|
resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True) |
|
return resnet |
|
|
|
|
|
class LitModel(pl.LightningModule): |
|
|
|
def __init__(self, |
|
classifier, |
|
loss, |
|
lr=1e-3, |
|
weight_decay=0, |
|
loss_aux=None, |
|
adv_training=False, |
|
adv_parameters='all', |
|
metrics=None, |
|
processor=None, |
|
augmentation=None, |
|
is_segmentation_task=False, |
|
augmentation_on_eval=False, |
|
metrics_on_training=True, |
|
freeze_classifier=False, |
|
freeze_processor=False, |
|
): |
|
super().__init__() |
|
|
|
self.classifier = classifier |
|
self.processor = processor |
|
|
|
self.lr = lr |
|
self.weight_decay = weight_decay |
|
self.loss_fn = loss |
|
self.loss_aux_fn = loss_aux |
|
self.adv_training = adv_training |
|
self.metrics = metrics |
|
self.augmentation = augmentation |
|
self.is_segmentation_task = is_segmentation_task |
|
self.augmentation_on_eval = augmentation_on_eval |
|
self.metrics_on_training = metrics_on_training |
|
|
|
self.freeze_classifier = freeze_classifier |
|
self.freeze_processor = freeze_processor |
|
|
|
self.unfreeze() |
|
if freeze_classifier: |
|
pl.LightningModule.freeze(self.classifier) |
|
if freeze_processor: |
|
pl.LightningModule.freeze(self.processor) |
|
|
|
if adv_training and adv_parameters != 'all': |
|
if adv_parameters != 'all': |
|
pl.LightningModule.freeze(self.processor) |
|
for name, p in self.processor.named_parameters(): |
|
if adv_parameters in name: |
|
p.requires_grad = True |
|
|
|
def forward(self, x): |
|
x = self.processor(x) |
|
apply_augmentation_step = self.training or self.augmentation_on_eval |
|
if self.augmentation is not None and apply_augmentation_step: |
|
x = self.augmentation(x, retain_state=self.is_segmentation_task) |
|
x = self.classifier(x) |
|
return x |
|
|
|
def update_step(self, batch, step_name): |
|
x, y = batch |
|
|
|
logits = self(x) |
|
|
|
apply_augmentation_mask = self.is_segmentation_task and (self.training or self.augmentation_on_eval) |
|
if self.augmentation is not None and apply_augmentation_mask: |
|
y = self.augmentation(y, mask_transform=True).contiguous() |
|
|
|
loss = self.loss_fn(logits, y) |
|
|
|
if self.loss_aux_fn is not None: |
|
loss_aux = self.loss_aux_fn(x) |
|
loss += loss_aux |
|
|
|
self.log(f'{step_name}_loss', loss, on_step=False, on_epoch=True) |
|
if self.loss_aux_fn is not None: |
|
self.log(f'{step_name}_loss_aux', loss_aux, on_step=False, on_epoch=True) |
|
|
|
if self.is_segmentation_task: |
|
y_hat = F.logsigmoid(logits).exp().squeeze() |
|
else: |
|
y_hat = torch.argmax(logits, dim=1) |
|
|
|
if self.metrics is not None: |
|
for metric in self.metrics: |
|
metric_name = metric.__name__ if hasattr(metric, '__name__') else type(metric).__name__ |
|
if metric_name == 'accuracy' or not self.training or self.metrics_on_training: |
|
m = metric(y_hat.cpu().detach(), y.cpu()) |
|
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True, |
|
prog_bar=self.training or metric_name == 'accuracy') |
|
if metric_name == 'iou_score' or not self.training or self.metrics_on_training: |
|
m = metric(y_hat.cpu().detach(), y.cpu()) |
|
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True, |
|
prog_bar=self.training or metric_name == 'iou_score') |
|
elif metric_name == 'accuracy' or not self.training or self.metrics_on_training: |
|
m = metric(y_hat.cpu().detach(), y.cpu()) |
|
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True, |
|
prog_bar=self.training or metric_name == 'accuracy') |
|
|
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
return self.update_step(batch, 'train') |
|
|
|
def validation_step(self, batch, batch_idx): |
|
return self.update_step(batch, 'val') |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.update_step(batch, 'test') |
|
|
|
def train(self, mode=True): |
|
self.training = mode |
|
|
|
|
|
self.processor.train(mode=mode and not self.freeze_processor and not self.adv_training) |
|
self.classifier.train(mode=mode and not self.freeze_classifier) |
|
return self |
|
|
|
def configure_optimizers(self): |
|
self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay) |
|
return self.optimizer |
|
|
|
def get_progress_bar_dict(self): |
|
items = super().get_progress_bar_dict() |
|
items.pop('v_num') |
|
return items |
|
|
|
|
|
class TrackImagesCallback(pl.callbacks.base.Callback): |
|
def __init__(self, data_loader, reference_processor=None, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True): |
|
super().__init__() |
|
self.data_loader = data_loader |
|
|
|
self.track_every_epoch = track_every_epoch |
|
|
|
self.track_processing = track_processing |
|
self.track_gradients = track_gradients |
|
self.track_predictions = track_predictions |
|
self.save_tensors = save_tensors |
|
|
|
self.reference_processor = reference_processor |
|
|
|
def callback_track_images(self, model, save_loc): |
|
track_images(model, |
|
self.data_loader, |
|
reference_processor=self.reference_processor, |
|
track_processing=self.track_processing, |
|
track_gradients=self.track_gradients, |
|
track_predictions=self.track_predictions, |
|
save_tensors=self.save_tensors, |
|
save_loc=save_loc, |
|
) |
|
|
|
def on_fit_end(self, trainer, pl_module): |
|
if not self.track_every_epoch: |
|
save_loc = 'results' |
|
self.callback_track_images(trainer.model, save_loc) |
|
|
|
def on_train_epoch_end(self, trainer, pl_module, outputs): |
|
if self.track_every_epoch: |
|
save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}' |
|
self.callback_track_images(trainer.model, save_loc) |
|
|
|
|
|
def log_tensor(batch, path, save_tensors=True, nrow=8): |
|
if save_tensors: |
|
torch.save(batch, path) |
|
mlflow.log_artifact(path, os.path.dirname(path)) |
|
|
|
img_path = path.replace('.pt', '.png') |
|
split = img_path.split('/') |
|
img_path = '/'.join(split[:-1]) + '/img_' + split[-1] |
|
|
|
grid = make_grid(batch, nrow=nrow).squeeze() |
|
save_image(grid, img_path) |
|
mlflow.log_artifact(img_path, os.path.dirname(path)) |
|
|
|
|
|
def track_images(model, data_loader, reference_processor=None, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'): |
|
|
|
device = model.device |
|
processor = model.processor |
|
classifier = model.classifier |
|
|
|
if not hasattr(processor, 'stages'): |
|
return |
|
|
|
os.makedirs(save_loc, exist_ok=True) |
|
|
|
|
|
|
|
|
|
labels_full = [] |
|
logits_full = [] |
|
stages_full = defaultdict(list) |
|
grads_full = defaultdict(list) |
|
diffs_full = defaultdict(list) |
|
|
|
track_differences = reference_processor is not None |
|
|
|
for inputs, labels in data_loader: |
|
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
inputs.requires_grad = True |
|
|
|
processed_rgb = processor(inputs) |
|
|
|
if track_differences: |
|
|
|
processed_rgb_ref = reference_processor(inputs) |
|
|
|
if track_gradients or track_predictions: |
|
logits = classifier(processed_rgb) |
|
|
|
|
|
loss = model.loss_fn(logits, labels) |
|
loss.backward() |
|
|
|
if track_predictions: |
|
labels_full.append(labels.cpu().detach()) |
|
logits_full.append(logits.cpu().detach()) |
|
|
|
|
|
for stage, batch in processor.stages.items(): |
|
stages_full[stage].append(batch.cpu().detach()) |
|
if track_differences: |
|
diffs_full[stage].append((reference_processor.stages[stage] - batch).cpu().detach()) |
|
if track_gradients: |
|
grads_full[stage].append(batch.grad.cpu().detach()) |
|
|
|
with torch.no_grad(): |
|
|
|
stages = stages_full |
|
grads = grads_full |
|
diffs = diffs_full |
|
|
|
if track_processing: |
|
for stage, batch in stages.items(): |
|
stages[stage] = torch.cat(batch) |
|
|
|
if track_differences: |
|
for stage, batch in diffs.items(): |
|
diffs[stage] = torch.cat(batch) |
|
|
|
if track_gradients: |
|
for stage, batch in grads.items(): |
|
grads[stage] = torch.cat(batch) |
|
|
|
for stage_nr, stage_name in enumerate(stages): |
|
if track_processing: |
|
batch = stages[stage_name] |
|
log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors) |
|
|
|
if track_differences: |
|
batch = diffs[stage_name] |
|
log_tensor(batch, os.path.join(save_loc, f'diffs_{stage_nr}_{stage_name}.pt'), False) |
|
|
|
if track_gradients: |
|
batch_grad = grads[stage_name] |
|
batch_grad = batch_grad.abs() |
|
batch_grad = (batch_grad - batch_grad.min()) / (batch_grad.max() - batch_grad.min()) |
|
log_tensor(batch_grad, os.path.join( |
|
save_loc, f'gradients_{stage_nr}_{stage_name}.pt'), save_tensors) |
|
|
|
|
|
|
|
if track_predictions: |
|
labels = torch.cat(labels_full) |
|
logits = torch.cat(logits_full) |
|
masks = labels.unsqueeze(1) |
|
predictions = logits |
|
|
|
|
|
log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors) |
|
log_tensor(predictions, os.path.join(save_loc, f'preds.pt'), save_tensors) |
|
|