File size: 3,771 Bytes
077fb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr

# Third Party Imports
import torch
from torchvision import models

# Local Imports
from inference import inference


def load_model(model_path: str):
    """
    Load the model.
    """
    # Load the pre-trained ResNet50 model from ImageNet
    model = models.resnet50(pretrained=False)

    # Load custom weights from a .pth file
    state_dict = torch.load(model_path)

    # Filter out unexpected keys
    filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}

    # Load the filtered state dictionary into the model
    model.load_state_dict(filtered_state_dict, strict=False) 
    return model


def load_classes():
    """
    Load the classes.
    """
    # Get ImageNet class names from ResNet50 weights
    classes = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"]
    return classes


def main():
    """
    Main function for the application.
    """
    # Load the model at startup
    model = load_model("resnet50_imagenet1k.pth")
    
    # Load the classes at startup
    classes = load_classes()

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # ImageNet-1K trained on ResNet50v2
            """
        )

    # #############################################################################
    # ################################ GradCam Tab ################################
    # #############################################################################
    with gr.Tab("GradCam"):
        gr.Markdown(
            """
            Visualize Class Activations Maps generated by the model's layer for the predicted class.
            This is used to see what the model is actually looking at in the image.
            """
        )
        with gr.Row():
            # Update the image input dimensions
            img_input = [gr.Image(label="Input Image", type="numpy", height=224)]  # Changed dimensions
            gradcam_outputs = [
                gr.Label(label="Predictions"),
                gr.Image(label="GradCAM Output", height=224)  # Match input image height
            ]

        with gr.Row():
            gradcam_inputs = [
                gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"),
                gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"),
                gr.Slider(1, 6, value=4, step=1, label="Target Layer Number")
            ]

        gradcam_button = gr.Button("Generate GradCAM")
        # Pass model to inference function using partial
        from functools import partial
        inference_fn = partial(inference, model=model, classes=classes)
        gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs)

        gr.Markdown("## Examples")
        gr.Examples(
            examples=[
                ["./assets/examples/dog.jpg", 0.5, 3, 4],
                ["./assets/examples/cat.jpg", 0.5, 3, 4],
                ["./assets/examples/frog.jpg", 0.5, 3, 4],
                ["./assets/examples/bird.jpg", 0.5, 3, 4],
                ["./assets/examples/shark-plane.jpg", 0.5, 3, 4],
                ["./assets/examples/car.jpg", 0.5, 3, 4],
                ["./assets/examples/truck.jpg", 0.5, 3, 4],
                ["./assets/examples/horse.jpg", 0.5, 3, 4],
                ["./assets/examples/plane.jpg", 0.5, 3, 4],
                ["./assets/examples/ship.png", 0.5, 3, 4]
            ],
            inputs=img_input + gradcam_inputs,
            fn=inference_fn,
            outputs=gradcam_outputs
        )

    gr.close_all()
    demo.launch(debug=True)


if __name__ == "__main__":
    main()