Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
""" | |
Application for ResNet50 trained on ImageNet-1K. | |
""" | |
# Standard Library Imports | |
import gradio as gr | |
# Third Party Imports | |
import torch | |
from torchvision import models | |
# Local Imports | |
from inference import inference | |
def load_model(model_path: str): | |
""" | |
Load the model. | |
""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Load the model with default weights first | |
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
model = model.to(device) | |
# Load custom weights | |
state_dict = torch.load(model_path, map_location=device) | |
# Debug: Print state dict info | |
print("\nState dict keys:", list(state_dict['model_state_dict'].keys())[:5]) | |
print("Model state dict keys:", list(model.state_dict().keys())[:5]) | |
# Check if the final layer weights match | |
fc_weight_shape = state_dict['model_state_dict']['fc.weight'].shape | |
print(f"\nFC layer weight shape: {fc_weight_shape}") | |
filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} | |
print(f"Filtered state dict size: {len(filtered_state_dict)} / {len(state_dict['model_state_dict'])}") | |
model.load_state_dict(filtered_state_dict, strict=False) | |
model.eval() | |
# Verify model | |
print("\nModel architecture:") | |
print(model) | |
return model | |
def load_classes(): | |
""" | |
Load the classes. | |
""" | |
# Load classes from the same weights version as the model was trained with | |
weights = models.ResNet50_Weights.IMAGENET1K_V1 # Try V1 instead of V2 | |
classes = weights.meta["categories"] | |
return classes | |
def inference_wrapper(image, alpha, top_k, target_layer): | |
""" | |
Wrapper function for inference with error handling | |
""" | |
try: | |
if image is None: | |
return {"Error": 1.0}, None | |
results = inference( | |
image, | |
alpha, | |
top_k, | |
target_layer, | |
model=model, | |
classes=classes | |
) | |
if results is None: | |
return {"Error": 1.0}, None | |
return results | |
except RuntimeError as e: | |
error_msg = str(e) | |
print(f"Error in inference: {error_msg}") | |
if "out of memory" in error_msg.lower(): | |
return {"GPU Memory Error - Please try again": 1.0}, None | |
return {"Runtime Error: " + error_msg: 1.0}, None | |
except Exception as e: | |
error_msg = str(e) | |
print(f"Error in inference: {error_msg}") | |
return {"Error: " + error_msg: 1.0}, None | |
def main(): | |
""" | |
Main function for the application. | |
""" | |
global model, classes | |
try: | |
print(f"Gradio version: {gr.__version__}") | |
# Load the model at startup | |
model = load_model("resnet50_imagenet1k.pth") | |
classes = load_classes() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# ImageNet-1K trained on ResNet50v2 | |
""" | |
) | |
with gr.Tab("GradCam"): | |
gr.Markdown( | |
""" | |
Visualize Class Activations Maps generated by the model's layer for the predicted class. | |
""" | |
) | |
# Define inputs | |
with gr.Row(): | |
img_input = gr.Image( | |
label="Input Image", | |
type="numpy", | |
height=224, | |
width=224 | |
) | |
with gr.Column(): | |
label_output = gr.Label(label="Predictions") | |
gradcam_output = gr.Image( | |
label="GradCAM Output", | |
height=224, | |
width=224 | |
) | |
with gr.Row(): | |
alpha_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Activation Map Transparency" | |
) | |
top_k_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Top Predictions" | |
) | |
target_layer_slider = gr.Slider( | |
minimum=1, | |
maximum=6, | |
value=4, | |
step=1, | |
label="Target Layer Number" | |
) | |
gradcam_button = gr.Button("Generate GradCAM") | |
# Set up the click event | |
gradcam_button.click( | |
fn=inference_wrapper, | |
inputs=[ | |
img_input, | |
alpha_slider, | |
top_k_slider, | |
target_layer_slider | |
], | |
outputs=[ | |
label_output, | |
gradcam_output | |
] | |
) | |
# Examples section for Gradio 5.x | |
examples = [ | |
# [ | |
# "assets/examples/dog.jpg", | |
# 0.5, # alpha slider | |
# 3, # top_k slider | |
# 4 # target_layer slider | |
# ], | |
[ | |
"assets/examples/cat.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/frog.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/bird.jpg", | |
0.5, | |
3, | |
4 | |
], | |
# [ | |
# "assets/examples/shark-plane.jpg", | |
# 0.5, | |
# 3, | |
# 4 | |
# ], | |
[ | |
"assets/examples/car.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/truck.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/horse.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/plane.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/ship.png", | |
0.5, | |
3, | |
4 | |
] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
img_input, | |
alpha_slider, | |
top_k_slider, | |
target_layer_slider | |
], | |
outputs=[ | |
label_output, | |
gradcam_output | |
], | |
fn=inference_wrapper, | |
cache_examples=False, # Disable caching to prevent memory issues | |
label="Click on any example to run GradCAM" | |
) | |
# Queue configuration | |
demo.queue(max_size=1) # Only allow one job at a time | |
# Launch with minimal memory usage | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |
except Exception as e: | |
print(f"Error during startup: {str(e)}") | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
main() | |