#!/usr/bin/env python """ Application for ResNet50 trained on ImageNet-1K. """ # Standard Library Imports import gradio as gr # Third Party Imports import torch from torchvision import models # Local Imports from inference import inference def load_model(model_path: str): """ Load the model. """ # Check if CUDA is available and set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load the pre-trained ResNet50 model model = models.resnet50(weights=None) model = model.to(device) # Load custom weights from a .pth file state_dict = torch.load(model_path, map_location=device) # Filter out unexpected keys filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} # Load the filtered state dictionary into the model model.load_state_dict(filtered_state_dict, strict=False) model.eval() return model def load_classes(): """ Load the classes. """ # Get ImageNet class names from ResNet50 weights weights = models.ResNet50_Weights.IMAGENET1K_V2 classes = weights.meta["categories"] return classes def inference_wrapper(image, alpha, top_k, target_layer): """ Wrapper function for inference with error handling """ try: if image is None: return {"No image provided": 1.0}, None results = inference( image, alpha, top_k, target_layer, model=model, classes=classes ) if results is None: return {"Processing failed": 1.0}, None return results except Exception as e: error_msg = str(e) print(f"Error in inference: {error_msg}") # Handle GPU quota error specifically if "GPU quota" in error_msg: return {"GPU quota exceeded - Please try again later": 1.0}, None # Handle other errors return {"Error: " + error_msg: 1.0}, None def main(): """ Main function for the application. """ global model, classes try: print(f"Gradio version: {gr.__version__}") # Load the model at startup model = load_model("resnet50_imagenet1k.pth") # Load the classes at startup classes = load_classes() with gr.Blocks() as demo: gr.Markdown( """ # ImageNet-1K trained on ResNet50v2 """ ) with gr.Tab("GradCam"): gr.Markdown( """ Visualize Class Activations Maps generated by the model's layer for the predicted class. """ ) # Define inputs with gr.Row(): img_input = gr.Image( label="Input Image", type="numpy", height=224, width=224 ) with gr.Column(): label_output = gr.Label(label="Predictions") gradcam_output = gr.Image( label="GradCAM Output", height=224, width=224 ) with gr.Row(): alpha_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.1, label="Activation Map Transparency" ) top_k_slider = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Top Predictions" ) target_layer_slider = gr.Slider( minimum=1, maximum=6, value=4, step=1, label="Target Layer Number" ) gradcam_button = gr.Button("Generate GradCAM") # Set up the click event gradcam_button.click( fn=inference_wrapper, inputs=[ img_input, alpha_slider, top_k_slider, target_layer_slider ], outputs=[ label_output, gradcam_output ] ) # Example section gr.Examples( examples=[ ["assets/examples/dog.jpg", 0.5, 3, 4], ["assets/examples/cat.jpg", 0.5, 3, 4], ["assets/examples/frog.jpg", 0.5, 3, 4], ["assets/examples/bird.jpg", 0.5, 3, 4], ["assets/examples/shark-plane.jpg", 0.5, 3, 4], ["assets/examples/car.jpg", 0.5, 3, 4], ["assets/examples/truck.jpg", 0.5, 3, 4], ["assets/examples/horse.jpg", 0.5, 3, 4], ["assets/examples/plane.jpg", 0.5, 3, 4], ["assets/examples/ship.png", 0.5, 3, 4] ], inputs=[ img_input, alpha_slider, top_k_slider, target_layer_slider ], outputs=[ label_output, gradcam_output ], fn=inference_wrapper, cache_examples=True, label="Click on any example to run GradCAM" ) # Configure queue demo.queue(concurrency_count=1) # Limit concurrent processing # Launch with compatible parameters demo.launch( server_name="0.0.0.0", server_port=7860, share=False ) except Exception as e: print(f"Error during startup: {str(e)}") if torch.cuda.is_available(): torch.cuda.empty_cache() if __name__ == "__main__": main()