piyushgrover commited on
Commit
f7915f2
·
1 Parent(s): 896d4b0

Uploaded app code

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/cat.jpeg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8
+ from pytorch_grad_cam.utils.image import show_cam_on_image
9
+ import io
10
+ from models import custom_resnet_lightning_s10
11
+ from utils import load_model_from_checkpoint, denormalize, get_data_label_name, get_dataset_labels
12
+
13
+ device = torch.device('cpu')
14
+ dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \
15
+ (0.2470, 0.2435, 0.2616)
16
+ model = custom_resnet_lightning_s10.S10LightningModel(64)
17
+
18
+ checkpoint = load_model_from_checkpoint(device)
19
+ model.load_state_dict(checkpoint['model'], strict=False)
20
+
21
+ test_incorrect_pred = checkpoint['test_incorrect_pred']
22
+
23
+ sample_images = [
24
+ ['images/aeroplane.jpeg', 0],
25
+ ['images/bird.jpeg', 2],
26
+ ['images/car.jpeg', 1],
27
+ ['images/cat.jpeg', 3],
28
+ ['images/deer.jpeg', 4],
29
+ ['images/dog.jpeg', 5],
30
+ ['images/frog.jpeg', 6],
31
+ ['images/horse.jpeg', 7],
32
+ ['images/ship.jpeg', 8],
33
+ ['images/truck.jpeg', 9]
34
+ ]
35
+
36
+ with gr.Blocks() as app:
37
+ '''
38
+ Select feature interface
39
+ '''
40
+ with gr.Row() as input_radio_group:
41
+ radio_btn = gr.Radio(
42
+ choices=['Top Prediction Classes', 'GradCAM Images', 'Missclassified Images'],
43
+ type="index",
44
+ label='Feature options',
45
+ info="Choose which feature you want to explore",
46
+ value='Top Prediction Classes'
47
+ )
48
+
49
+ '''
50
+ Options for GradCAM feature
51
+ '''
52
+ with gr.Row():
53
+ with gr.Column(visible=False) as grad_cam_col:
54
+ grad_cam_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count",
55
+ info="How many images you want to view?")
56
+ grad_cam_layer = gr.Slider(-4, -1, value=-3, step=1, label="Choose model layer",
57
+ info="Which layer you want to view GradCAM on? [-4 => last layer]")
58
+ grad_cam_opacity = gr.Slider(0, 1, value=0.4, step=0.1, label="Choose opacity of the gradient")
59
+
60
+ with gr.Column():
61
+ grad_cam_btn = gr.Button("Yes, Go Ahead")
62
+
63
+ with gr.Column(visible=False) as grad_cam_output:
64
+ grad_cam_output_gallery = gr.Gallery(value=[], columns=3, label='Output')
65
+ # prediction_title = gr.Label(value='')
66
+
67
+ '''
68
+ Options for Missclassfied images feature
69
+ '''
70
+ with gr.Row(visible=False) as missclassified_col:
71
+ with gr.Row():
72
+ missclassified_img_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count",
73
+ info="How man missclassified images you want to view?")
74
+ missclassified_btn = gr.Button("Click to Continue")
75
+ with gr.Row(visible=False) as missclassified_img_output:
76
+ missclassified_img_output_gallery = gr.Gallery(value=[], columns=5, label='Output')
77
+
78
+ '''
79
+ Option for Top prediction classes
80
+ '''
81
+ with gr.Row(visible=True) as top_pred_cls_col:
82
+ with gr.Column():
83
+ example_images = gr.Gallery(allow_preview=False, label='Select image ', info='',
84
+ value=[img[0] for img in sample_images], columns=3, rows=2,
85
+ object_fit='scale_down')
86
+
87
+ with gr.Column():
88
+ with gr.Row():
89
+ top_pred_image = gr.Image(shape=(32, 32), label='Upload Image or Select from the gallery')
90
+ top_class_count = gr.Slider(1, 10, value=5, step=1, label="Number of classes to predict")
91
+ top_class_btn = gr.Button("Submit")
92
+
93
+ with gr.Row(visible=True) as top_class_output:
94
+ # top_class_output_img = gr.Image().style(width=256, height=256)
95
+ top_class_output_labels = gr.Label(num_top_classes=top_class_count.value, label='Output')
96
+
97
+
98
+ def on_select(evt: gr.SelectData):
99
+ return {
100
+ top_pred_image: sample_images[evt.index][0]
101
+ }
102
+
103
+
104
+ example_images.select(on_select, None, top_pred_image)
105
+
106
+
107
+ def top_class_img_upload(input_img, top_class_count):
108
+ if input_img is not None:
109
+ transform = transforms.ToTensor()
110
+ org_img = input_img
111
+ input_img = transform(input_img)
112
+ input_img = input_img.to(device)
113
+ input_img = input_img.unsqueeze(0)
114
+ outputs = model(input_img, no_softmax=True)
115
+ softmax = torch.nn.Softmax(dim=0)
116
+ o = softmax(outputs.flatten())
117
+ confidences = {get_dataset_labels()[i]: float(o[i]) for i in range(10)}
118
+ top_class_output_labels.num_top_classes = top_class_count
119
+ return {
120
+ top_class_output: gr.update(visible=True),
121
+ # top_class_output_img: org_img,
122
+ top_class_output_labels: confidences
123
+ }
124
+
125
+
126
+ top_class_btn.click(
127
+ top_class_img_upload,
128
+ [top_pred_image, top_class_count],
129
+ [top_class_output, top_class_output_labels]
130
+ )
131
+
132
+ '''
133
+ Missclassified Images feature
134
+ '''
135
+
136
+
137
+ def show_missclassified_images(img_count):
138
+ imgs = []
139
+ for i in range(img_count):
140
+ img = test_incorrect_pred['images'][i].cpu()
141
+ img = denormalize(img, dataset_mean, dataset_std)
142
+ img = np.array(255 * img, np.int16).transpose(1, 2, 0)
143
+ label = '✅ ' + get_data_label_name(
144
+ test_incorrect_pred['ground_truths'][i].item()) + ' ❌ ' + get_data_label_name(
145
+ test_incorrect_pred['predicted_vals'][i].item())
146
+ imgs.append((img, label))
147
+
148
+ return {
149
+ missclassified_img_output: gr.update(visible=True),
150
+ missclassified_img_output_gallery: imgs
151
+ }
152
+
153
+
154
+ missclassified_btn.click(
155
+ show_missclassified_images,
156
+ [missclassified_img_count],
157
+ [missclassified_img_output_gallery, missclassified_img_output]
158
+ )
159
+
160
+ '''
161
+ GradCAM Feature
162
+ '''
163
+
164
+
165
+ def grad_cam_submit(img_count, layer_idx, grad_opacity):
166
+
167
+ target_layers = [model.get_layer(-1 * (layer_idx + 1))]
168
+ cam = GradCAM(model=model, target_layers=target_layers)
169
+
170
+ visual_arr = []
171
+ pred_arr = []
172
+ for i in range(img_count):
173
+ pred_dict = test_incorrect_pred
174
+
175
+ targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())]
176
+
177
+ grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets)
178
+
179
+ x = denormalize(pred_dict['images'][i].cpu(), dataset_mean, dataset_std)
180
+
181
+ image = np.array(255 * x, np.int16).transpose(1, 2, 0)
182
+ img_tensor = np.array(x, np.float16).transpose(1, 2, 0)
183
+
184
+ visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True,
185
+ image_weight=(1.0 - grad_opacity))
186
+
187
+ visual_arr.append(
188
+ (visualization, get_data_label_name(pred_dict['ground_truths'][i].item()))
189
+ )
190
+
191
+ return {
192
+ grad_cam_output: gr.update(visible=True),
193
+ grad_cam_output_gallery: visual_arr
194
+ }
195
+
196
+
197
+ grad_cam_btn.click(
198
+ grad_cam_submit,
199
+ [grad_cam_count, grad_cam_layer, grad_cam_opacity],
200
+ [grad_cam_output_gallery, grad_cam_output]
201
+ )
202
+
203
+ '''
204
+ Select Feature to showcase
205
+ '''
206
+
207
+
208
+ def select_feature(feature):
209
+ if feature == 0:
210
+ return {
211
+ grad_cam_col: gr.update(visible=False),
212
+ grad_cam_output: gr.update(visible=False),
213
+ missclassified_col: gr.update(visible=False),
214
+ missclassified_img_output: gr.update(visible=False),
215
+ top_pred_cls_col: gr.update(visible=True),
216
+ top_class_output: gr.update(visible=True)
217
+ }
218
+ elif feature == 1:
219
+ return {
220
+ grad_cam_col: gr.update(visible=False),
221
+ grad_cam_output: gr.update(visible=False),
222
+ missclassified_col: gr.update(visible=True),
223
+ missclassified_img_output: gr.update(visible=True),
224
+ top_pred_cls_col: gr.update(visible=False),
225
+ top_class_output: gr.update(visible=False)
226
+ }
227
+
228
+ else:
229
+ return {
230
+ grad_cam_col: gr.update(visible=True),
231
+ grad_cam_output: gr.update(visible=True),
232
+ missclassified_col: gr.update(visible=False),
233
+ missclassified_img_output: gr.update(visible=False),
234
+ top_pred_cls_col: gr.update(visible=False),
235
+ top_class_output: gr.update(visible=False)
236
+ }
237
+
238
+
239
+ radio_btn.change(select_feature,
240
+ [radio_btn],
241
+ [grad_cam_col, grad_cam_output, missclassified_col, missclassified_img_output, top_pred_cls_col,
242
+ top_class_output])
243
+
244
+ '''
245
+ Launch the app
246
+ '''
247
+ app.launch()
ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c5cef3f797917b1f454d5538e8d39af1dea5a0dd880a148e5c19a1b1c746263
3
+ size 88712703
images/aeroplane.jpeg ADDED
images/bird.jpeg ADDED
images/car.jpeg ADDED
images/cat.jpeg ADDED

