ImageNet / app.py
Shilpaj's picture
Refactor: Modifications for inference on GPU
f8ecba6
raw
history blame
5.85 kB
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr
# Third Party Imports
import torch
from torchvision import models
# Local Imports
from inference import inference
def load_model(model_path: str):
"""
Load the model.
"""
# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load the pre-trained ResNet50 model
model = models.resnet50(weights=None)
model = model.to(device)
# Load custom weights from a .pth file
state_dict = torch.load(model_path, map_location=device)
# Filter out unexpected keys
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
# Load the filtered state dictionary into the model
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
return model
def load_classes():
"""
Load the classes.
"""
# Get ImageNet class names from ResNet50 weights
weights = models.ResNet50_Weights.IMAGENET1K_V2
classes = weights.meta["categories"]
return classes
def inference_wrapper(image, alpha, top_k, target_layer):
"""
Wrapper function for inference with error handling
"""
try:
if image is None:
return None, None
with torch.cuda.amp.autocast(): # Enable automatic mixed precision
with torch.no_grad(): # Disable gradient calculation
return inference(
image,
alpha,
top_k,
target_layer,
model=model,
classes=classes
)
except Exception as e:
print(f"Error in inference: {str(e)}")
return gr.Error(f"Error processing image: {str(e)}")
def main():
"""
Main function for the application.
"""
global model, classes # Make these global so they're accessible to inference_wrapper
# Load the model at startup
model = load_model("resnet50_imagenet1k.pth")
# Load the classes at startup
classes = load_classes()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ImageNet-1K trained on ResNet50v2
"""
)
with gr.Tab("GradCam"):
gr.Markdown(
"""
Visualize Class Activations Maps generated by the model's layer for the predicted class.
"""
)
# Define inputs
with gr.Row():
img_input = gr.Image(
label="Input Image",
type="numpy",
height=224,
width=224
)
with gr.Column():
label_output = gr.Label(label="Predictions")
gradcam_output = gr.Image(
label="GradCAM Output",
height=224,
width=224
)
with gr.Row():
alpha_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Activation Map Transparency"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Top Predictions"
)
target_layer_slider = gr.Slider(
minimum=1,
maximum=6,
value=4,
step=1,
label="Target Layer Number"
)
gradcam_button = gr.Button("Generate GradCAM")
# Set up the click event
gradcam_button.click(
fn=inference_wrapper,
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
]
)
# Example section
gr.Examples(
examples=[
["assets/examples/dog.jpg", 0.5, 3, 4],
["assets/examples/cat.jpg", 0.5, 3, 4],
["assets/examples/frog.jpg", 0.5, 3, 4],
["assets/examples/bird.jpg", 0.5, 3, 4],
["assets/examples/shark-plane.jpg", 0.5, 3, 4],
["assets/examples/car.jpg", 0.5, 3, 4],
["assets/examples/truck.jpg", 0.5, 3, 4],
["assets/examples/horse.jpg", 0.5, 3, 4],
["assets/examples/plane.jpg", 0.5, 3, 4],
["assets/examples/ship.png", 0.5, 3, 4]
],
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
],
fn=inference_wrapper,
cache_examples=True,
label="Click on any example to run GradCAM"
)
# Launch the demo
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True,
enable_queue=True,
show_error=True,
max_threads=4
)
if __name__ == "__main__":
main()