ImageNet / app.py
Shilpaj's picture
Debug: Issue with prediction
ebbea61
raw
history blame
8.48 kB
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr
# Third Party Imports
import torch
from torchvision import models
# Local Imports
from inference import inference
def load_model(model_path: str):
"""
Load the model.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load the model with default weights first
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model = model.to(device)
# Load custom weights
state_dict = torch.load(model_path, map_location=device)
# Debug: Print state dict info
print("\nState dict keys:", list(state_dict['model_state_dict'].keys())[:5])
print("Model state dict keys:", list(model.state_dict().keys())[:5])
# Check if the final layer weights match
fc_weight_shape = state_dict['model_state_dict']['fc.weight'].shape
print(f"\nFC layer weight shape: {fc_weight_shape}")
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()}
print(f"Filtered state dict size: {len(filtered_state_dict)} / {len(state_dict['model_state_dict'])}")
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
# Verify model
print("\nModel architecture:")
print(model)
return model
def load_classes():
"""
Load the classes.
"""
# Load classes from the same weights version as the model was trained with
weights = models.ResNet50_Weights.IMAGENET1K_V1 # Try V1 instead of V2
classes = weights.meta["categories"]
return classes
def inference_wrapper(image, alpha, top_k, target_layer):
"""
Wrapper function for inference with error handling
"""
try:
if image is None:
return {"Error": 1.0}, None
results = inference(
image,
alpha,
top_k,
target_layer,
model=model,
classes=classes
)
if results is None:
return {"Error": 1.0}, None
return results
except RuntimeError as e:
error_msg = str(e)
print(f"Error in inference: {error_msg}")
if "out of memory" in error_msg.lower():
return {"GPU Memory Error - Please try again": 1.0}, None
return {"Runtime Error: " + error_msg: 1.0}, None
except Exception as e:
error_msg = str(e)
print(f"Error in inference: {error_msg}")
return {"Error: " + error_msg: 1.0}, None
def main():
"""
Main function for the application.
"""
global model, classes
try:
print(f"Gradio version: {gr.__version__}")
# Load the model at startup
model = load_model("resnet50_imagenet1k.pth")
classes = load_classes()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ImageNet-1K trained on ResNet50v2
"""
)
with gr.Tab("GradCam"):
gr.Markdown(
"""
Visualize Class Activations Maps generated by the model's layer for the predicted class.
"""
)
# Define inputs
with gr.Row():
img_input = gr.Image(
label="Input Image",
type="numpy",
height=224,
width=224
)
with gr.Column():
label_output = gr.Label(label="Predictions")
gradcam_output = gr.Image(
label="GradCAM Output",
height=224,
width=224
)
with gr.Row():
alpha_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Activation Map Transparency"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Top Predictions"
)
target_layer_slider = gr.Slider(
minimum=1,
maximum=6,
value=4,
step=1,
label="Target Layer Number"
)
gradcam_button = gr.Button("Generate GradCAM")
# Set up the click event
gradcam_button.click(
fn=inference_wrapper,
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
]
)
# Examples section for Gradio 5.x
examples = [
# [
# "assets/examples/dog.jpg",
# 0.5, # alpha slider
# 3, # top_k slider
# 4 # target_layer slider
# ],
[
"assets/examples/cat.jpg",
0.5,
3,
4
],
[
"assets/examples/frog.jpg",
0.5,
3,
4
],
[
"assets/examples/bird.jpg",
0.5,
3,
4
],
# [
# "assets/examples/shark-plane.jpg",
# 0.5,
# 3,
# 4
# ],
[
"assets/examples/car.jpg",
0.5,
3,
4
],
[
"assets/examples/truck.jpg",
0.5,
3,
4
],
[
"assets/examples/horse.jpg",
0.5,
3,
4
],
[
"assets/examples/plane.jpg",
0.5,
3,
4
],
[
"assets/examples/ship.png",
0.5,
3,
4
]
]
gr.Examples(
examples=examples,
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
],
fn=inference_wrapper,
cache_examples=False, # Disable caching to prevent memory issues
label="Click on any example to run GradCAM"
)
# Queue configuration
demo.queue(max_size=1) # Only allow one job at a time
# Launch with minimal memory usage
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
except Exception as e:
print(f"Error during startup: {str(e)}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()