Git LFS Details

  • SHA256: 2743ac102aca5d2aec6870e1a127041d97d1fd5b0be0900e58ec9e179f33a442
  • Pointer size: 132 Bytes
  • Size of remote file: 4.63 MB
images/deer.jpeg ADDED
images/dog.jpeg ADDED
images/frog.jpeg ADDED
images/horse.jpeg ADDED
images/ship.jpeg ADDED
images/truck.jpeg ADDED
models/__pycache__/custom_resnet_lightning_s10.cpython-38.pyc ADDED
Binary file (8.26 kB). View file
 
models/custom_resnet_lightning_s10.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+ from pytorch_grad_cam import GradCAM
6
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+
9
+ import matplotlib.pyplot as plt
10
+ from torch_lr_finder import LRFinder
11
+ import numpy as np
12
+ from utils import get_correct_pred_count, add_predictions, test_incorrect_pred, test_correct_pred, denormalize
13
+
14
+ NO_GROUPS = 4
15
+ class ResnetBlock(nn.Module):
16
+ def __init__(self, input_channel, output_channel, padding=1, norm='bn', drop=0.01):
17
+
18
+ super(ResnetBlock, self).__init__()
19
+
20
+ self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=padding)
21
+
22
+ if norm == 'bn':
23
+ self.n1 = nn.BatchNorm2d(output_channel)
24
+ elif norm == 'gn':
25
+ self.n1 = nn.GroupNorm(NO_GROUPS, output_channel)
26
+ elif norm == 'ln':
27
+ self.n1 = nn.GroupNorm(1, output_channel)
28
+
29
+ self.drop1 = nn.Dropout2d(drop)
30
+
31
+ self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=padding)
32
+
33
+ if norm == 'bn':
34
+ self.n2 = nn.BatchNorm2d(output_channel)
35
+ elif norm == 'gn':
36
+ self.n2 = nn.GroupNorm(NO_GROUPS, output_channel)
37
+ elif norm == 'ln':
38
+ self.n2 = nn.GroupNorm(1, output_channel)
39
+
40
+ self.drop2 = nn.Dropout2d(drop)
41
+
42
+
43
+ '''
44
+ Depending on the model requirement, Convolution block with number of layers is applied to the input image
45
+ '''
46
+ def __call__(self, x):
47
+
48
+ x = self.conv1(x)
49
+ x = self.n1(x)
50
+ x = F.relu(x)
51
+
52
+ x = self.drop1(x)
53
+
54
+
55
+ #if layers >= 2:
56
+
57
+ x = self.conv2(x)
58
+
59
+ x = self.n2(x)
60
+ x = F.relu(x)
61
+ x = self.drop2(x)
62
+
63
+ return x
64
+
65
+
66
+ class S10LightningModel(pl.LightningModule):
67
+ def __init__(self, base_channels, drop=0.01, loss_function=F.cross_entropy, is_find_max_lr=False, max_lr=3.20E-04):
68
+ super(S10LightningModel, self).__init__()
69
+
70
+ self.is_find_max_lr = is_find_max_lr
71
+ self.max_lr = max_lr
72
+ self.criterion = loss_function
73
+
74
+ self.metric = dict(train=0,
75
+ val=0,
76
+ train_total=0,
77
+ val_total=0,
78
+ epoch_train_loss=[],
79
+ epoch_val_loss=[],
80
+ train_loss=[],
81
+ val_loss=[],
82
+ train_acc=[],
83
+ val_acc=[])
84
+
85
+ self.base_channels = base_channels
86
+
87
+ self.prep_layer = nn.Sequential(
88
+ nn.Conv2d(3, base_channels, 3, stride=1, padding=1),
89
+ nn.BatchNorm2d(base_channels),
90
+ nn.ReLU(),
91
+ nn.Dropout2d(drop)
92
+ )
93
+
94
+ # layer1
95
+ self.x1 = nn.Sequential(
96
+ nn.Conv2d(base_channels, 2 * base_channels, 3, stride=1, padding=1),
97
+ nn.MaxPool2d(2, 2),
98
+ nn.BatchNorm2d(2 * base_channels),
99
+ nn.ReLU(),
100
+ nn.Dropout2d(drop)
101
+ )
102
+
103
+ self.R1 = ResnetBlock(2 * base_channels, 2 * base_channels, padding=1, drop=drop)
104
+
105
+ # layer2
106
+ self.layer2 = nn.Sequential(
107
+ nn.Conv2d(2 * base_channels, 4 * base_channels, 3, stride=1, padding=1),
108
+ nn.MaxPool2d(2, 2),
109
+ nn.BatchNorm2d(4 * base_channels),
110
+ nn.ReLU(),
111
+ nn.Dropout2d(drop)
112
+ )
113
+
114
+ # layer3
115
+ self.x2 = nn.Sequential(
116
+ nn.Conv2d(4 * base_channels, 8 * base_channels, 3, stride=1, padding=1),
117
+ nn.MaxPool2d(2, 2),
118
+ nn.BatchNorm2d(8 * base_channels),
119
+ nn.ReLU(),
120
+ nn.Dropout2d(drop)
121
+ )
122
+
123
+ self.R2 = ResnetBlock(8 * base_channels, 8 * base_channels, padding=1, drop=drop)
124
+
125
+ self.pool = nn.MaxPool2d(4)
126
+
127
+ self.fc = nn.Linear(8 * base_channels, 10)
128
+
129
+ def forward(self, x, no_softmax=False):
130
+
131
+ # print(x.size())
132
+
133
+ x = self.prep_layer(x)
134
+ # print(x.size())
135
+
136
+ x = self.x1(x)
137
+ # print('x1', x.size())
138
+
139
+ x = self.R1(x) + x
140
+ # print('x', x.size())
141
+
142
+ x = self.layer2(x)
143
+ # print(x.size())
144
+
145
+ x = self.x2(x)
146
+ # print('x2', x.size())
147
+
148
+ x = self.R2(x) + x
149
+ # print('x', x.size())
150
+
151
+ x = self.pool(x)
152
+ # print(x.size())
153
+
154
+ x = x.view(x.size(0), 8 * self.base_channels)
155
+ # print(x.size())
156
+
157
+ x = self.fc(x)
158
+ # print(x.size())
159
+
160
+ if no_softmax:
161
+ print(x.size())
162
+ return x
163
+
164
+ return F.log_softmax(x, dim=1)
165
+
166
+
167
+ def get_layer(self, idx):
168
+ layers = [self.prep_layer, self.x1, self.layer2, self.x2, self.pool]
169
+
170
+ if idx < len(layers) and idx >= 0:
171
+ return layers[idx]
172
+
173
+
174
+ def training_step(self, train_batch, batch_idx):
175
+ x, target = train_batch
176
+ output = self.forward(x)
177
+ loss = self.criterion(output, target)
178
+
179
+ self.metric['train'] += get_correct_pred_count(output, target)
180
+ self.metric['train_total'] += len(x)
181
+ self.metric['epoch_train_loss'].append(loss)
182
+
183
+ acc = 100 * self.metric['train'] / self.metric['train_total']
184
+
185
+ self.log_dict({'train_loss': loss, 'train_acc': acc})
186
+ return loss
187
+
188
+
189
+ def validation_step(self, val_batch, batch_idx):
190
+ x, target = val_batch
191
+ output = self.forward(x)
192
+ loss = self.criterion(output, target)
193
+
194
+ self.metric['val'] += get_correct_pred_count(output, target)
195
+ self.metric['val_total'] += len(x)
196
+ self.metric['epoch_val_loss'].append(loss)
197
+
198
+ acc = 100 * self.metric['val'] / self.metric['val_total']
199
+
200
+ if self.current_epoch == self.trainer.max_epochs - 1:
201
+ add_predictions(x, output, target)
202
+
203
+ self.log_dict({'val_loss': loss, 'val_acc': acc})
204
+
205
+
206
+ def test_step(self, test_batch, batch_idx):
207
+ self.validation_step(test_batch, batch_idx)
208
+
209
+ def train_dataloader(self):
210
+ if not self.trainer.train_dataloader:
211
+ self.trainer.fit_loop.setup_data()
212
+
213
+ return self.trainer.train_dataloader
214
+
215
+ def configure_optimizers(self):
216
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-6, weight_decay=0.01)
217
+ self.find_lr(optimizer)
218
+ print(self.max_lr)
219
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
220
+ max_lr=self.max_lr,
221
+ epochs=self.trainer.max_epochs,
222
+ steps_per_epoch=len(self.train_dataloader()),
223
+ pct_start=5 / self.trainer.max_epochs,
224
+ div_factor=100,
225
+ final_div_factor=100,
226
+ three_phase=False,
227
+ verbose=False
228
+ )
229
+ return {
230
+ "optimizer": optimizer,
231
+ "lr_scheduler": {
232
+ "scheduler": scheduler,
233
+ 'interval': 'step', # or 'epoch'
234
+ 'frequency': 1
235
+ },
236
+ }
237
+
238
+
239
+ def on_validation_epoch_end(self):
240
+ if self.metric['train_total']:
241
+ print('Epoch ', self.current_epoch)
242
+ train_acc = 100 * self.metric['train'] / self.metric['train_total']
243
+ epoch_loss = sum(self.metric['epoch_train_loss']) / len(self.metric['epoch_train_loss'])
244
+ self.metric['train_loss'].append( epoch_loss.item() )
245
+ self.metric['train_acc'].append(train_acc)
246
+
247
+
248
+ print('Train Loss: ', epoch_loss.item(), ' Accuracy: ', str(train_acc) + '%', ' [',
249
+ self.metric['train'], '/', self.metric['train_total'], ']')
250
+
251
+ self.metric['train'] = 0
252
+ self.metric['train_total'] = 0
253
+ self.metric['epoch_train_loss'] = []
254
+
255
+ val_acc = 100 * self.metric['val'] / self.metric['val_total']
256
+
257
+ epoch_loss = sum(self.metric['epoch_val_loss']) / len(self.metric['epoch_val_loss'])
258
+ self.metric['val_loss'].append( epoch_loss.item() )
259
+ self.metric['val_acc'].append(val_acc)
260
+
261
+ print('Validation Loss: ', epoch_loss.item(), ' Accuracy: ', str(val_acc) + '%', ' [', self.metric['val'],
262
+ '/', self.metric['val_total'], ']\n')
263
+
264
+ self.metric['val'] = 0
265
+ self.metric['val_total'] = 0
266
+ self.metric['epoch_val_loss'] = []
267
+
268
+
269
+
270
+ def find_lr(self, optimizer):
271
+ if not self.is_find_max_lr:
272
+ return
273
+
274
+ lr_finder = LRFinder(self, optimizer, self.criterion)
275
+ lr_finder.range_test(self.train_dataloader(), end_lr=100, num_iter=100)
276
+ _, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph
277
+ lr_finder.reset()
278
+ self.max_lr = best_lr
279
+
280
+
281
+ def plot_model_performance(self):
282
+ fig, axs = plt.subplots(2, 2, figsize=(15, 10))
283
+ axs[0, 0].plot( self.metric['train_loss'] )
284
+ axs[0, 0].set_title("Training Loss")
285
+ axs[1, 0].plot( self.metric['train_acc'] )
286
+ axs[1, 0].set_title("Training Accuracy")
287
+ axs[0, 1].plot( self.metric['val_loss'] )
288
+ axs[0, 1].set_title("Test Loss")
289
+ axs[1, 1].plot( self.metric['val_acc'] )
290
+ axs[1, 1].set_title("Test Accuracy")
291
+
292
+
293
+ def plot_grad_cam(self, mean, std, target_layers, get_data_label_name, count=10, missclassified=True, grad_opacity=1.0):
294
+ cam = GradCAM(model=self, target_layers=target_layers)
295
+
296
+ #fig = plt.figure()
297
+ for i in range(count):
298
+ plt.subplot(int(count / 5), 5, i + 1)
299
+ plt.tight_layout()
300
+ if not missclassified:
301
+ pred_dict = test_correct_pred
302
+ else:
303
+ pred_dict = test_incorrect_pred
304
+
305
+ targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())]
306
+
307
+ grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets)
308
+
309
+ x = denormalize(pred_dict['images'][i].cpu(), mean, std)
310
+
311
+ image = np.array(255 * x, np.int16).transpose(1, 2, 0)
312
+ img_tensor = np.array(x, np.float16).transpose(1, 2, 0)
313
+
314
+ visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True,
315
+ image_weight=(1.0 - grad_opacity) )
316
+
317
+ plt.imshow(image, vmin=0, vmax=255)
318
+ plt.imshow(visualization, vmin=0, vmax=255, alpha=grad_opacity)
319
+ plt.xticks([])
320
+ plt.yticks([])
321
+
322
+ title = get_data_label_name(pred_dict['ground_truths'][i].item()) + ' / ' + \
323
+ get_data_label_name(pred_dict['predicted_vals'][i].item())
324
+ plt.title(title, fontsize=8)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torch-lr-finder
3
+ torchvision
4
+ pillow
5
+ gradio
6
+ grad-cam
7
+ numpy
utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ def get_dataset_labels():
5
+ return ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
6
+
7
+
8
+ def get_data_label_name(idx):
9
+ if idx < 0:
10
+ return ''
11
+
12
+ return get_dataset_labels()[idx]
13
+
14
+
15
+ def get_data_idx_from_name(name):
16
+ if not name:
17
+ return -1
18
+
19
+ return get_dataset_labels.index(name.lower()) if name.lower() in get_dataset_labels() else -1
20
+
21
+ def load_model_from_checkpoint(device, file_name='checkpoint.ckpt'):
22
+ checkpoint = torch.load('ckpt.pth', map_location=device)
23
+
24
+ return checkpoint
25
+
26
+
27
+ def denormalize(img, mean, std):
28
+ MEAN = torch.tensor(mean)
29
+ STD = torch.tensor(std)
30
+
31
+ img = img * STD[:, None, None] + MEAN[:, None, None]
32
+ i_min = img.min().item()
33
+ i_max = img.max().item()
34
+
35
+ img_bar = (img - i_min)/(i_max - i_min)
36
+
37
+ return img_bar
38
+
39
+ # Data to plot accuracy and loss graphs
40
+ train_losses = []
41
+ test_losses = []
42
+ train_acc = []
43
+ test_acc = []
44
+
45
+ test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
46
+ test_correct_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
47
+
48
+ def get_correct_pred_count(pPrediction, pLabels):
49
+ return pPrediction.argmax(dim=1).eq(pLabels).sum().item()
50
+
51
+
52
+ def add_predictions(data, pred, target):
53
+ diff_preds = pred.argmax(dim=1) - target
54
+ for idx, d in enumerate(diff_preds):
55
+ if d.item() != 0:
56
+ test_incorrect_pred['images'].append(data[idx])
57
+ test_incorrect_pred['ground_truths'].append(target[idx])
58
+ test_incorrect_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])
59
+ elif d.item() == 0:
60
+ test_correct_pred['images'].append(data[idx])
61
+ test_correct_pred['ground_truths'].append(target[idx])
62
+ test_correct_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])