#!/usr/bin/env python
"""
Inference script for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import numpy as np
import torch
from collections import OrderedDict

# Third Party Imports
import spaces
from torchvision import transforms
from torch.nn import functional as F
from torchvision.models import resnet50
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


@spaces.GPU
def inference(image, alpha, top_k, target_layer, model=None, classes=None):
    """
    Run inference with GradCAM visualization
    """
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Debug: Print model mode
        print(f"Model mode: {model.training}")
        
        # Ensure model is on correct device and in eval mode
        model = model.to(device)
        model.eval()
        
        with torch.cuda.amp.autocast():
            org_img = image.copy()

            # Convert img to tensor and normalize it
            _transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])

            # Debug: Print image tensor stats
            input_tensor = _transform(image).to(device)
            print(f"Input tensor shape: {input_tensor.shape}")
            print(f"Input tensor range: [{input_tensor.min():.2f}, {input_tensor.max():.2f}]")
            
            input_tensor = input_tensor.unsqueeze(0)
            input_tensor.requires_grad = True
            
            # Get Model Predictions
            outputs = model(input_tensor)
            print(f"Raw output shape: {outputs.shape}")
            print(f"Raw output range: [{outputs.min():.2f}, {outputs.max():.2f}]")
            
            probabilities = torch.softmax(outputs, dim=1)[0]
            print(f"Probabilities sum: {probabilities.sum():.2f}")  # Should be close to 1.0
            
            # Get top 5 predictions for debugging
            top_probs, top_indices = torch.topk(probabilities, 5)
            print("\nTop 5 predictions:")
            for idx, (prob, class_idx) in enumerate(zip(top_probs, top_indices)):
                class_name = classes[class_idx]
                print(f"{idx+1}. {class_name}: {prob:.4f}")

            # Create confidence dictionary
            confidences = {classes[i]: float(probabilities[i]) for i in range(len(classes))}
            sorted_confidences = sorted(confidences.items(), key=lambda x: x[1], reverse=True)
            show_confidences = OrderedDict(sorted_confidences[:top_k])

            # Map layer numbers to meaningful parts of the ResNet architecture
            _layers = {
                1: model.conv1,
                2: model.layer1[-1],
                3: model.layer2[-1],
                4: model.layer3[-1],
                5: model.layer4[-1],
                6: model.layer4[-1]
            }

            target_layer = min(max(target_layer, 1), 6)
            target_layers = [_layers[target_layer]]

            # Debug: Print selected layer
            print(f"\nUsing target layer: {target_layers[0]}")

            cam = GradCAM(model=model, target_layers=target_layers)
            
            # Get the most probable class index
            top_class = max(confidences.items(), key=lambda x: x[1])[0]
            class_idx = classes.index(top_class)
            print(f"\nSelected class for GradCAM: {top_class} (index: {class_idx})")
            
            grayscale_cam = cam(
                input_tensor=input_tensor,
                targets=[ClassifierOutputTarget(class_idx)],
                aug_smooth=False,
                eigen_smooth=False
            )
            grayscale_cam = grayscale_cam[0, :]

            visualization = show_cam_on_image(org_img/255., grayscale_cam, use_rgb=True, image_weight=alpha)
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            return show_confidences, visualization
            
    except Exception as e:
        print(f"Error in inference: {str(e)}")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise e