ImageNet / inference.py
Shilpaj's picture
Feat: Files for application
077fb0c verified
raw
history blame
4.03 kB
#!/usr/bin/env python
"""
Inference script for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import numpy as np
import torch
from collections import OrderedDict
# Third Party Imports
from torchvision import transforms
from torch.nn import functional as F
from torchvision.models import resnet50
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
def inference(input_img,
model,
classes,
transparency=0.5,
number_of_top_classes=3,
target_layer_number=4):
"""
Function to run inference on the input image
:param input_img: Image provided by the user
:param model: Model to use for inference
:param classes: Classes to use for inference
:param transparency: Percentage of cam overlap over the input image
:param number_of_top_classes: Number of top predictions for the input image
:param target_layer_number: Layer for which GradCam to be shown
"""
# Save a copy of input img
org_img = input_img.copy()
# Calculate mean over each channel of input image
mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.)
# Calculate Standard deviation over each channel
std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.)
# Convert img to tensor and normalize it
_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b))
])
# Preprocess the input image
input_tensor = _transform(input_img)
# Create a mini-batch as expected by the model
input_tensor = input_tensor.unsqueeze(0)
# Move the input and model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_tensor = input_tensor.to(device)
model.to(device)
# Get Model Predictions
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)[0]
del outputs
confidences = {classes[i]: float(probabilities[i]) for i in range(1000)}
# Select the top classes based on user input
sorted_confidences = sorted(confidences.items(), key=lambda val: val[1], reverse=True)
show_confidences = OrderedDict(sorted_confidences[:number_of_top_classes])
# Map layer numbers to meaningful parts of the ResNet architecture
_layers = {
1: model.conv1, # Initial convolution layer
2: model.layer1[-1], # Last bottleneck of first residual block
3: model.layer2[-1], # Last bottleneck of second residual block
4: model.layer3[-1], # Last bottleneck of third residual block
5: model.layer4[-1], # Last bottleneck of fourth residual block
6: model.layer4[-1] # Changed from fc to last conv layer for better visualization
}
# Ensure valid layer selection
target_layer_number = min(max(target_layer_number, 1), 6)
target_layers = [_layers[target_layer_number]]
# Get the class activations from the selected layer
cam = GradCAM(model=model, target_layers=target_layers)
# Get the most probable class index
top_class = max(confidences.items(), key=lambda x: x[1])[0]
class_idx = classes.index(top_class)
# Generate GradCAM for the top predicted class
grayscale_cam = cam(input_tensor=input_tensor,
targets=[ClassifierOutputTarget(class_idx)],
aug_smooth=True,
eigen_smooth=True)
model.eval()
grayscale_cam = grayscale_cam[0, :]
# Overlay input image with Class activations
visualization = show_cam_on_image(org_img/255., grayscale_cam, use_rgb=True, image_weight=transparency)
return show_confidences, visualization