#!/usr/bin/env python """ Application for ResNet50 trained on ImageNet-1K. """ # Standard Library Imports import gradio as gr # Third Party Imports import torch from torchvision import models # Local Imports from inference import inference def load_model(model_path: str): """ Load the model. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Initialize a fresh model without pretrained weights model = models.resnet50(weights=None) model = model.to(device) # Load custom weights state_dict = torch.load(model_path, map_location=device) # Debug: Print original state dict keys print("\nOriginal state dict keys:", list(state_dict['model_state_dict'].keys())[:5]) # Remove the 'model.' prefix from state dict keys new_state_dict = {} for key, value in state_dict['model_state_dict'].items(): new_key = key.replace('model.', '') new_state_dict[new_key] = value # Debug: Print modified state dict keys print("Modified state dict keys:", list(new_state_dict.keys())[:5]) print("Model state dict keys:", list(model.state_dict().keys())[:5]) # Load the modified state dict try: model.load_state_dict(new_state_dict) print("Successfully loaded model weights") except Exception as e: print(f"Error loading state dict: {str(e)}") raise e model.eval() return model def load_classes(): """ Load the ImageNet classes """ weights = models.ResNet50_Weights.IMAGENET1K_V1 classes = weights.meta["categories"] print(f"Loaded {len(classes)} classes") return classes def inference_wrapper(image, alpha, top_k, target_layer): """ Wrapper function for inference with error handling """ try: if image is None: return {"Error": 1.0}, None results = inference( image, alpha, top_k, target_layer, model=model, classes=classes ) if results is None: return {"Error": 1.0}, None return results except RuntimeError as e: error_msg = str(e) print(f"Error in inference: {error_msg}") if "out of memory" in error_msg.lower(): return {"GPU Memory Error - Please try again": 1.0}, None return {"Runtime Error: " + error_msg: 1.0}, None except Exception as e: error_msg = str(e) print(f"Error in inference: {error_msg}") return {"Error: " + error_msg: 1.0}, None def main(): """ Main function for the application. """ global model, classes try: print(f"Gradio version: {gr.__version__}") # Load the model at startup model = load_model("resnet50_imagenet1k.pth") classes = load_classes() with gr.Blocks() as demo: gr.Markdown( """ # ResNet50 trained on ImageNet-1K A large-scale image classification dataset with 1.2 million training images across 1,000 object categories. """ ) with gr.Tab("Predictions & GradCAM"): gr.Markdown( """ View model predictions and visualize where the model is looking using GradCAM. ## Steps to use: 1. Upload an image or select one from the examples below 2. Adjust the sliders (optional): - Activation Map Transparency: Controls the blend between original image and activation map - Number of Top Predictions: How many top class predictions to show - Target Layer Number: Which network layer to visualize (deeper layers show higher-level features) 3. Click "Generate GradCAM" to run the model 4. View the results: - Left: Original uploaded image - Right: Model predictions and GradCAM visualization showing where the model focused """ ) # Define inputs with gr.Row(): img_input = gr.Image( label="Input Image", type="numpy", height=224, width=224 ) with gr.Column(): label_output = gr.Label(label="Predictions") gradcam_output = gr.Image( label="GradCAM Output", height=224, width=224 ) with gr.Row(): alpha_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.1, label="Activation Map Transparency" ) top_k_slider = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Top Predictions" ) target_layer_slider = gr.Slider( minimum=1, maximum=6, value=4, step=1, label="Target Layer Number" ) gradcam_button = gr.Button("Generate GradCAM") # Set up the click event gradcam_button.click( fn=inference_wrapper, inputs=[ img_input, alpha_slider, top_k_slider, target_layer_slider ], outputs=[ label_output, gradcam_output ] ) # Examples section for Gradio 5.x examples = [ [ "assets/examples/cat.jpg", 0.5, 3, 4 ], [ "assets/examples/frog.jpg", 0.5, 3, 4 ], [ "assets/examples/bird.jpg", 0.5, 3, 4 ], [ "assets/examples/car.jpg", 0.5, 3, 4 ], [ "assets/examples/truck.jpg", 0.5, 3, 4 ], [ "assets/examples/horse.jpg", 0.5, 3, 4 ], [ "assets/examples/plane.jpg", 0.5, 3, 4 ], [ "assets/examples/ship.png", 0.5, 3, 4 ] ] gr.Examples( examples=examples, inputs=[ img_input, alpha_slider, top_k_slider, target_layer_slider ], outputs=[ label_output, gradcam_output ], fn=inference_wrapper, cache_examples=False, # Disable caching to prevent memory issues label="Click on any example to run GradCAM" ) # Queue configuration demo.queue(max_size=1) # Only allow one job at a time # Launch with minimal memory usage demo.launch( server_name="0.0.0.0", server_port=7860, share=True ) except Exception as e: print(f"Error during startup: {str(e)}") if torch.cuda.is_available(): torch.cuda.empty_cache() if __name__ == "__main__": main()