Spaces:
Sleeping
Sleeping
Commit
·
f7915f2
1
Parent(s):
896d4b0
Uploaded app code
Browse files- .gitattributes +1 -0
- app.py +247 -0
- ckpt.pth +3 -0
- images/aeroplane.jpeg +0 -0
- images/bird.jpeg +0 -0
- images/car.jpeg +0 -0
- images/cat.jpeg +3 -0
- images/deer.jpeg +0 -0
- images/dog.jpeg +0 -0
- images/frog.jpeg +0 -0
- images/horse.jpeg +0 -0
- images/ship.jpeg +0 -0
- images/truck.jpeg +0 -0
- models/__pycache__/custom_resnet_lightning_s10.cpython-38.pyc +0 -0
- models/custom_resnet_lightning_s10.py +324 -0
- requirements.txt +7 -0
- utils.py +62 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/cat.jpeg filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from pytorch_grad_cam import GradCAM
|
7 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
8 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
9 |
+
import io
|
10 |
+
from models import custom_resnet_lightning_s10
|
11 |
+
from utils import load_model_from_checkpoint, denormalize, get_data_label_name, get_dataset_labels
|
12 |
+
|
13 |
+
device = torch.device('cpu')
|
14 |
+
dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \
|
15 |
+
(0.2470, 0.2435, 0.2616)
|
16 |
+
model = custom_resnet_lightning_s10.S10LightningModel(64)
|
17 |
+
|
18 |
+
checkpoint = load_model_from_checkpoint(device)
|
19 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
20 |
+
|
21 |
+
test_incorrect_pred = checkpoint['test_incorrect_pred']
|
22 |
+
|
23 |
+
sample_images = [
|
24 |
+
['images/aeroplane.jpeg', 0],
|
25 |
+
['images/bird.jpeg', 2],
|
26 |
+
['images/car.jpeg', 1],
|
27 |
+
['images/cat.jpeg', 3],
|
28 |
+
['images/deer.jpeg', 4],
|
29 |
+
['images/dog.jpeg', 5],
|
30 |
+
['images/frog.jpeg', 6],
|
31 |
+
['images/horse.jpeg', 7],
|
32 |
+
['images/ship.jpeg', 8],
|
33 |
+
['images/truck.jpeg', 9]
|
34 |
+
]
|
35 |
+
|
36 |
+
with gr.Blocks() as app:
|
37 |
+
'''
|
38 |
+
Select feature interface
|
39 |
+
'''
|
40 |
+
with gr.Row() as input_radio_group:
|
41 |
+
radio_btn = gr.Radio(
|
42 |
+
choices=['Top Prediction Classes', 'GradCAM Images', 'Missclassified Images'],
|
43 |
+
type="index",
|
44 |
+
label='Feature options',
|
45 |
+
info="Choose which feature you want to explore",
|
46 |
+
value='Top Prediction Classes'
|
47 |
+
)
|
48 |
+
|
49 |
+
'''
|
50 |
+
Options for GradCAM feature
|
51 |
+
'''
|
52 |
+
with gr.Row():
|
53 |
+
with gr.Column(visible=False) as grad_cam_col:
|
54 |
+
grad_cam_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count",
|
55 |
+
info="How many images you want to view?")
|
56 |
+
grad_cam_layer = gr.Slider(-4, -1, value=-3, step=1, label="Choose model layer",
|
57 |
+
info="Which layer you want to view GradCAM on? [-4 => last layer]")
|
58 |
+
grad_cam_opacity = gr.Slider(0, 1, value=0.4, step=0.1, label="Choose opacity of the gradient")
|
59 |
+
|
60 |
+
with gr.Column():
|
61 |
+
grad_cam_btn = gr.Button("Yes, Go Ahead")
|
62 |
+
|
63 |
+
with gr.Column(visible=False) as grad_cam_output:
|
64 |
+
grad_cam_output_gallery = gr.Gallery(value=[], columns=3, label='Output')
|
65 |
+
# prediction_title = gr.Label(value='')
|
66 |
+
|
67 |
+
'''
|
68 |
+
Options for Missclassfied images feature
|
69 |
+
'''
|
70 |
+
with gr.Row(visible=False) as missclassified_col:
|
71 |
+
with gr.Row():
|
72 |
+
missclassified_img_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count",
|
73 |
+
info="How man missclassified images you want to view?")
|
74 |
+
missclassified_btn = gr.Button("Click to Continue")
|
75 |
+
with gr.Row(visible=False) as missclassified_img_output:
|
76 |
+
missclassified_img_output_gallery = gr.Gallery(value=[], columns=5, label='Output')
|
77 |
+
|
78 |
+
'''
|
79 |
+
Option for Top prediction classes
|
80 |
+
'''
|
81 |
+
with gr.Row(visible=True) as top_pred_cls_col:
|
82 |
+
with gr.Column():
|
83 |
+
example_images = gr.Gallery(allow_preview=False, label='Select image ', info='',
|
84 |
+
value=[img[0] for img in sample_images], columns=3, rows=2,
|
85 |
+
object_fit='scale_down')
|
86 |
+
|
87 |
+
with gr.Column():
|
88 |
+
with gr.Row():
|
89 |
+
top_pred_image = gr.Image(shape=(32, 32), label='Upload Image or Select from the gallery')
|
90 |
+
top_class_count = gr.Slider(1, 10, value=5, step=1, label="Number of classes to predict")
|
91 |
+
top_class_btn = gr.Button("Submit")
|
92 |
+
|
93 |
+
with gr.Row(visible=True) as top_class_output:
|
94 |
+
# top_class_output_img = gr.Image().style(width=256, height=256)
|
95 |
+
top_class_output_labels = gr.Label(num_top_classes=top_class_count.value, label='Output')
|
96 |
+
|
97 |
+
|
98 |
+
def on_select(evt: gr.SelectData):
|
99 |
+
return {
|
100 |
+
top_pred_image: sample_images[evt.index][0]
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
example_images.select(on_select, None, top_pred_image)
|
105 |
+
|
106 |
+
|
107 |
+
def top_class_img_upload(input_img, top_class_count):
|
108 |
+
if input_img is not None:
|
109 |
+
transform = transforms.ToTensor()
|
110 |
+
org_img = input_img
|
111 |
+
input_img = transform(input_img)
|
112 |
+
input_img = input_img.to(device)
|
113 |
+
input_img = input_img.unsqueeze(0)
|
114 |
+
outputs = model(input_img, no_softmax=True)
|
115 |
+
softmax = torch.nn.Softmax(dim=0)
|
116 |
+
o = softmax(outputs.flatten())
|
117 |
+
confidences = {get_dataset_labels()[i]: float(o[i]) for i in range(10)}
|
118 |
+
top_class_output_labels.num_top_classes = top_class_count
|
119 |
+
return {
|
120 |
+
top_class_output: gr.update(visible=True),
|
121 |
+
# top_class_output_img: org_img,
|
122 |
+
top_class_output_labels: confidences
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
top_class_btn.click(
|
127 |
+
top_class_img_upload,
|
128 |
+
[top_pred_image, top_class_count],
|
129 |
+
[top_class_output, top_class_output_labels]
|
130 |
+
)
|
131 |
+
|
132 |
+
'''
|
133 |
+
Missclassified Images feature
|
134 |
+
'''
|
135 |
+
|
136 |
+
|
137 |
+
def show_missclassified_images(img_count):
|
138 |
+
imgs = []
|
139 |
+
for i in range(img_count):
|
140 |
+
img = test_incorrect_pred['images'][i].cpu()
|
141 |
+
img = denormalize(img, dataset_mean, dataset_std)
|
142 |
+
img = np.array(255 * img, np.int16).transpose(1, 2, 0)
|
143 |
+
label = '✅ ' + get_data_label_name(
|
144 |
+
test_incorrect_pred['ground_truths'][i].item()) + ' ❌ ' + get_data_label_name(
|
145 |
+
test_incorrect_pred['predicted_vals'][i].item())
|
146 |
+
imgs.append((img, label))
|
147 |
+
|
148 |
+
return {
|
149 |
+
missclassified_img_output: gr.update(visible=True),
|
150 |
+
missclassified_img_output_gallery: imgs
|
151 |
+
}
|
152 |
+
|
153 |
+
|
154 |
+
missclassified_btn.click(
|
155 |
+
show_missclassified_images,
|
156 |
+
[missclassified_img_count],
|
157 |
+
[missclassified_img_output_gallery, missclassified_img_output]
|
158 |
+
)
|
159 |
+
|
160 |
+
'''
|
161 |
+
GradCAM Feature
|
162 |
+
'''
|
163 |
+
|
164 |
+
|
165 |
+
def grad_cam_submit(img_count, layer_idx, grad_opacity):
|
166 |
+
|
167 |
+
target_layers = [model.get_layer(-1 * (layer_idx + 1))]
|
168 |
+
cam = GradCAM(model=model, target_layers=target_layers)
|
169 |
+
|
170 |
+
visual_arr = []
|
171 |
+
pred_arr = []
|
172 |
+
for i in range(img_count):
|
173 |
+
pred_dict = test_incorrect_pred
|
174 |
+
|
175 |
+
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())]
|
176 |
+
|
177 |
+
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets)
|
178 |
+
|
179 |
+
x = denormalize(pred_dict['images'][i].cpu(), dataset_mean, dataset_std)
|
180 |
+
|
181 |
+
image = np.array(255 * x, np.int16).transpose(1, 2, 0)
|
182 |
+
img_tensor = np.array(x, np.float16).transpose(1, 2, 0)
|
183 |
+
|
184 |
+
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True,
|
185 |
+
image_weight=(1.0 - grad_opacity))
|
186 |
+
|
187 |
+
visual_arr.append(
|
188 |
+
(visualization, get_data_label_name(pred_dict['ground_truths'][i].item()))
|
189 |
+
)
|
190 |
+
|
191 |
+
return {
|
192 |
+
grad_cam_output: gr.update(visible=True),
|
193 |
+
grad_cam_output_gallery: visual_arr
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
grad_cam_btn.click(
|
198 |
+
grad_cam_submit,
|
199 |
+
[grad_cam_count, grad_cam_layer, grad_cam_opacity],
|
200 |
+
[grad_cam_output_gallery, grad_cam_output]
|
201 |
+
)
|
202 |
+
|
203 |
+
'''
|
204 |
+
Select Feature to showcase
|
205 |
+
'''
|
206 |
+
|
207 |
+
|
208 |
+
def select_feature(feature):
|
209 |
+
if feature == 0:
|
210 |
+
return {
|
211 |
+
grad_cam_col: gr.update(visible=False),
|
212 |
+
grad_cam_output: gr.update(visible=False),
|
213 |
+
missclassified_col: gr.update(visible=False),
|
214 |
+
missclassified_img_output: gr.update(visible=False),
|
215 |
+
top_pred_cls_col: gr.update(visible=True),
|
216 |
+
top_class_output: gr.update(visible=True)
|
217 |
+
}
|
218 |
+
elif feature == 1:
|
219 |
+
return {
|
220 |
+
grad_cam_col: gr.update(visible=False),
|
221 |
+
grad_cam_output: gr.update(visible=False),
|
222 |
+
missclassified_col: gr.update(visible=True),
|
223 |
+
missclassified_img_output: gr.update(visible=True),
|
224 |
+
top_pred_cls_col: gr.update(visible=False),
|
225 |
+
top_class_output: gr.update(visible=False)
|
226 |
+
}
|
227 |
+
|
228 |
+
else:
|
229 |
+
return {
|
230 |
+
grad_cam_col: gr.update(visible=True),
|
231 |
+
grad_cam_output: gr.update(visible=True),
|
232 |
+
missclassified_col: gr.update(visible=False),
|
233 |
+
missclassified_img_output: gr.update(visible=False),
|
234 |
+
top_pred_cls_col: gr.update(visible=False),
|
235 |
+
top_class_output: gr.update(visible=False)
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
radio_btn.change(select_feature,
|
240 |
+
[radio_btn],
|
241 |
+
[grad_cam_col, grad_cam_output, missclassified_col, missclassified_img_output, top_pred_cls_col,
|
242 |
+
top_class_output])
|
243 |
+
|
244 |
+
'''
|
245 |
+
Launch the app
|
246 |
+
'''
|
247 |
+
app.launch()
|
ckpt.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c5cef3f797917b1f454d5538e8d39af1dea5a0dd880a148e5c19a1b1c746263
|
3 |
+
size 88712703
|
images/aeroplane.jpeg
ADDED
![]() |
images/bird.jpeg
ADDED
![]() |
images/car.jpeg
ADDED
![]() |
images/cat.jpeg
ADDED
![]() |
Git LFS Details
|
images/deer.jpeg
ADDED
![]() |
images/dog.jpeg
ADDED
![]() |
images/frog.jpeg
ADDED
![]() |
images/horse.jpeg
ADDED
![]() |
images/ship.jpeg
ADDED
![]() |
images/truck.jpeg
ADDED
![]() |
models/__pycache__/custom_resnet_lightning_s10.cpython-38.pyc
ADDED
Binary file (8.26 kB). View file
|
|
models/custom_resnet_lightning_s10.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_grad_cam import GradCAM
|
6 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
7 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from torch_lr_finder import LRFinder
|
11 |
+
import numpy as np
|
12 |
+
from utils import get_correct_pred_count, add_predictions, test_incorrect_pred, test_correct_pred, denormalize
|
13 |
+
|
14 |
+
NO_GROUPS = 4
|
15 |
+
class ResnetBlock(nn.Module):
|
16 |
+
def __init__(self, input_channel, output_channel, padding=1, norm='bn', drop=0.01):
|
17 |
+
|
18 |
+
super(ResnetBlock, self).__init__()
|
19 |
+
|
20 |
+
self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=padding)
|
21 |
+
|
22 |
+
if norm == 'bn':
|
23 |
+
self.n1 = nn.BatchNorm2d(output_channel)
|
24 |
+
elif norm == 'gn':
|
25 |
+
self.n1 = nn.GroupNorm(NO_GROUPS, output_channel)
|
26 |
+
elif norm == 'ln':
|
27 |
+
self.n1 = nn.GroupNorm(1, output_channel)
|
28 |
+
|
29 |
+
self.drop1 = nn.Dropout2d(drop)
|
30 |
+
|
31 |
+
self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=padding)
|
32 |
+
|
33 |
+
if norm == 'bn':
|
34 |
+
self.n2 = nn.BatchNorm2d(output_channel)
|
35 |
+
elif norm == 'gn':
|
36 |
+
self.n2 = nn.GroupNorm(NO_GROUPS, output_channel)
|
37 |
+
elif norm == 'ln':
|
38 |
+
self.n2 = nn.GroupNorm(1, output_channel)
|
39 |
+
|
40 |
+
self.drop2 = nn.Dropout2d(drop)
|
41 |
+
|
42 |
+
|
43 |
+
'''
|
44 |
+
Depending on the model requirement, Convolution block with number of layers is applied to the input image
|
45 |
+
'''
|
46 |
+
def __call__(self, x):
|
47 |
+
|
48 |
+
x = self.conv1(x)
|
49 |
+
x = self.n1(x)
|
50 |
+
x = F.relu(x)
|
51 |
+
|
52 |
+
x = self.drop1(x)
|
53 |
+
|
54 |
+
|
55 |
+
#if layers >= 2:
|
56 |
+
|
57 |
+
x = self.conv2(x)
|
58 |
+
|
59 |
+
x = self.n2(x)
|
60 |
+
x = F.relu(x)
|
61 |
+
x = self.drop2(x)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class S10LightningModel(pl.LightningModule):
|
67 |
+
def __init__(self, base_channels, drop=0.01, loss_function=F.cross_entropy, is_find_max_lr=False, max_lr=3.20E-04):
|
68 |
+
super(S10LightningModel, self).__init__()
|
69 |
+
|
70 |
+
self.is_find_max_lr = is_find_max_lr
|
71 |
+
self.max_lr = max_lr
|
72 |
+
self.criterion = loss_function
|
73 |
+
|
74 |
+
self.metric = dict(train=0,
|
75 |
+
val=0,
|
76 |
+
train_total=0,
|
77 |
+
val_total=0,
|
78 |
+
epoch_train_loss=[],
|
79 |
+
epoch_val_loss=[],
|
80 |
+
train_loss=[],
|
81 |
+
val_loss=[],
|
82 |
+
train_acc=[],
|
83 |
+
val_acc=[])
|
84 |
+
|
85 |
+
self.base_channels = base_channels
|
86 |
+
|
87 |
+
self.prep_layer = nn.Sequential(
|
88 |
+
nn.Conv2d(3, base_channels, 3, stride=1, padding=1),
|
89 |
+
nn.BatchNorm2d(base_channels),
|
90 |
+
nn.ReLU(),
|
91 |
+
nn.Dropout2d(drop)
|
92 |
+
)
|
93 |
+
|
94 |
+
# layer1
|
95 |
+
self.x1 = nn.Sequential(
|
96 |
+
nn.Conv2d(base_channels, 2 * base_channels, 3, stride=1, padding=1),
|
97 |
+
nn.MaxPool2d(2, 2),
|
98 |
+
nn.BatchNorm2d(2 * base_channels),
|
99 |
+
nn.ReLU(),
|
100 |
+
nn.Dropout2d(drop)
|
101 |
+
)
|
102 |
+
|
103 |
+
self.R1 = ResnetBlock(2 * base_channels, 2 * base_channels, padding=1, drop=drop)
|
104 |
+
|
105 |
+
# layer2
|
106 |
+
self.layer2 = nn.Sequential(
|
107 |
+
nn.Conv2d(2 * base_channels, 4 * base_channels, 3, stride=1, padding=1),
|
108 |
+
nn.MaxPool2d(2, 2),
|
109 |
+
nn.BatchNorm2d(4 * base_channels),
|
110 |
+
nn.ReLU(),
|
111 |
+
nn.Dropout2d(drop)
|
112 |
+
)
|
113 |
+
|
114 |
+
# layer3
|
115 |
+
self.x2 = nn.Sequential(
|
116 |
+
nn.Conv2d(4 * base_channels, 8 * base_channels, 3, stride=1, padding=1),
|
117 |
+
nn.MaxPool2d(2, 2),
|
118 |
+
nn.BatchNorm2d(8 * base_channels),
|
119 |
+
nn.ReLU(),
|
120 |
+
nn.Dropout2d(drop)
|
121 |
+
)
|
122 |
+
|
123 |
+
self.R2 = ResnetBlock(8 * base_channels, 8 * base_channels, padding=1, drop=drop)
|
124 |
+
|
125 |
+
self.pool = nn.MaxPool2d(4)
|
126 |
+
|
127 |
+
self.fc = nn.Linear(8 * base_channels, 10)
|
128 |
+
|
129 |
+
def forward(self, x, no_softmax=False):
|
130 |
+
|
131 |
+
# print(x.size())
|
132 |
+
|
133 |
+
x = self.prep_layer(x)
|
134 |
+
# print(x.size())
|
135 |
+
|
136 |
+
x = self.x1(x)
|
137 |
+
# print('x1', x.size())
|
138 |
+
|
139 |
+
x = self.R1(x) + x
|
140 |
+
# print('x', x.size())
|
141 |
+
|
142 |
+
x = self.layer2(x)
|
143 |
+
# print(x.size())
|
144 |
+
|
145 |
+
x = self.x2(x)
|
146 |
+
# print('x2', x.size())
|
147 |
+
|
148 |
+
x = self.R2(x) + x
|
149 |
+
# print('x', x.size())
|
150 |
+
|
151 |
+
x = self.pool(x)
|
152 |
+
# print(x.size())
|
153 |
+
|
154 |
+
x = x.view(x.size(0), 8 * self.base_channels)
|
155 |
+
# print(x.size())
|
156 |
+
|
157 |
+
x = self.fc(x)
|
158 |
+
# print(x.size())
|
159 |
+
|
160 |
+
if no_softmax:
|
161 |
+
print(x.size())
|
162 |
+
return x
|
163 |
+
|
164 |
+
return F.log_softmax(x, dim=1)
|
165 |
+
|
166 |
+
|
167 |
+
def get_layer(self, idx):
|
168 |
+
layers = [self.prep_layer, self.x1, self.layer2, self.x2, self.pool]
|
169 |
+
|
170 |
+
if idx < len(layers) and idx >= 0:
|
171 |
+
return layers[idx]
|
172 |
+
|
173 |
+
|
174 |
+
def training_step(self, train_batch, batch_idx):
|
175 |
+
x, target = train_batch
|
176 |
+
output = self.forward(x)
|
177 |
+
loss = self.criterion(output, target)
|
178 |
+
|
179 |
+
self.metric['train'] += get_correct_pred_count(output, target)
|
180 |
+
self.metric['train_total'] += len(x)
|
181 |
+
self.metric['epoch_train_loss'].append(loss)
|
182 |
+
|
183 |
+
acc = 100 * self.metric['train'] / self.metric['train_total']
|
184 |
+
|
185 |
+
self.log_dict({'train_loss': loss, 'train_acc': acc})
|
186 |
+
return loss
|
187 |
+
|
188 |
+
|
189 |
+
def validation_step(self, val_batch, batch_idx):
|
190 |
+
x, target = val_batch
|
191 |
+
output = self.forward(x)
|
192 |
+
loss = self.criterion(output, target)
|
193 |
+
|
194 |
+
self.metric['val'] += get_correct_pred_count(output, target)
|
195 |
+
self.metric['val_total'] += len(x)
|
196 |
+
self.metric['epoch_val_loss'].append(loss)
|
197 |
+
|
198 |
+
acc = 100 * self.metric['val'] / self.metric['val_total']
|
199 |
+
|
200 |
+
if self.current_epoch == self.trainer.max_epochs - 1:
|
201 |
+
add_predictions(x, output, target)
|
202 |
+
|
203 |
+
self.log_dict({'val_loss': loss, 'val_acc': acc})
|
204 |
+
|
205 |
+
|
206 |
+
def test_step(self, test_batch, batch_idx):
|
207 |
+
self.validation_step(test_batch, batch_idx)
|
208 |
+
|
209 |
+
def train_dataloader(self):
|
210 |
+
if not self.trainer.train_dataloader:
|
211 |
+
self.trainer.fit_loop.setup_data()
|
212 |
+
|
213 |
+
return self.trainer.train_dataloader
|
214 |
+
|
215 |
+
def configure_optimizers(self):
|
216 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-6, weight_decay=0.01)
|
217 |
+
self.find_lr(optimizer)
|
218 |
+
print(self.max_lr)
|
219 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
|
220 |
+
max_lr=self.max_lr,
|
221 |
+
epochs=self.trainer.max_epochs,
|
222 |
+
steps_per_epoch=len(self.train_dataloader()),
|
223 |
+
pct_start=5 / self.trainer.max_epochs,
|
224 |
+
div_factor=100,
|
225 |
+
final_div_factor=100,
|
226 |
+
three_phase=False,
|
227 |
+
verbose=False
|
228 |
+
)
|
229 |
+
return {
|
230 |
+
"optimizer": optimizer,
|
231 |
+
"lr_scheduler": {
|
232 |
+
"scheduler": scheduler,
|
233 |
+
'interval': 'step', # or 'epoch'
|
234 |
+
'frequency': 1
|
235 |
+
},
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
def on_validation_epoch_end(self):
|
240 |
+
if self.metric['train_total']:
|
241 |
+
print('Epoch ', self.current_epoch)
|
242 |
+
train_acc = 100 * self.metric['train'] / self.metric['train_total']
|
243 |
+
epoch_loss = sum(self.metric['epoch_train_loss']) / len(self.metric['epoch_train_loss'])
|
244 |
+
self.metric['train_loss'].append( epoch_loss.item() )
|
245 |
+
self.metric['train_acc'].append(train_acc)
|
246 |
+
|
247 |
+
|
248 |
+
print('Train Loss: ', epoch_loss.item(), ' Accuracy: ', str(train_acc) + '%', ' [',
|
249 |
+
self.metric['train'], '/', self.metric['train_total'], ']')
|
250 |
+
|
251 |
+
self.metric['train'] = 0
|
252 |
+
self.metric['train_total'] = 0
|
253 |
+
self.metric['epoch_train_loss'] = []
|
254 |
+
|
255 |
+
val_acc = 100 * self.metric['val'] / self.metric['val_total']
|
256 |
+
|
257 |
+
epoch_loss = sum(self.metric['epoch_val_loss']) / len(self.metric['epoch_val_loss'])
|
258 |
+
self.metric['val_loss'].append( epoch_loss.item() )
|
259 |
+
self.metric['val_acc'].append(val_acc)
|
260 |
+
|
261 |
+
print('Validation Loss: ', epoch_loss.item(), ' Accuracy: ', str(val_acc) + '%', ' [', self.metric['val'],
|
262 |
+
'/', self.metric['val_total'], ']\n')
|
263 |
+
|
264 |
+
self.metric['val'] = 0
|
265 |
+
self.metric['val_total'] = 0
|
266 |
+
self.metric['epoch_val_loss'] = []
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
def find_lr(self, optimizer):
|
271 |
+
if not self.is_find_max_lr:
|
272 |
+
return
|
273 |
+
|
274 |
+
lr_finder = LRFinder(self, optimizer, self.criterion)
|
275 |
+
lr_finder.range_test(self.train_dataloader(), end_lr=100, num_iter=100)
|
276 |
+
_, best_lr = lr_finder.plot() # to inspect the loss-learning rate graph
|
277 |
+
lr_finder.reset()
|
278 |
+
self.max_lr = best_lr
|
279 |
+
|
280 |
+
|
281 |
+
def plot_model_performance(self):
|
282 |
+
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
|
283 |
+
axs[0, 0].plot( self.metric['train_loss'] )
|
284 |
+
axs[0, 0].set_title("Training Loss")
|
285 |
+
axs[1, 0].plot( self.metric['train_acc'] )
|
286 |
+
axs[1, 0].set_title("Training Accuracy")
|
287 |
+
axs[0, 1].plot( self.metric['val_loss'] )
|
288 |
+
axs[0, 1].set_title("Test Loss")
|
289 |
+
axs[1, 1].plot( self.metric['val_acc'] )
|
290 |
+
axs[1, 1].set_title("Test Accuracy")
|
291 |
+
|
292 |
+
|
293 |
+
def plot_grad_cam(self, mean, std, target_layers, get_data_label_name, count=10, missclassified=True, grad_opacity=1.0):
|
294 |
+
cam = GradCAM(model=self, target_layers=target_layers)
|
295 |
+
|
296 |
+
#fig = plt.figure()
|
297 |
+
for i in range(count):
|
298 |
+
plt.subplot(int(count / 5), 5, i + 1)
|
299 |
+
plt.tight_layout()
|
300 |
+
if not missclassified:
|
301 |
+
pred_dict = test_correct_pred
|
302 |
+
else:
|
303 |
+
pred_dict = test_incorrect_pred
|
304 |
+
|
305 |
+
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())]
|
306 |
+
|
307 |
+
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets)
|
308 |
+
|
309 |
+
x = denormalize(pred_dict['images'][i].cpu(), mean, std)
|
310 |
+
|
311 |
+
image = np.array(255 * x, np.int16).transpose(1, 2, 0)
|
312 |
+
img_tensor = np.array(x, np.float16).transpose(1, 2, 0)
|
313 |
+
|
314 |
+
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True,
|
315 |
+
image_weight=(1.0 - grad_opacity) )
|
316 |
+
|
317 |
+
plt.imshow(image, vmin=0, vmax=255)
|
318 |
+
plt.imshow(visualization, vmin=0, vmax=255, alpha=grad_opacity)
|
319 |
+
plt.xticks([])
|
320 |
+
plt.yticks([])
|
321 |
+
|
322 |
+
title = get_data_label_name(pred_dict['ground_truths'][i].item()) + ' / ' + \
|
323 |
+
get_data_label_name(pred_dict['predicted_vals'][i].item())
|
324 |
+
plt.title(title, fontsize=8)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torch-lr-finder
|
3 |
+
torchvision
|
4 |
+
pillow
|
5 |
+
gradio
|
6 |
+
grad-cam
|
7 |
+
numpy
|
utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def get_dataset_labels():
|
5 |
+
return ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
|
6 |
+
|
7 |
+
|
8 |
+
def get_data_label_name(idx):
|
9 |
+
if idx < 0:
|
10 |
+
return ''
|
11 |
+
|
12 |
+
return get_dataset_labels()[idx]
|
13 |
+
|
14 |
+
|
15 |
+
def get_data_idx_from_name(name):
|
16 |
+
if not name:
|
17 |
+
return -1
|
18 |
+
|
19 |
+
return get_dataset_labels.index(name.lower()) if name.lower() in get_dataset_labels() else -1
|
20 |
+
|
21 |
+
def load_model_from_checkpoint(device, file_name='checkpoint.ckpt'):
|
22 |
+
checkpoint = torch.load('ckpt.pth', map_location=device)
|
23 |
+
|
24 |
+
return checkpoint
|
25 |
+
|
26 |
+
|
27 |
+
def denormalize(img, mean, std):
|
28 |
+
MEAN = torch.tensor(mean)
|
29 |
+
STD = torch.tensor(std)
|
30 |
+
|
31 |
+
img = img * STD[:, None, None] + MEAN[:, None, None]
|
32 |
+
i_min = img.min().item()
|
33 |
+
i_max = img.max().item()
|
34 |
+
|
35 |
+
img_bar = (img - i_min)/(i_max - i_min)
|
36 |
+
|
37 |
+
return img_bar
|
38 |
+
|
39 |
+
# Data to plot accuracy and loss graphs
|
40 |
+
train_losses = []
|
41 |
+
test_losses = []
|
42 |
+
train_acc = []
|
43 |
+
test_acc = []
|
44 |
+
|
45 |
+
test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
|
46 |
+
test_correct_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}
|
47 |
+
|
48 |
+
def get_correct_pred_count(pPrediction, pLabels):
|
49 |
+
return pPrediction.argmax(dim=1).eq(pLabels).sum().item()
|
50 |
+
|
51 |
+
|
52 |
+
def add_predictions(data, pred, target):
|
53 |
+
diff_preds = pred.argmax(dim=1) - target
|
54 |
+
for idx, d in enumerate(diff_preds):
|
55 |
+
if d.item() != 0:
|
56 |
+
test_incorrect_pred['images'].append(data[idx])
|
57 |
+
test_incorrect_pred['ground_truths'].append(target[idx])
|
58 |
+
test_incorrect_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])
|
59 |
+
elif d.item() == 0:
|
60 |
+
test_correct_pred['images'].append(data[idx])
|
61 |
+
test_correct_pred['ground_truths'].append(target[idx])
|
62 |
+
test_correct_pred['predicted_vals'].append(pred.argmax(dim=1)[idx])
|