# imports import albumentations as A import lightning as L import matplotlib.pyplot as plt import numpy as np import torch import torch.optim as optim from albumentations.pytorch import ToTensorV2 from model import MyResNet from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import DataLoader from torchmetrics.functional import accuracy from torchvision import datasets, transforms means = [0.4914, 0.4822, 0.4465] stds = [0.2470, 0.2435, 0.2616] class CustomResnetTransforms: def train_transforms(means, stds): return A.Compose( [ A.Normalize(mean=means, std=stds, always_apply=True), A.PadIfNeeded(min_height=36, min_width=36, always_apply=True), A.RandomCrop(height=32, width=32, always_apply=True), A.HorizontalFlip(), A.Cutout(num_holes=1, max_h_size=8, max_w_size=8, fill_value=0, p=1.0), ToTensorV2(), ] ) def test_transforms(means, stds): return A.Compose( [ A.Normalize(mean=means, std=stds, always_apply=True), ToTensorV2(), ] ) class Cifar10SearchDataset(datasets.CIFAR10): def __init__(self, root="~/data", train=True, download=True, transform=None): super().__init__(root=root, train=train, download=download, transform=transform) def __getitem__(self, index): image, label = self.data[index], self.targets[index] if self.transform is not None: transformed = self.transform(image=image) image = transformed["image"] return image, label class LitCIFAR10(L.LightningModule): def __init__(self, data_dir='./data', learning_rate=0.01, batch_size = 512): super().__init__() # Set our init args as class attributes self.data_dir = data_dir self.lr = learning_rate self.batch_size = batch_size # Hardcode some dataset specific attributes self.num_classes = 10 self.train_transforms = CustomResnetTransforms.train_transforms(means, stds) self.test_transforms = CustomResnetTransforms.test_transforms(means, stds) # Define PyTorch model self.model = MyResNet() self.criterion = nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y, task='multiclass', num_classes=10) # Calling self.log will surface up scalars for you in TensorBoard self.log("train_loss", loss, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) self.log("train_acc", acc, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) # print("train_loss", loss) # print("train_acc", acc) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.criterion(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y, task='multiclass', num_classes=10) # Calling self.log will surface up scalars for you in TensorBoard self.log("val_loss", loss, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) self.log("val_acc", acc, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True) return loss def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4) steps_per_epoch = (len(self.trainset) // self.batch_size)+1 scheduler_dict = { "scheduler": OneCycleLR( optimizer, max_lr = self.lr, steps_per_epoch=steps_per_epoch, epochs=self.trainer.max_epochs, pct_start=5/self.trainer.max_epochs, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear' ), "interval": "step", } return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} def setup(self, stage=None): # Assign train/val datasets for use in dataloaders self.trainset = Cifar10SearchDataset(root=self.data_dir, train=True, download=True, transform=self.train_transforms) self.valset = Cifar10SearchDataset(root=self.data_dir, train=False, download=True, transform=self.test_transforms) def train_dataloader(self): return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=0, pin_memory=True) def val_dataloader(self): return DataLoader(self.valset, batch_size=self.batch_size, num_workers=0, pin_memory=True) def get_misclassified_images(model, testset, mu, sigma, device): model.eval() transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mu, sigma) ]) misclassified_images, misclassified_predictions, true_targets = [], [], [] with torch.no_grad(): for data_, target in testset: data = transform(data_).to(device) data = data.unsqueeze(0) output = model(data) pred = output.argmax(dim=1, keepdim=True) if pred.item()!=target: misclassified_images.append(data_) misclassified_predictions.append(pred.item()) true_targets.append(target) return misclassified_images, misclassified_predictions, true_targets def plot_misclassified(image, pred, target, classes): nrows = 4 ncols = 5 _, ax = plt.subplots(nrows, ncols, figsize=(20, 15)) for i in range(nrows): for j in range(ncols): index = i * ncols + j ax[i, j].axis("off") ax[i, j].set_title(f"Prediction: {classes[pred[index]]}\nTarget: {classes[target[index]]}") ax[i, j].imshow(image[index]) plt.show() class ClassifierOutputTarget: def __init__(self, category): self.category = category def __call__(self, model_output): if len(model_output.shape) == 1: return model_output[self.category] return model_output[:, self.category] def plot_grad_cam_images(images, pred, target, classes, model): nrows = 4 ncols = 5 fig, ax = plt.subplots(nrows, ncols, figsize=(20,15)) device = 'cuda' if torch.cuda.is_available() else 'cpu' for i in range(nrows): for j in range(ncols): index = i * ncols + j img = images[index] input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) target_layers = [model.model.layer3[-1]] targets = [ClassifierOutputTarget(target[index])] cam = GradCAM(model=model, target_layers=target_layers, use_cuda = device) grayscale_cam = cam(input_tensor=input_tensor, targets = targets) #grayscale_cam = cam(input_tensor=input_tensor) grayscale_cam = grayscale_cam[0, :] rgb_img = np.float32(img) / 255 visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight = 0.6) index = i * ncols + j ax[i, j].axis("off") ax[i, j].set_title(f"Prediction: {classes[pred[index]]}\nTarget: {classes[target[index]]}") ax[i, j].imshow(visualization) plt.show()