cifar10 / app.py
Shilpaj's picture
Fix: Runtime error
9ea8844 verified
#!/usr/bin/env python3
"""
Gradio Application for model trained on CIFAR10 dataset
Author: Shilpaj Bhalerao
Date: Aug 06, 2023
"""
# Standard Library Imports
import os
from collections import OrderedDict
# Third-Party Imports
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
# Local Imports
from resnet import LITResNet
from visualize import FeatureMapVisualizer
# Directory Path
example_directory = 'examples/'
model_path = 'epoch=23-step=2112.ckpt'
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
model = LITResNet.load_from_checkpoint(model_path, map_location=torch.device('cpu'), strict=False, class_names=classes)
model.eval()
# Create an object of the Class
viz = FeatureMapVisualizer(model)
def inference(input_img,
transparency=0.5,
number_of_top_classes=3,
target_layer_number=4):
"""
Function to run inference on the input image
:param input_img: Image provided by the user
:parma transparency: Percentage of cam overlap over the input image
:param number_of_top_classes: Number of top predictions for the input image
:param target_layer_number: Layer for which GradCam to be shown
"""
# Resize the image to (32, 32)
input_img = Image.fromarray(input_img).resize((32, 32))
input_img = np.array(input_img)
# Calculate mean over each channel of input image
mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.)
# Calculate Standard deviation over each channel
std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.)
# Convert img to tensor and normalize it
_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b))
])
# Save a copy of input img
org_img = input_img
# Apply the transforms on the input image
input_img = _transform(input_img)
# Add batch dimension to perform inference
input_img = input_img.unsqueeze(0)
# Get Model Predictions
with torch.no_grad():
outputs = model(input_img)
o = torch.exp(outputs)[0]
confidences = {classes[i]: float(o[i]) for i in range(10)}
# Select the top classes based on user input
sorted_confidences = sorted(confidences.items(), key=lambda val: val[1], reverse=True)
show_confidences = OrderedDict(sorted_confidences[:number_of_top_classes])
# Name of layers defined in the model
_layers = ['prep_layer', 'custom_block1', 'resnet_block1',
'custom_block2', 'custom_block3', 'resnet_block3']
target_layers = [eval(f'model.{_layers[target_layer_number-1]}[0]')]
# Get the class activations from the selected layer
cam = GradCAM(model=model, target_layers=target_layers)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
# Overlay input image with Class activations
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return show_confidences, visualization
def display_misclassified_images(number: int = 1):
"""
Display the misclassified images saved during training
:param number: Number of images to display
"""
# List to store names of misclassified images
data = []
# Get the names of all the files from Misclassified directory
file_names = os.listdir('misclassified/')
# Save the correct name and misclassified class name as a tuple in the `data` list
for file in file_names:
file_name, extension = file.split('.')
correct_label, misclassified = file_name.split('_')
data.append((correct_label, misclassified))
# Create a path to the images for Gradio to access them
file_path = ['misclassified/' + file for file in file_names]
# Return the file path and names of correct and misclassified images
return file_path[:number], data[:number]
def feature_maps(input_img, kernel_number=32):
"""
Function to return feature maps for the selected image
:param input_img: User input image
:param kernel_number: Number of kernel in all 6 layers
"""
# Resize the image to (32, 32)
input_img = Image.fromarray(input_img).resize((32, 32))
input_img = np.array(input_img)
# Calculate mean over each channel of input image
mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.)
# Calculate Standard deviation over each channel
std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.)
# Convert img to tensor and normalize it
_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b))
])
# Apply transforms on the input image
input_img = _transform(input_img)
# Visualize feature maps for kernel number 32
plt = viz.visualize_feature_map_of_kernel(image=input_img, kernel_number=kernel_number)
return plt
def get_kernels(layer_number):
"""
Function to get the kernels from the layer
:param layer_number: Number of layer from which kernels to be visualized
"""
# Visualize kernels from layer
plt = viz.visualize_kernels_from_layer(layer_number=layer_number)
return plt
if __name__ == '__main__':
with gr.Blocks() as demo:
gr.Markdown(
"""
# CIFAR10 trained on ResNet18 Model
A model architecture by [David C](https://github.com/davidcpage) which is trained on CIFAR10 for 24 Epochs to achieve accuracy of 90+%
The model works for following classes: `plane`, `car`, `bird`, `cat`, `deer`, `dog`, `frog`, `horse`, `ship`, `truck`
"""
)
# #############################################################################
# ################################ GradCam Tab ################################
# #############################################################################
with gr.Tab("GradCam"):
gr.Markdown(
"""
Visualize Class Activations Maps generated by the model's layer for the predicted class
This is used to see what the model is actually looking at in the image
"""
)
with gr.Row():
img_input = gr.Image(label="Input Image")
gradcam_outputs = [gr.Label(),
gr.Image(label="Output")]
with gr.Row():
gradcam_inputs = [gr.Slider(0, 1, value=0.5,
label="How much percentage overlap of the input image on the activation maps?"),
gr.Slider(1, 10, value=3, step=1,
label="How many top class predictions you want to see?"),
gr.Slider(1, 6, value=4, step=1,
label="From 6 blocks of the model, which block's first convolutional layer's class activation you want to see?")]
gradcam_button = gr.Button("Submit")
gradcam_button.click(inference, inputs=[img_input] + gradcam_inputs, outputs=gradcam_outputs)
gr.Markdown("## Examples")
gr.Examples([example_directory + 'dog.jpg', example_directory + 'cat.jpg', example_directory + 'frog.jpg',
example_directory + 'bird.jpg', example_directory + 'shark-plane.jpg',
example_directory + 'car.jpg', example_directory + 'truck.jpg',
example_directory + 'horse.jpg', example_directory + 'plane.jpg',
example_directory + 'ship.png'],
inputs=img_input, fn=inference)
# ###########################################################################################
# ################################ Misclassified Images Tab #################################
# ###########################################################################################
with gr.Tab("Misclassified Images"):
gr.Markdown(
"""
10% of test images were misclassified by the model at the end of the training
You can visualize those images with their correct label and misclassified label
"""
)
with gr.Row():
mis_inputs = gr.Slider(1, 10, value=1, step=1,
label="Number of misclassified images to display")
mis_outputs = [
gr.Gallery(label="Misclassified Images", show_label=False, elem_id="gallery"),
gr.Dataframe(headers=["Correct Label", "Misclassified Label"], type="array", datatype="str",
row_count=10, col_count=2)]
mis_button = gr.Button("Display Misclassified Images")
mis_button.click(display_misclassified_images, inputs=mis_inputs, outputs=mis_outputs)
# ################################################################################################
# ################################ Feature Maps Visualization Tab ################################
# ################################################################################################
with gr.Tab("Feature Map Visualization"):
gr.Markdown(
"""
The model has 6 convolutional blocks. Each block has two or three convolutional layers
From each block's first convolutional layer, output of specific kernel number is visualized
In the below images `l1` represents first block and `kx` represents the number of kerenel from the first convolutional layer of that block
"""
)
with gr.Column():
feature_map_input = gr.Image(label="Feature Map Input Image")
feature_map_slider = gr.Slider(1, 32, value=16, step=1,
label="Select a Kernel number whose Features Maps from all 6 block's to be shown")
feature_map_output = gr.Plot()
feature_map_button = gr.Button("Visualize FeatureMaps")
feature_map_button.click(feature_maps, inputs=[feature_map_input, feature_map_slider], outputs=feature_map_output)
# ##########################################################################################
# ################################ Kernel Visualization Tab ################################
# ##########################################################################################
with gr.Tab("Kernel Visualization"):
gr.Markdown(
"""
The model has 6 convolutional blocks. Each block has two or three convolutional layers
Some of the Kernels from the first convolutional layer of selected block number are visualized below
"""
)
with gr.Column():
kernel_input = gr.Slider(1, 4, value=2, step=1,
label="Select a block number whose first convolutional layer's Kernels to be shown")
kernel_output = gr.Plot()
kernel_button = gr.Button("Visualize Kernels")
kernel_button.click(get_kernels, inputs=kernel_input, outputs=kernel_output)
gr.close_all()
demo.launch()