Spaces:
Sleeping
Sleeping
# imports | |
import albumentations as A | |
import lightning as L | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
from albumentations.pytorch import ToTensorV2 | |
from model import MyResNet | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image | |
from torch import nn | |
from torch.optim.lr_scheduler import OneCycleLR | |
from torch.utils.data import DataLoader | |
from torchmetrics.functional import accuracy | |
from torchvision import datasets, transforms | |
means = [0.4914, 0.4822, 0.4465] | |
stds = [0.2470, 0.2435, 0.2616] | |
class CustomResnetTransforms: | |
def train_transforms(means, stds): | |
return A.Compose( | |
[ | |
A.Normalize(mean=means, std=stds, always_apply=True), | |
A.PadIfNeeded(min_height=36, min_width=36, always_apply=True), | |
A.RandomCrop(height=32, width=32, always_apply=True), | |
A.HorizontalFlip(), | |
A.Cutout(num_holes=1, max_h_size=8, max_w_size=8, fill_value=0, p=1.0), | |
ToTensorV2(), | |
] | |
) | |
def test_transforms(means, stds): | |
return A.Compose( | |
[ | |
A.Normalize(mean=means, std=stds, always_apply=True), | |
ToTensorV2(), | |
] | |
) | |
class Cifar10SearchDataset(datasets.CIFAR10): | |
def __init__(self, root="~/data", train=True, download=True, transform=None): | |
super().__init__(root=root, train=train, download=download, transform=transform) | |
def __getitem__(self, index): | |
image, label = self.data[index], self.targets[index] | |
if self.transform is not None: | |
transformed = self.transform(image=image) | |
image = transformed["image"] | |
return image, label | |
class LitCIFAR10(L.LightningModule): | |
def __init__(self, data_dir='./data', learning_rate=0.01, batch_size = 512): | |
super().__init__() | |
# Set our init args as class attributes | |
self.data_dir = data_dir | |
self.lr = learning_rate | |
self.batch_size = batch_size | |
# Hardcode some dataset specific attributes | |
self.num_classes = 10 | |
self.train_transforms = CustomResnetTransforms.train_transforms(means, stds) | |
self.test_transforms = CustomResnetTransforms.test_transforms(means, stds) | |
# Define PyTorch model | |
self.model = MyResNet() | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = self.criterion(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
acc = accuracy(preds, y, task='multiclass', | |
num_classes=10) | |
# Calling self.log will surface up scalars for you in TensorBoard | |
self.log("train_loss", loss, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) | |
self.log("train_acc", acc, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) | |
# print("train_loss", loss) | |
# print("train_acc", acc) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
loss = self.criterion(logits, y) | |
preds = torch.argmax(logits, dim=1) | |
acc = accuracy(preds, y, task='multiclass', | |
num_classes=10) | |
# Calling self.log will surface up scalars for you in TensorBoard | |
self.log("val_loss", loss, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) | |
self.log("val_acc", acc, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) | |
return loss | |
def configure_optimizers(self): | |
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4) | |
steps_per_epoch = (len(self.trainset) // self.batch_size)+1 | |
scheduler_dict = { | |
"scheduler": OneCycleLR( | |
optimizer, | |
max_lr = self.lr, | |
steps_per_epoch=steps_per_epoch, | |
epochs=self.trainer.max_epochs, | |
pct_start=5/self.trainer.max_epochs, | |
div_factor=100, | |
three_phase=False, | |
final_div_factor=100, | |
anneal_strategy='linear' | |
), | |
"interval": "step", | |
} | |
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} | |
def setup(self, stage=None): | |
# Assign train/val datasets for use in dataloaders | |
self.trainset = Cifar10SearchDataset(root=self.data_dir, train=True, | |
download=True, transform=self.train_transforms) | |
self.valset = Cifar10SearchDataset(root=self.data_dir, train=False, | |
download=True, transform=self.test_transforms) | |
def train_dataloader(self): | |
return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=0, pin_memory=True) | |
def val_dataloader(self): | |
return DataLoader(self.valset, batch_size=self.batch_size, num_workers=0, pin_memory=True) | |
def get_misclassified_images(model, testset, mu, sigma, device): | |
model.eval() | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mu, sigma) | |
]) | |
misclassified_images, misclassified_predictions, true_targets = [], [], [] | |
with torch.no_grad(): | |
for data_, target in testset: | |
data = transform(data_).to(device) | |
data = data.unsqueeze(0) | |
output = model(data) | |
pred = output.argmax(dim=1, keepdim=True) | |
if pred.item()!=target: | |
misclassified_images.append(data_) | |
misclassified_predictions.append(pred.item()) | |
true_targets.append(target) | |
return misclassified_images, misclassified_predictions, true_targets | |
def plot_misclassified(image, pred, target, classes): | |
nrows = 4 | |
ncols = 5 | |
_, ax = plt.subplots(nrows, ncols, figsize=(20, 15)) | |
for i in range(nrows): | |
for j in range(ncols): | |
index = i * ncols + j | |
ax[i, j].axis("off") | |
ax[i, j].set_title(f"Prediction: {classes[pred[index]]}\nTarget: {classes[target[index]]}") | |
ax[i, j].imshow(image[index]) | |
plt.show() | |
class ClassifierOutputTarget: | |
def __init__(self, category): | |
self.category = category | |
def __call__(self, model_output): | |
if len(model_output.shape) == 1: | |
return model_output[self.category] | |
return model_output[:, self.category] | |
def plot_grad_cam_images(images, pred, target, classes, model): | |
nrows = 4 | |
ncols = 5 | |
fig, ax = plt.subplots(nrows, ncols, figsize=(20,15)) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
for i in range(nrows): | |
for j in range(ncols): | |
index = i * ncols + j | |
img = images[index] | |
input_tensor = preprocess_image(img, | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
target_layers = [model.model.layer3[-1]] | |
targets = [ClassifierOutputTarget(target[index])] | |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda = device) | |
grayscale_cam = cam(input_tensor=input_tensor, targets = targets) | |
#grayscale_cam = cam(input_tensor=input_tensor) | |
grayscale_cam = grayscale_cam[0, :] | |
rgb_img = np.float32(img) / 255 | |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight = 0.6) | |
index = i * ncols + j | |
ax[i, j].axis("off") | |
ax[i, j].set_title(f"Prediction: {classes[pred[index]]}\nTarget: {classes[target[index]]}") | |
ax[i, j].imshow(visualization) | |
plt.show() |