ImageNet / app.py
Shilpaj's picture
Fix: Runtime error
671ad7d verified
raw
history blame
3.98 kB
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr
# Third Party Imports
import torch
from torchvision import models
# Local Imports
from inference import inference
def load_model(model_path: str):
"""
Load the model.
"""
# Load the pre-trained ResNet50 model from ImageNet
model = models.resnet50(pretrained=False)
# Load custom weights from a .pth file with CPU mapping
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# Filter out unexpected keys
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
# Load the filtered state dictionary into the model
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
return model
def load_classes():
"""
Load the classes.
"""
# Get ImageNet class names from ResNet50 weights
classes = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"]
return classes
def main():
"""
Main function for the application.
"""
# Load the model at startup
model = load_model("resnet50_imagenet1k.pth")
# Load the classes at startup
classes = load_classes()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ImageNet-1K trained on ResNet50v2
"""
)
# #############################################################################
# ################################ GradCam Tab ################################
# #############################################################################
with gr.Tab("GradCam"):
gr.Markdown(
"""
Visualize Class Activations Maps generated by the model's layer for the predicted class.
This is used to see what the model is actually looking at in the image.
"""
)
with gr.Row():
img_input = [gr.Image(label="Input Image", type="numpy", height=224)]
gradcam_outputs = [
gr.Label(label="Predictions"),
gr.Image(label="GradCAM Output", height=224)
]
with gr.Row():
gradcam_inputs = [
gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"),
gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"),
gr.Slider(1, 6, value=4, step=1, label="Target Layer Number")
]
gradcam_button = gr.Button("Generate GradCAM")
# Pass model to inference function using partial
from functools import partial
inference_fn = partial(inference, model=model, classes=classes)
gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs)
gr.Markdown("## Examples")
gr.Examples(
examples=[
["./assets/examples/dog.jpg", 0.5, 3, 4],
["./assets/examples/cat.jpg", 0.5, 3, 4],
["./assets/examples/frog.jpg", 0.5, 3, 4],
["./assets/examples/bird.jpg", 0.5, 3, 4],
["./assets/examples/shark-plane.jpg", 0.5, 3, 4],
["./assets/examples/car.jpg", 0.5, 3, 4],
["./assets/examples/truck.jpg", 0.5, 3, 4],
["./assets/examples/horse.jpg", 0.5, 3, 4],
["./assets/examples/plane.jpg", 0.5, 3, 4],
["./assets/examples/ship.png", 0.5, 3, 4]
],
inputs=img_input + gradcam_inputs,
fn=inference_fn,
outputs=gradcam_outputs
)
# Launch the demo (moved inside the Blocks context)
demo.launch(debug=True)
if __name__ == "__main__":
main()