import torch
import numpy as np

def sliding_window_inference(model, image, window_size, overlap=0.5):
    """Perform sliding window inference on large images"""
    model.eval()

    # Get dimensions
    _, height, width = image.shape
    stride = int(window_size * (1 - overlap))

    # Calculate number of windows needed
    n_h = int(np.ceil((height - window_size) / stride) + 1)
    n_w = int(np.ceil((width - window_size) / stride) + 1)

    # Create empty prediction map and count map for averaging
    pred_map = torch.zeros((1, height, width)).to(model.device)
    count_map = torch.zeros((1, height, width)).to(model.device)

    # Slide window over image
    with torch.no_grad():
        for i in range(n_h):
            for j in range(n_w):
                # Calculate window boundaries
                h_start = min(i * stride, height - window_size)
                w_start = min(j * stride, width - window_size)
                h_end = h_start + window_size
                w_end = w_start + window_size

                # Extract window
                window = image[:, h_start:h_end, w_start:w_end]

                # If window is smaller than window_size, pad it
                if window.shape[1:] != (window_size, window_size):
                    pad_h = window_size - window.shape[1]
                    pad_w = window_size - window.shape[2]
                    window = torch.nn.functional.pad(window, (0, pad_w, 0, pad_h))

                # Make prediction
                window = window.unsqueeze(0)  # Add batch dimension
                pred = model(window)
                pred = pred.squeeze(0)  # Remove batch dimension

                # If window was padded, remove padding from prediction
                if window.shape[2] - h_end + h_start > 0 or window.shape[3] - w_end + w_start > 0:
                    pred = pred[:, :h_end - h_start, :w_end - w_start]

                # Add prediction to map
                pred_map[:, h_start:h_end, w_start:w_end] += pred
                count_map[:, h_start:h_end, w_start:w_end] += 1

    # Average overlapping predictions
    final_pred = pred_map / count_map
    return final_pred.cpu()