Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
""" | |
Application for ResNet50 trained on ImageNet-1K. | |
""" | |
# Standard Library Imports | |
import gradio as gr | |
# Third Party Imports | |
import torch | |
from torchvision import models | |
# Local Imports | |
from inference import inference | |
def load_model(model_path: str): | |
""" | |
Load the model. | |
""" | |
# Load the pre-trained ResNet50 model from ImageNet | |
model = models.resnet50(pretrained=False) | |
# Load custom weights from a .pth file with CPU mapping | |
state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
# Filter out unexpected keys | |
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} | |
# Load the filtered state dictionary into the model | |
model.load_state_dict(filtered_state_dict, strict=False) | |
model.eval() | |
return model | |
def load_classes(): | |
""" | |
Load the classes. | |
""" | |
# Get ImageNet class names from ResNet50 weights | |
classes = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"] | |
return classes | |
def main(): | |
""" | |
Main function for the application. | |
""" | |
# Load the model at startup | |
model = load_model("resnet50_imagenet1k.pth") | |
# Load the classes at startup | |
classes = load_classes() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# ImageNet-1K trained on ResNet50v2 | |
""" | |
) | |
# ############################################################################# | |
# ################################ GradCam Tab ################################ | |
# ############################################################################# | |
with gr.Tab("GradCam"): | |
gr.Markdown( | |
""" | |
Visualize Class Activations Maps generated by the model's layer for the predicted class. | |
This is used to see what the model is actually looking at in the image. | |
""" | |
) | |
with gr.Row(): | |
img_input = [gr.Image(label="Input Image", type="numpy", height=224)] | |
gradcam_outputs = [ | |
gr.Label(label="Predictions"), | |
gr.Image(label="GradCAM Output", height=224) | |
] | |
with gr.Row(): | |
gradcam_inputs = [ | |
gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"), | |
gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"), | |
gr.Slider(1, 6, value=4, step=1, label="Target Layer Number") | |
] | |
gradcam_button = gr.Button("Generate GradCAM") | |
# Pass model to inference function using partial | |
from functools import partial | |
inference_fn = partial(inference, model=model, classes=classes) | |
gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs) | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=[ | |
["./assets/examples/dog.jpg", 0.5, 3, 4], | |
["./assets/examples/cat.jpg", 0.5, 3, 4], | |
["./assets/examples/frog.jpg", 0.5, 3, 4], | |
["./assets/examples/bird.jpg", 0.5, 3, 4], | |
["./assets/examples/shark-plane.jpg", 0.5, 3, 4], | |
["./assets/examples/car.jpg", 0.5, 3, 4], | |
["./assets/examples/truck.jpg", 0.5, 3, 4], | |
["./assets/examples/horse.jpg", 0.5, 3, 4], | |
["./assets/examples/plane.jpg", 0.5, 3, 4], | |
["./assets/examples/ship.png", 0.5, 3, 4] | |
], | |
inputs=img_input + gradcam_inputs, | |
fn=inference_fn, | |
outputs=gradcam_outputs | |
) | |
# Launch the demo (moved inside the Blocks context) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() | |