import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import matplotlib.pyplot as plt | |
from torch_lr_finder import LRFinder | |
import numpy as np | |
from utils import get_correct_pred_count, add_predictions, test_incorrect_pred, test_correct_pred, denormalize | |
NO_GROUPS = 4 | |
class ResnetBlock(nn.Module): | |
def __init__(self, input_channel, output_channel, padding=1, norm='bn', drop=0.01): | |
super(ResnetBlock, self).__init__() | |
self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=padding) | |
if norm == 'bn': | |
self.n1 = nn.BatchNorm2d(output_channel) | |
elif norm == 'gn': | |
self.n1 = nn.GroupNorm(NO_GROUPS, output_channel) | |
elif norm == 'ln': | |
self.n1 = nn.GroupNorm(1, output_channel) | |
self.drop1 = nn.Dropout2d(drop) | |
self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=padding) | |
if norm == 'bn': | |
self.n2 = nn.BatchNorm2d(output_channel) | |
elif norm == 'gn': | |
self.n2 = nn.GroupNorm(NO_GROUPS, output_channel) | |
elif norm == 'ln': | |
self.n2 = nn.GroupNorm(1, output_channel) | |
self.drop2 = nn.Dropout2d(drop) | |
''' | |
Depending on the model requirement, Convolution block with number of layers is applied to the input image | |
''' | |
def __call__(self, x): | |
x = self.conv1(x) | |
x = self.n1(x) | |
x = F.relu(x) | |
x = self.drop1(x) | |
#if layers >= 2: | |
x = self.conv2(x) | |
x = self.n2(x) | |
x = F.relu(x) | |
x = self.drop2(x) | |
return x | |
class S10LightningModel(pl.LightningModule): | |
def __init__(self, base_channels, drop=0.01, loss_function=F.cross_entropy, is_find_max_lr=False, max_lr=3.20E-04): | |
super(S10LightningModel, self).__init__() | |
self.is_find_max_lr = is_find_max_lr | |
self.max_lr = max_lr | |
self.criterion = loss_function | |
self.metric = dict(train=0, | |
val=0, | |
train_total=0, | |
val_total=0, | |
epoch_train_loss=[], | |
epoch_val_loss=[], | |
train_loss=[], | |
val_loss=[], | |
train_acc=[], | |
val_acc=[]) | |
self.base_channels = base_channels | |
self.prep_layer = nn.Sequential( | |
nn.Conv2d(3, base_channels, 3, stride=1, padding=1), | |
nn.BatchNorm2d(base_channels), | |
nn.ReLU(), | |
nn.Dropout2d(drop) | |
) | |
# layer1 | |
self.x1 = nn.Sequential( | |
nn.Conv2d(base_channels, 2 * base_channels, 3, stride=1, padding=1), | |
nn.MaxPool2d(2, 2), | |
nn.BatchNorm2d(2 * base_channels), | |
nn.ReLU(), | |
nn.Dropout2d(drop) | |
) | |
self.R1 = ResnetBlock(2 * base_channels, 2 * base_channels, padding=1, drop=drop) | |
# layer2 | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(2 * base_channels, 4 * base_channels, 3, stride=1, padding=1), | |
nn.MaxPool2d(2, 2), | |
nn.BatchNorm2d(4 * base_channels), | |
nn.ReLU(), | |
nn.Dropout2d(drop) | |
) | |
# layer3 | |
self.x2 = nn.Sequential( | |
nn.Conv2d(4 * base_channels, 8 * base_channels, 3, stride=1, padding=1), | |
nn.MaxPool2d(2, 2), | |
nn.BatchNorm2d(8 * base_channels), | |
nn.ReLU(), | |
nn.Dropout2d(drop) | |
) | |
self.R2 = ResnetBlock(8 * base_channels, 8 * base_channels, padding=1, drop=drop) | |
self.pool = nn.MaxPool2d(4) | |
self.fc = nn.Linear(8 * base_channels, 10) | |
def forward(self, x, no_softmax=False): | |
# print(x.size()) | |
x = self.prep_layer(x) | |
# print(x.size()) | |
x = self.x1(x) | |
# print('x1', x.size()) | |
x = self.R1(x) + x | |
# print('x', x.size()) | |
x = self.layer2(x) | |
# print(x.size()) | |
x = self.x2(x) | |
# print('x2', x.size()) | |
x = self.R2(x) + x | |
# print('x', x.size()) | |
x = self.pool(x) | |
# print(x.size()) | |
x = x.view(x.size(0), 8 * self.base_channels) | |
# print(x.size()) | |
x = self.fc(x) | |
# print(x.size()) | |
if no_softmax: | |
print(x.size()) | |
return x | |
return F.log_softmax(x, dim=1) | |
def get_layer(self, idx): | |
layers = [self.prep_layer, self.x1, self.layer2, self.x2, self.pool] | |
if idx < len(layers) and idx >= 0: | |
return layers[idx] | |
def training_step(self, train_batch, batch_idx): | |
x, target = train_batch | |
output = self.forward(x) | |
loss = self.criterion(output, target) | |
self.metric['train'] += get_correct_pred_count(output, target) | |
self.metric['train_total'] += len(x) | |
self.metric['epoch_train_loss'].append(loss) | |
acc = 100 * self.metric['train'] / self.metric['train_total'] | |
self.log_dict({'train_loss': loss, 'train_acc': acc}) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, target = val_batch | |
output = self.forward(x) | |
loss = self.criterion(output, target) | |
self.metric['val'] += get_correct_pred_count(output, target) | |
self.metric['val_total'] += len(x) | |
self.metric['epoch_val_loss'].append(loss) | |
acc = 100 * self.metric['val'] / self.metric['val_total'] | |
if self.current_epoch == self.trainer.max_epochs - 1: | |
add_predictions(x, output, target) | |
self.log_dict({'val_loss': loss, 'val_acc': acc}) | |
def test_step(self, test_batch, batch_idx): | |
self.validation_step(test_batch, batch_idx) | |
def train_dataloader(self): | |
if not self.trainer.train_dataloader: | |
self.trainer.fit_loop.setup_data() | |
return self.trainer.train_dataloader | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-6, weight_decay=0.01) | |
self.find_lr(optimizer) | |
print(self.max_lr) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, | |
max_lr=self.max_lr, | |
epochs=self.trainer.max_epochs, | |
steps_per_epoch=len(self.train_dataloader()), | |
pct_start=5 / self.trainer.max_epochs, | |
div_factor=100, | |
final_div_factor=100, | |
three_phase=False, | |
verbose=False | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
'interval': 'step', # or 'epoch' | |
'frequency': 1 | |
}, | |
} | |
def on_validation_epoch_end(self): | |
if self.metric['train_total']: | |
print('Epoch ', self.current_epoch) | |
train_acc = 100 * self.metric['train'] / self.metric['train_total'] | |
epoch_loss = sum(self.metric['epoch_train_loss']) / len(self.metric['epoch_train_loss']) | |
self.metric['train_loss'].append( epoch_loss.item() ) | |
self.metric['train_acc'].append(train_acc) | |
print('Train Loss: ', epoch_loss.item(), ' Accuracy: ', str(train_acc) + '%', ' [', | |
self.metric['train'], '/', self.metric['train_total'], ']') | |
self.metric['train'] = 0 | |
self.metric['train_total'] = 0 | |
self.metric['epoch_train_loss'] = [] | |
val_acc = 100 * self.metric['val'] / self.metric['val_total'] | |
epoch_loss = sum(self.metric['epoch_val_loss']) / len(self.metric['epoch_val_loss']) | |
self.metric['val_loss'].append( epoch_loss.item() ) | |
self.metric['val_acc'].append(val_acc) | |
print('Validation Loss: ', epoch_loss.item(), ' Accuracy: ', str(val_acc) + '%', ' [', self.metric['val'], | |
'/', self.metric['val_total'], ']\n') | |
self.metric['val'] = 0 | |
self.metric['val_total'] = 0 | |
self.metric['epoch_val_loss'] = [] | |
def find_lr(self, optimizer): | |
if not self.is_find_max_lr: | |
return | |
lr_finder = LRFinder(self, optimizer, self.criterion) | |
lr_finder.range_test(self.train_dataloader(), end_lr=100, num_iter=100) | |
_, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph | |
lr_finder.reset() | |
self.max_lr = best_lr | |
def plot_model_performance(self): | |
fig, axs = plt.subplots(2, 2, figsize=(15, 10)) | |
axs[0, 0].plot( self.metric['train_loss'] ) | |
axs[0, 0].set_title("Training Loss") | |
axs[1, 0].plot( self.metric['train_acc'] ) | |
axs[1, 0].set_title("Training Accuracy") | |
axs[0, 1].plot( self.metric['val_loss'] ) | |
axs[0, 1].set_title("Test Loss") | |
axs[1, 1].plot( self.metric['val_acc'] ) | |
axs[1, 1].set_title("Test Accuracy") | |
def plot_grad_cam(self, mean, std, target_layers, get_data_label_name, count=10, missclassified=True, grad_opacity=1.0): | |
cam = GradCAM(model=self, target_layers=target_layers) | |
#fig = plt.figure() | |
for i in range(count): | |
plt.subplot(int(count / 5), 5, i + 1) | |
plt.tight_layout() | |
if not missclassified: | |
pred_dict = test_correct_pred | |
else: | |
pred_dict = test_incorrect_pred | |
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())] | |
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets) | |
x = denormalize(pred_dict['images'][i].cpu(), mean, std) | |
image = np.array(255 * x, np.int16).transpose(1, 2, 0) | |
img_tensor = np.array(x, np.float16).transpose(1, 2, 0) | |
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True, | |
image_weight=(1.0 - grad_opacity) ) | |
plt.imshow(image, vmin=0, vmax=255) | |
plt.imshow(visualization, vmin=0, vmax=255, alpha=grad_opacity) | |
plt.xticks([]) | |
plt.yticks([]) | |
title = get_data_label_name(pred_dict['ground_truths'][i].item()) + ' / ' + \ | |
get_data_label_name(pred_dict['predicted_vals'][i].item()) | |
plt.title(title, fontsize=8) | |