s12erav1 / models /custom_resnet_lightning_s10.py
piyushgrover's picture
Uploaded app code
f7915f2
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib.pyplot as plt
from torch_lr_finder import LRFinder
import numpy as np
from utils import get_correct_pred_count, add_predictions, test_incorrect_pred, test_correct_pred, denormalize
NO_GROUPS = 4
class ResnetBlock(nn.Module):
def __init__(self, input_channel, output_channel, padding=1, norm='bn', drop=0.01):
super(ResnetBlock, self).__init__()
self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=padding)
if norm == 'bn':
self.n1 = nn.BatchNorm2d(output_channel)
elif norm == 'gn':
self.n1 = nn.GroupNorm(NO_GROUPS, output_channel)
elif norm == 'ln':
self.n1 = nn.GroupNorm(1, output_channel)
self.drop1 = nn.Dropout2d(drop)
self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=padding)
if norm == 'bn':
self.n2 = nn.BatchNorm2d(output_channel)
elif norm == 'gn':
self.n2 = nn.GroupNorm(NO_GROUPS, output_channel)
elif norm == 'ln':
self.n2 = nn.GroupNorm(1, output_channel)
self.drop2 = nn.Dropout2d(drop)
'''
Depending on the model requirement, Convolution block with number of layers is applied to the input image
'''
def __call__(self, x):
x = self.conv1(x)
x = self.n1(x)
x = F.relu(x)
x = self.drop1(x)
#if layers >= 2:
x = self.conv2(x)
x = self.n2(x)
x = F.relu(x)
x = self.drop2(x)
return x
class S10LightningModel(pl.LightningModule):
def __init__(self, base_channels, drop=0.01, loss_function=F.cross_entropy, is_find_max_lr=False, max_lr=3.20E-04):
super(S10LightningModel, self).__init__()
self.is_find_max_lr = is_find_max_lr
self.max_lr = max_lr
self.criterion = loss_function
self.metric = dict(train=0,
val=0,
train_total=0,
val_total=0,
epoch_train_loss=[],
epoch_val_loss=[],
train_loss=[],
val_loss=[],
train_acc=[],
val_acc=[])
self.base_channels = base_channels
self.prep_layer = nn.Sequential(
nn.Conv2d(3, base_channels, 3, stride=1, padding=1),
nn.BatchNorm2d(base_channels),
nn.ReLU(),
nn.Dropout2d(drop)
)
# layer1
self.x1 = nn.Sequential(
nn.Conv2d(base_channels, 2 * base_channels, 3, stride=1, padding=1),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(2 * base_channels),
nn.ReLU(),
nn.Dropout2d(drop)
)
self.R1 = ResnetBlock(2 * base_channels, 2 * base_channels, padding=1, drop=drop)
# layer2
self.layer2 = nn.Sequential(
nn.Conv2d(2 * base_channels, 4 * base_channels, 3, stride=1, padding=1),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(4 * base_channels),
nn.ReLU(),
nn.Dropout2d(drop)
)
# layer3
self.x2 = nn.Sequential(
nn.Conv2d(4 * base_channels, 8 * base_channels, 3, stride=1, padding=1),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(8 * base_channels),
nn.ReLU(),
nn.Dropout2d(drop)
)
self.R2 = ResnetBlock(8 * base_channels, 8 * base_channels, padding=1, drop=drop)
self.pool = nn.MaxPool2d(4)
self.fc = nn.Linear(8 * base_channels, 10)
def forward(self, x, no_softmax=False):
# print(x.size())
x = self.prep_layer(x)
# print(x.size())
x = self.x1(x)
# print('x1', x.size())
x = self.R1(x) + x
# print('x', x.size())
x = self.layer2(x)
# print(x.size())
x = self.x2(x)
# print('x2', x.size())
x = self.R2(x) + x
# print('x', x.size())
x = self.pool(x)
# print(x.size())
x = x.view(x.size(0), 8 * self.base_channels)
# print(x.size())
x = self.fc(x)
# print(x.size())
if no_softmax:
print(x.size())
return x
return F.log_softmax(x, dim=1)
def get_layer(self, idx):
layers = [self.prep_layer, self.x1, self.layer2, self.x2, self.pool]
if idx < len(layers) and idx >= 0:
return layers[idx]
def training_step(self, train_batch, batch_idx):
x, target = train_batch
output = self.forward(x)
loss = self.criterion(output, target)
self.metric['train'] += get_correct_pred_count(output, target)
self.metric['train_total'] += len(x)
self.metric['epoch_train_loss'].append(loss)
acc = 100 * self.metric['train'] / self.metric['train_total']
self.log_dict({'train_loss': loss, 'train_acc': acc})
return loss
def validation_step(self, val_batch, batch_idx):
x, target = val_batch
output = self.forward(x)
loss = self.criterion(output, target)
self.metric['val'] += get_correct_pred_count(output, target)
self.metric['val_total'] += len(x)
self.metric['epoch_val_loss'].append(loss)
acc = 100 * self.metric['val'] / self.metric['val_total']
if self.current_epoch == self.trainer.max_epochs - 1:
add_predictions(x, output, target)
self.log_dict({'val_loss': loss, 'val_acc': acc})
def test_step(self, test_batch, batch_idx):
self.validation_step(test_batch, batch_idx)
def train_dataloader(self):
if not self.trainer.train_dataloader:
self.trainer.fit_loop.setup_data()
return self.trainer.train_dataloader
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-6, weight_decay=0.01)
self.find_lr(optimizer)
print(self.max_lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
max_lr=self.max_lr,
epochs=self.trainer.max_epochs,
steps_per_epoch=len(self.train_dataloader()),
pct_start=5 / self.trainer.max_epochs,
div_factor=100,
final_div_factor=100,
three_phase=False,
verbose=False
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
'interval': 'step', # or 'epoch'
'frequency': 1
},
}
def on_validation_epoch_end(self):
if self.metric['train_total']:
print('Epoch ', self.current_epoch)
train_acc = 100 * self.metric['train'] / self.metric['train_total']
epoch_loss = sum(self.metric['epoch_train_loss']) / len(self.metric['epoch_train_loss'])
self.metric['train_loss'].append( epoch_loss.item() )
self.metric['train_acc'].append(train_acc)
print('Train Loss: ', epoch_loss.item(), ' Accuracy: ', str(train_acc) + '%', ' [',
self.metric['train'], '/', self.metric['train_total'], ']')
self.metric['train'] = 0
self.metric['train_total'] = 0
self.metric['epoch_train_loss'] = []
val_acc = 100 * self.metric['val'] / self.metric['val_total']
epoch_loss = sum(self.metric['epoch_val_loss']) / len(self.metric['epoch_val_loss'])
self.metric['val_loss'].append( epoch_loss.item() )
self.metric['val_acc'].append(val_acc)
print('Validation Loss: ', epoch_loss.item(), ' Accuracy: ', str(val_acc) + '%', ' [', self.metric['val'],
'/', self.metric['val_total'], ']\n')
self.metric['val'] = 0
self.metric['val_total'] = 0
self.metric['epoch_val_loss'] = []
def find_lr(self, optimizer):
if not self.is_find_max_lr:
return
lr_finder = LRFinder(self, optimizer, self.criterion)
lr_finder.range_test(self.train_dataloader(), end_lr=100, num_iter=100)
_, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph
lr_finder.reset()
self.max_lr = best_lr
def plot_model_performance(self):
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
axs[0, 0].plot( self.metric['train_loss'] )
axs[0, 0].set_title("Training Loss")
axs[1, 0].plot( self.metric['train_acc'] )
axs[1, 0].set_title("Training Accuracy")
axs[0, 1].plot( self.metric['val_loss'] )
axs[0, 1].set_title("Test Loss")
axs[1, 1].plot( self.metric['val_acc'] )
axs[1, 1].set_title("Test Accuracy")
def plot_grad_cam(self, mean, std, target_layers, get_data_label_name, count=10, missclassified=True, grad_opacity=1.0):
cam = GradCAM(model=self, target_layers=target_layers)
#fig = plt.figure()
for i in range(count):
plt.subplot(int(count / 5), 5, i + 1)
plt.tight_layout()
if not missclassified:
pred_dict = test_correct_pred
else:
pred_dict = test_incorrect_pred
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())]
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets)
x = denormalize(pred_dict['images'][i].cpu(), mean, std)
image = np.array(255 * x, np.int16).transpose(1, 2, 0)
img_tensor = np.array(x, np.float16).transpose(1, 2, 0)
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True,
image_weight=(1.0 - grad_opacity) )
plt.imshow(image, vmin=0, vmax=255)
plt.imshow(visualization, vmin=0, vmax=255, alpha=grad_opacity)
plt.xticks([])
plt.yticks([])
title = get_data_label_name(pred_dict['ground_truths'][i].item()) + ' / ' + \
get_data_label_name(pred_dict['predicted_vals'][i].item())
plt.title(title, fontsize=8)