#!/usr/bin/env python """ Application for ResNet50 trained on ImageNet-1K. """ # Standard Library Imports import gradio as gr # Third Party Imports import torch from torchvision import models # Local Imports from inference import inference def load_model(model_path: str): """ Load the model. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load the model with default weights first model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) model = model.to(device) # Load custom weights state_dict = torch.load(model_path, map_location=device) # Debug: Print state dict info print("\nState dict keys:", list(state_dict['model_state_dict'].keys())[:5]) print("Model state dict keys:", list(model.state_dict().keys())[:5]) # Check if the final layer weights match fc_weight_shape = state_dict['model_state_dict']['fc.weight'].shape print(f"\nFC layer weight shape: {fc_weight_shape}") filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} print(f"Filtered state dict size: {len(filtered_state_dict)} / {len(state_dict['model_state_dict'])}") model.load_state_dict(filtered_state_dict, strict=False) model.eval() # Verify model print("\nModel architecture:") print(model) return model def load_classes(): """ Load the classes. """ # Load classes from the same weights version as the model was trained with weights = models.ResNet50_Weights.IMAGENET1K_V1 # Try V1 instead of V2 classes = weights.meta["categories"] return classes def inference_wrapper(image, alpha, top_k, target_layer): """ Wrapper function for inference with error handling """ try: if image is None: return {"Error": 1.0}, None results = inference( image, alpha, top_k, target_layer, model=model, classes=classes ) if results is None: return {"Error": 1.0}, None return results except RuntimeError as e: error_msg = str(e) print(f"Error in inference: {error_msg}") if "out of memory" in error_msg.lower(): return {"GPU Memory Error - Please try again": 1.0}, None return {"Runtime Error: " + error_msg: 1.0}, None except Exception as e: error_msg = str(e) print(f"Error in inference: {error_msg}") return {"Error: " + error_msg: 1.0}, None def main(): """ Main function for the application. """ global model, classes try: print(f"Gradio version: {gr.__version__}") # Load the model at startup model = load_model("resnet50_imagenet1k.pth") classes = load_classes() with gr.Blocks() as demo: gr.Markdown( """ # ImageNet-1K trained on ResNet50v2 """ ) with gr.Tab("GradCam"): gr.Markdown( """ Visualize Class Activations Maps generated by the model's layer for the predicted class. """ ) # Define inputs with gr.Row(): img_input = gr.Image( label="Input Image", type="numpy", height=224, width=224 ) with gr.Column(): label_output = gr.Label(label="Predictions") gradcam_output = gr.Image( label="GradCAM Output", height=224, width=224 ) with gr.Row(): alpha_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.1, label="Activation Map Transparency" ) top_k_slider = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Top Predictions" ) target_layer_slider = gr.Slider( minimum=1, maximum=6, value=4, step=1, label="Target Layer Number" ) gradcam_button = gr.Button("Generate GradCAM") # Set up the click event gradcam_button.click( fn=inference_wrapper, inputs=[ img_input, alpha_slider, top_k_slider, target_layer_slider ], outputs=[ label_output, gradcam_output ] ) # Examples section for Gradio 5.x examples = [ # [ # "assets/examples/dog.jpg", # 0.5, # alpha slider # 3, # top_k slider # 4 # target_layer slider # ], [ "assets/examples/cat.jpg", 0.5, 3, 4 ], [ "assets/examples/frog.jpg", 0.5, 3, 4 ], [ "assets/examples/bird.jpg", 0.5, 3, 4 ], # [ # "assets/examples/shark-plane.jpg", # 0.5, # 3, # 4 # ], [ "assets/examples/car.jpg", 0.5, 3, 4 ], [ "assets/examples/truck.jpg", 0.5, 3, 4 ], [ "assets/examples/horse.jpg", 0.5, 3, 4 ], [ "assets/examples/plane.jpg", 0.5, 3, 4 ], [ "assets/examples/ship.png", 0.5, 3, 4 ] ] gr.Examples( examples=examples, inputs=[ img_input, alpha_slider, top_k_slider, target_layer_slider ], outputs=[ label_output, gradcam_output ], fn=inference_wrapper, cache_examples=False, # Disable caching to prevent memory issues label="Click on any example to run GradCAM" ) # Queue configuration demo.queue(max_size=1) # Only allow one job at a time # Launch with minimal memory usage demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: print(f"Error during startup: {str(e)}") if torch.cuda.is_available(): torch.cuda.empty_cache() if __name__ == "__main__": main()