#!/usr/bin/env python """ Application for ResNet50 trained on ImageNet-1K. """ # Standard Library Imports import gradio as gr # Third Party Imports import torch from torchvision import models # Local Imports from inference import inference def load_model(model_path: str): """ Load the model. """ # Load the pre-trained ResNet50 model from ImageNet model = models.resnet50(pretrained=False) # Load custom weights from a .pth file with CPU mapping state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Filter out unexpected keys filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} # Load the filtered state dictionary into the model model.load_state_dict(filtered_state_dict, strict=False) model.eval() return model def load_classes(): """ Load the classes. """ # Get ImageNet class names from ResNet50 weights classes = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"] return classes def main(): """ Main function for the application. """ # Load the model at startup model = load_model("resnet50_imagenet1k.pth") # Load the classes at startup classes = load_classes() with gr.Blocks() as demo: gr.Markdown( """ # ImageNet-1K trained on ResNet50v2 """ ) # ############################################################################# # ################################ GradCam Tab ################################ # ############################################################################# with gr.Tab("GradCam"): gr.Markdown( """ Visualize Class Activations Maps generated by the model's layer for the predicted class. This is used to see what the model is actually looking at in the image. """ ) with gr.Row(): img_input = [gr.Image(label="Input Image", type="numpy", height=224)] gradcam_outputs = [ gr.Label(label="Predictions"), gr.Image(label="GradCAM Output", height=224) ] with gr.Row(): gradcam_inputs = [ gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"), gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"), gr.Slider(1, 6, value=4, step=1, label="Target Layer Number") ] gradcam_button = gr.Button("Generate GradCAM") # Pass model to inference function using partial from functools import partial inference_fn = partial(inference, model=model, classes=classes) gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs) gr.Markdown("## Examples") gr.Examples( examples=[ ["./assets/examples/dog.jpg", 0.5, 3, 4], ["./assets/examples/cat.jpg", 0.5, 3, 4], ["./assets/examples/frog.jpg", 0.5, 3, 4], ["./assets/examples/bird.jpg", 0.5, 3, 4], ["./assets/examples/shark-plane.jpg", 0.5, 3, 4], ["./assets/examples/car.jpg", 0.5, 3, 4], ["./assets/examples/truck.jpg", 0.5, 3, 4], ["./assets/examples/horse.jpg", 0.5, 3, 4], ["./assets/examples/plane.jpg", 0.5, 3, 4], ["./assets/examples/ship.png", 0.5, 3, 4] ], inputs=img_input + gradcam_inputs, fn=inference_fn, outputs=gradcam_outputs ) # Launch the demo (moved inside the Blocks context) demo.launch(debug=True) if __name__ == "__main__": main()