gupta1912 commited on
Commit
4937055
·
1 Parent(s): 29933c6

Upload 3 files

Browse files
Files changed (3) hide show
  1. model.ckpt +3 -0
  2. model.py +73 -0
  3. utils.py +211 -0
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0ae70f8bb67a7ab0e6336f4357de5795ce471083cb74ee1db05d4bf85a03719
3
+ size 78956063
model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(ResBlock, self).__init__()
8
+ self.convblock1 = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels,kernel_size=(3,3), stride = 1, padding = 1,bias=False),
10
+ nn.BatchNorm2d(out_channels),
11
+ nn.ReLU(),
12
+ nn.Conv2d(out_channels, out_channels,kernel_size=(3,3), stride = 1, padding = 1,bias=False),
13
+ nn.BatchNorm2d(out_channels),
14
+ nn.ReLU()
15
+ )
16
+
17
+
18
+ def forward(self, x):
19
+ x = self.convblock1(x)
20
+ return x
21
+
22
+ class MyResNet(nn.Module):
23
+ def __init__(self):
24
+ super(MyResNet,self).__init__()
25
+
26
+ self.prep_layer = nn.Sequential(
27
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,bias=True),
28
+ nn.BatchNorm2d(64),
29
+ nn.ReLU(),
30
+ )
31
+
32
+ self.layer1 = nn.Sequential(
33
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1,bias=True),
34
+ nn.MaxPool2d(2,2),
35
+ nn.BatchNorm2d(128),
36
+ nn.ReLU(),
37
+ )
38
+
39
+ self.resblock1 = ResBlock(128, 128)
40
+
41
+ self.layer2 = nn.Sequential(
42
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1,bias=True),
43
+ nn.MaxPool2d(kernel_size=2),
44
+ nn.BatchNorm2d(256),
45
+ nn.ReLU(),
46
+ )
47
+
48
+ self.layer3 = nn.Sequential(
49
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1,bias=True),
50
+ nn.MaxPool2d(kernel_size=2),
51
+ nn.BatchNorm2d(512),
52
+ nn.ReLU(),
53
+ )
54
+
55
+ self.resblock2 = ResBlock(512, 512)
56
+
57
+ self.maxpool = nn.MaxPool2d(kernel_size=4)
58
+ self.fc = nn.Linear(512, 10)
59
+
60
+ def forward(self, x):
61
+ out = self.prep_layer(x)
62
+ out = self.layer1(out)
63
+ res1 = self.resblock1(out)
64
+ out = out + res1
65
+ out = self.layer2(out)
66
+ out = self.layer3(out)
67
+ res2 = self.resblock2(out)
68
+ out = out + res2
69
+ out = self.maxpool(out)
70
+ out = out.view(out.size(0), -1)
71
+ out = self.fc(out)
72
+
73
+ return F.log_softmax(out,dim = -1)
utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+ import albumentations as A
3
+ import lightning as L
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.optim as optim
8
+ from albumentations.pytorch import ToTensorV2
9
+ from model import MyResNet
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
12
+ from torch import nn
13
+ from torch.optim.lr_scheduler import OneCycleLR
14
+ from torch.utils.data import DataLoader
15
+ from torchmetrics.functional import accuracy
16
+ from torchvision import datasets, transforms
17
+
18
+ means = [0.4914, 0.4822, 0.4465]
19
+ stds = [0.2470, 0.2435, 0.2616]
20
+
21
+ class CustomResnetTransforms:
22
+ def train_transforms(means, stds):
23
+ return A.Compose(
24
+ [
25
+ A.Normalize(mean=means, std=stds, always_apply=True),
26
+ A.PadIfNeeded(min_height=36, min_width=36, always_apply=True),
27
+ A.RandomCrop(height=32, width=32, always_apply=True),
28
+ A.HorizontalFlip(),
29
+ A.Cutout(num_holes=1, max_h_size=8, max_w_size=8, fill_value=0, p=1.0),
30
+ ToTensorV2(),
31
+ ]
32
+ )
33
+
34
+ def test_transforms(means, stds):
35
+ return A.Compose(
36
+ [
37
+ A.Normalize(mean=means, std=stds, always_apply=True),
38
+ ToTensorV2(),
39
+ ]
40
+ )
41
+
42
+ class Cifar10SearchDataset(datasets.CIFAR10):
43
+
44
+ def __init__(self, root="~/data", train=True, download=True, transform=None):
45
+ super().__init__(root=root, train=train, download=download, transform=transform)
46
+
47
+ def __getitem__(self, index):
48
+ image, label = self.data[index], self.targets[index]
49
+ if self.transform is not None:
50
+ transformed = self.transform(image=image)
51
+ image = transformed["image"]
52
+ return image, label
53
+
54
+ class LitCIFAR10(L.LightningModule):
55
+ def __init__(self, data_dir='./data', learning_rate=0.01, batch_size = 512):
56
+
57
+ super().__init__()
58
+
59
+ # Set our init args as class attributes
60
+ self.data_dir = data_dir
61
+ self.lr = learning_rate
62
+ self.batch_size = batch_size
63
+
64
+ # Hardcode some dataset specific attributes
65
+ self.num_classes = 10
66
+ self.train_transforms = CustomResnetTransforms.train_transforms(means, stds)
67
+ self.test_transforms = CustomResnetTransforms.test_transforms(means, stds)
68
+
69
+ # Define PyTorch model
70
+ self.model = MyResNet()
71
+ self.criterion = nn.CrossEntropyLoss()
72
+
73
+
74
+ def forward(self, x):
75
+ return self.model(x)
76
+
77
+ def training_step(self, batch, batch_idx):
78
+ x, y = batch
79
+ logits = self(x)
80
+ loss = self.criterion(logits, y)
81
+ preds = torch.argmax(logits, dim=1)
82
+ acc = accuracy(preds, y, task='multiclass',
83
+ num_classes=10)
84
+
85
+ # Calling self.log will surface up scalars for you in TensorBoard
86
+
87
+ self.log("train_loss", loss, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True)
88
+ self.log("train_acc", acc, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True)
89
+ # print("train_loss", loss)
90
+ # print("train_acc", acc)
91
+
92
+ return loss
93
+
94
+ def validation_step(self, batch, batch_idx):
95
+ x, y = batch
96
+ logits = self(x)
97
+ loss = self.criterion(logits, y)
98
+ preds = torch.argmax(logits, dim=1)
99
+ acc = accuracy(preds, y, task='multiclass',
100
+ num_classes=10)
101
+
102
+ # Calling self.log will surface up scalars for you in TensorBoard
103
+ self.log("val_loss", loss, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True)
104
+ self.log("val_acc", acc, prog_bar=True, enable_graph = True, on_step=False, on_epoch=True)
105
+ return loss
106
+
107
+ def configure_optimizers(self):
108
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
109
+
110
+ steps_per_epoch = (len(self.trainset) // self.batch_size)+1
111
+ scheduler_dict = {
112
+ "scheduler": OneCycleLR(
113
+ optimizer,
114
+ max_lr = self.lr,
115
+ steps_per_epoch=steps_per_epoch,
116
+ epochs=self.trainer.max_epochs,
117
+ pct_start=5/self.trainer.max_epochs,
118
+ div_factor=100,
119
+ three_phase=False,
120
+ final_div_factor=100,
121
+ anneal_strategy='linear'
122
+ ),
123
+ "interval": "step",
124
+ }
125
+ return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
126
+
127
+ def setup(self, stage=None):
128
+
129
+ # Assign train/val datasets for use in dataloaders
130
+ self.trainset = Cifar10SearchDataset(root=self.data_dir, train=True,
131
+ download=True, transform=self.train_transforms)
132
+ self.valset = Cifar10SearchDataset(root=self.data_dir, train=False,
133
+ download=True, transform=self.test_transforms)
134
+
135
+ def train_dataloader(self):
136
+ return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=0, pin_memory=True)
137
+
138
+ def val_dataloader(self):
139
+ return DataLoader(self.valset, batch_size=self.batch_size, num_workers=0, pin_memory=True)
140
+
141
+ def get_misclassified_images(model, testset, mu, sigma, device):
142
+ model.eval()
143
+ transform=transforms.Compose([
144
+ transforms.ToTensor(),
145
+ transforms.Normalize(mu, sigma)
146
+ ])
147
+ misclassified_images, misclassified_predictions, true_targets = [], [], []
148
+ with torch.no_grad():
149
+ for data_, target in testset:
150
+ data = transform(data_).to(device)
151
+ data = data.unsqueeze(0)
152
+ output = model(data)
153
+ pred = output.argmax(dim=1, keepdim=True)
154
+
155
+ if pred.item()!=target:
156
+ misclassified_images.append(data_)
157
+ misclassified_predictions.append(pred.item())
158
+ true_targets.append(target)
159
+ return misclassified_images, misclassified_predictions, true_targets
160
+
161
+ def plot_misclassified(image, pred, target, classes):
162
+
163
+ nrows = 4
164
+ ncols = 5
165
+
166
+ _, ax = plt.subplots(nrows, ncols, figsize=(20, 15))
167
+
168
+ for i in range(nrows):
169
+ for j in range(ncols):
170
+ index = i * ncols + j
171
+ ax[i, j].axis("off")
172
+ ax[i, j].set_title(f"Prediction: {classes[pred[index]]}\nTarget: {classes[target[index]]}")
173
+ ax[i, j].imshow(image[index])
174
+ plt.show()
175
+
176
+
177
+ class ClassifierOutputTarget:
178
+ def __init__(self, category):
179
+ self.category = category
180
+
181
+ def __call__(self, model_output):
182
+ if len(model_output.shape) == 1:
183
+ return model_output[self.category]
184
+ return model_output[:, self.category]
185
+
186
+ def plot_grad_cam_images(images, pred, target, classes, model):
187
+ nrows = 4
188
+ ncols = 5
189
+ fig, ax = plt.subplots(nrows, ncols, figsize=(20,15))
190
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
191
+ for i in range(nrows):
192
+ for j in range(ncols):
193
+ index = i * ncols + j
194
+ img = images[index]
195
+ input_tensor = preprocess_image(img,
196
+ mean=[0.485, 0.456, 0.406],
197
+ std=[0.229, 0.224, 0.225])
198
+ target_layers = [model.model.layer3[-1]]
199
+ targets = [ClassifierOutputTarget(target[index])]
200
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda = device)
201
+ grayscale_cam = cam(input_tensor=input_tensor, targets = targets)
202
+ #grayscale_cam = cam(input_tensor=input_tensor)
203
+ grayscale_cam = grayscale_cam[0, :]
204
+ rgb_img = np.float32(img) / 255
205
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight = 0.6)
206
+
207
+ index = i * ncols + j
208
+ ax[i, j].axis("off")
209
+ ax[i, j].set_title(f"Prediction: {classes[pred[index]]}\nTarget: {classes[target[index]]}")
210
+ ax[i, j].imshow(visualization)
211
+ plt.show()