ImageNet / app.py
Shilpaj's picture
Fix: Runtime issue
430f33e
raw
history blame
6.67 kB
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr
# Third Party Imports
import torch
from torchvision import models
# Local Imports
from inference import inference
def load_model(model_path: str):
"""
Load the model.
"""
# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load the pre-trained ResNet50 model
model = models.resnet50(weights=None)
model = model.to(device)
# Load custom weights from a .pth file
state_dict = torch.load(model_path, map_location=device)
# Filter out unexpected keys
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
# Load the filtered state dictionary into the model
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
return model
def load_classes():
"""
Load the classes.
"""
# Get ImageNet class names from ResNet50 weights
weights = models.ResNet50_Weights.IMAGENET1K_V2
classes = weights.meta["categories"]
return classes
def inference_wrapper(image, alpha, top_k, target_layer):
"""
Wrapper function for inference with error handling
"""
try:
if image is None:
return {"No image provided": 1.0}, None
results = inference(
image,
alpha,
top_k,
target_layer,
model=model,
classes=classes
)
if results is None:
return {"Processing failed": 1.0}, None
return results
except Exception as e:
error_msg = str(e)
print(f"Error in inference: {error_msg}")
# Handle GPU quota error specifically
if "GPU quota" in error_msg:
return {"GPU quota exceeded - Please try again later": 1.0}, None
# Handle other errors
return {"Error: " + error_msg: 1.0}, None
def main():
"""
Main function for the application.
"""
global model, classes
try:
print(f"Gradio version: {gr.__version__}")
# Load the model at startup
model = load_model("resnet50_imagenet1k.pth")
# Load the classes at startup
classes = load_classes()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ImageNet-1K trained on ResNet50v2
"""
)
with gr.Tab("GradCam"):
gr.Markdown(
"""
Visualize Class Activations Maps generated by the model's layer for the predicted class.
"""
)
# Define inputs
with gr.Row():
img_input = gr.Image(
label="Input Image",
type="numpy",
height=224,
width=224
)
with gr.Column():
label_output = gr.Label(label="Predictions")
gradcam_output = gr.Image(
label="GradCAM Output",
height=224,
width=224
)
with gr.Row():
alpha_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Activation Map Transparency"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Top Predictions"
)
target_layer_slider = gr.Slider(
minimum=1,
maximum=6,
value=4,
step=1,
label="Target Layer Number"
)
gradcam_button = gr.Button("Generate GradCAM")
# Set up the click event
gradcam_button.click(
fn=inference_wrapper,
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
]
)
# Example section
gr.Examples(
examples=[
["assets/examples/dog.jpg", 0.5, 3, 4],
["assets/examples/cat.jpg", 0.5, 3, 4],
["assets/examples/frog.jpg", 0.5, 3, 4],
["assets/examples/bird.jpg", 0.5, 3, 4],
["assets/examples/shark-plane.jpg", 0.5, 3, 4],
["assets/examples/car.jpg", 0.5, 3, 4],
["assets/examples/truck.jpg", 0.5, 3, 4],
["assets/examples/horse.jpg", 0.5, 3, 4],
["assets/examples/plane.jpg", 0.5, 3, 4],
["assets/examples/ship.png", 0.5, 3, 4]
],
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
],
fn=inference_wrapper,
cache_examples=True,
label="Click on any example to run GradCAM"
)
# Configure queue with new syntax for Gradio 5.x
demo.queue(max_size=1) # Limit to 1 concurrent job
# Launch with compatible parameters for Gradio 5.x
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
print(f"Error during startup: {str(e)}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()