Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
""" | |
Application for ResNet50 trained on ImageNet-1K. | |
""" | |
# Standard Library Imports | |
import gradio as gr | |
# Third Party Imports | |
import torch | |
from torchvision import models | |
# Local Imports | |
from inference import inference | |
def load_model(model_path: str): | |
""" | |
Load the model. | |
""" | |
# Check if CUDA is available and set device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Load the pre-trained ResNet50 model | |
model = models.resnet50(weights=None) | |
model = model.to(device) | |
# Load custom weights from a .pth file | |
state_dict = torch.load(model_path, map_location=device) | |
# Filter out unexpected keys | |
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} | |
# Load the filtered state dictionary into the model | |
model.load_state_dict(filtered_state_dict, strict=False) | |
model.eval() | |
return model | |
def load_classes(): | |
""" | |
Load the classes. | |
""" | |
# Get ImageNet class names from ResNet50 weights | |
weights = models.ResNet50_Weights.IMAGENET1K_V2 | |
classes = weights.meta["categories"] | |
return classes | |
def inference_wrapper(image, alpha, top_k, target_layer): | |
""" | |
Wrapper function for inference with error handling | |
""" | |
try: | |
if image is None: | |
return None, None | |
with torch.cuda.amp.autocast(): # Enable automatic mixed precision | |
with torch.no_grad(): # Disable gradient calculation | |
return inference( | |
image, | |
alpha, | |
top_k, | |
target_layer, | |
model=model, | |
classes=classes | |
) | |
except Exception as e: | |
print(f"Error in inference: {str(e)}") | |
return gr.Error(f"Error processing image: {str(e)}") | |
def main(): | |
""" | |
Main function for the application. | |
""" | |
global model, classes # Make these global so they're accessible to inference_wrapper | |
# Load the model at startup | |
model = load_model("resnet50_imagenet1k.pth") | |
# Load the classes at startup | |
classes = load_classes() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# ImageNet-1K trained on ResNet50v2 | |
""" | |
) | |
with gr.Tab("GradCam"): | |
gr.Markdown( | |
""" | |
Visualize Class Activations Maps generated by the model's layer for the predicted class. | |
""" | |
) | |
# Define inputs | |
with gr.Row(): | |
img_input = gr.Image( | |
label="Input Image", | |
type="numpy", | |
height=224, | |
width=224 | |
) | |
with gr.Column(): | |
label_output = gr.Label(label="Predictions") | |
gradcam_output = gr.Image( | |
label="GradCAM Output", | |
height=224, | |
width=224 | |
) | |
with gr.Row(): | |
alpha_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Activation Map Transparency" | |
) | |
top_k_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Top Predictions" | |
) | |
target_layer_slider = gr.Slider( | |
minimum=1, | |
maximum=6, | |
value=4, | |
step=1, | |
label="Target Layer Number" | |
) | |
gradcam_button = gr.Button("Generate GradCAM") | |
# Set up the click event | |
gradcam_button.click( | |
fn=inference_wrapper, | |
inputs=[ | |
img_input, | |
alpha_slider, | |
top_k_slider, | |
target_layer_slider | |
], | |
outputs=[ | |
label_output, | |
gradcam_output | |
] | |
) | |
# Example section | |
gr.Examples( | |
examples=[ | |
["assets/examples/dog.jpg", 0.5, 3, 4], | |
["assets/examples/cat.jpg", 0.5, 3, 4], | |
["assets/examples/frog.jpg", 0.5, 3, 4], | |
["assets/examples/bird.jpg", 0.5, 3, 4], | |
["assets/examples/shark-plane.jpg", 0.5, 3, 4], | |
["assets/examples/car.jpg", 0.5, 3, 4], | |
["assets/examples/truck.jpg", 0.5, 3, 4], | |
["assets/examples/horse.jpg", 0.5, 3, 4], | |
["assets/examples/plane.jpg", 0.5, 3, 4], | |
["assets/examples/ship.png", 0.5, 3, 4] | |
], | |
inputs=[ | |
img_input, | |
alpha_slider, | |
top_k_slider, | |
target_layer_slider | |
], | |
outputs=[ | |
label_output, | |
gradcam_output | |
], | |
fn=inference_wrapper, | |
cache_examples=True, | |
label="Click on any example to run GradCAM" | |
) | |
# Launch the demo | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=True, | |
enable_queue=True, | |
show_error=True, | |
max_threads=4 | |
) | |
if __name__ == "__main__": | |
main() | |