|
import torch |
|
import numpy as np |
|
|
|
def sliding_window_inference(model, image, window_size, overlap=0.5): |
|
"""Perform sliding window inference on large images""" |
|
model.eval() |
|
|
|
|
|
_, height, width = image.shape |
|
stride = int(window_size * (1 - overlap)) |
|
|
|
|
|
n_h = int(np.ceil((height - window_size) / stride) + 1) |
|
n_w = int(np.ceil((width - window_size) / stride) + 1) |
|
|
|
|
|
pred_map = torch.zeros((1, height, width)).to(model.device) |
|
count_map = torch.zeros((1, height, width)).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
for i in range(n_h): |
|
for j in range(n_w): |
|
|
|
h_start = min(i * stride, height - window_size) |
|
w_start = min(j * stride, width - window_size) |
|
h_end = h_start + window_size |
|
w_end = w_start + window_size |
|
|
|
|
|
window = image[:, h_start:h_end, w_start:w_end] |
|
|
|
|
|
if window.shape[1:] != (window_size, window_size): |
|
pad_h = window_size - window.shape[1] |
|
pad_w = window_size - window.shape[2] |
|
window = torch.nn.functional.pad(window, (0, pad_w, 0, pad_h)) |
|
|
|
|
|
window = window.unsqueeze(0) |
|
pred = model(window) |
|
pred = pred.squeeze(0) |
|
|
|
|
|
if window.shape[2] - h_end + h_start > 0 or window.shape[3] - w_end + w_start > 0: |
|
pred = pred[:, :h_end - h_start, :w_end - w_start] |
|
|
|
|
|
pred_map[:, h_start:h_end, w_start:w_end] += pred |
|
count_map[:, h_start:h_end, w_start:w_end] += 1 |
|
|
|
|
|
final_pred = pred_map / count_map |
|
return final_pred.cpu() |