Spaces:
Sleeping
Sleeping
File size: 3,771 Bytes
077fb0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr
# Third Party Imports
import torch
from torchvision import models
# Local Imports
from inference import inference
def load_model(model_path: str):
"""
Load the model.
"""
# Load the pre-trained ResNet50 model from ImageNet
model = models.resnet50(pretrained=False)
# Load custom weights from a .pth file
state_dict = torch.load(model_path)
# Filter out unexpected keys
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
# Load the filtered state dictionary into the model
model.load_state_dict(filtered_state_dict, strict=False)
return model
def load_classes():
"""
Load the classes.
"""
# Get ImageNet class names from ResNet50 weights
classes = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"]
return classes
def main():
"""
Main function for the application.
"""
# Load the model at startup
model = load_model("resnet50_imagenet1k.pth")
# Load the classes at startup
classes = load_classes()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ImageNet-1K trained on ResNet50v2
"""
)
# #############################################################################
# ################################ GradCam Tab ################################
# #############################################################################
with gr.Tab("GradCam"):
gr.Markdown(
"""
Visualize Class Activations Maps generated by the model's layer for the predicted class.
This is used to see what the model is actually looking at in the image.
"""
)
with gr.Row():
# Update the image input dimensions
img_input = [gr.Image(label="Input Image", type="numpy", height=224)] # Changed dimensions
gradcam_outputs = [
gr.Label(label="Predictions"),
gr.Image(label="GradCAM Output", height=224) # Match input image height
]
with gr.Row():
gradcam_inputs = [
gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"),
gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"),
gr.Slider(1, 6, value=4, step=1, label="Target Layer Number")
]
gradcam_button = gr.Button("Generate GradCAM")
# Pass model to inference function using partial
from functools import partial
inference_fn = partial(inference, model=model, classes=classes)
gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs)
gr.Markdown("## Examples")
gr.Examples(
examples=[
["./assets/examples/dog.jpg", 0.5, 3, 4],
["./assets/examples/cat.jpg", 0.5, 3, 4],
["./assets/examples/frog.jpg", 0.5, 3, 4],
["./assets/examples/bird.jpg", 0.5, 3, 4],
["./assets/examples/shark-plane.jpg", 0.5, 3, 4],
["./assets/examples/car.jpg", 0.5, 3, 4],
["./assets/examples/truck.jpg", 0.5, 3, 4],
["./assets/examples/horse.jpg", 0.5, 3, 4],
["./assets/examples/plane.jpg", 0.5, 3, 4],
["./assets/examples/ship.png", 0.5, 3, 4]
],
inputs=img_input + gradcam_inputs,
fn=inference_fn,
outputs=gradcam_outputs
)
gr.close_all()
demo.launch(debug=True)
if __name__ == "__main__":
main()
|