Spaces:
Running
Running
File size: 3,916 Bytes
f5288df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import numpy as np
import random
import yaml
import os
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import nibabel as nib
from monai.visualize.gradient_based import SmoothGrad, GuidedBackpropSmoothGrad
from dataset2 import MedicalImageDatasetBalancedIntensity3D
from load_brainiac import load_brainiac
# Fix random seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# collate funcntion (unneccerary for single timpoint input)
def custom_collate(batch):
"""Handles variable size of the scans and pads the sequence dimension."""
images = [item['image'] for item in batch]
labels = [item['label'] for item in batch]
patids = [item['pat_id'] for item in batch]
max_len = 1 # singlescan input
padded_images = []
for img in images:
pad_size = max_len - img.shape[0]
if pad_size > 0:
padding = torch.zeros((pad_size,) + img.shape[1:])
img_padded = torch.cat([img, padding], dim=0)
padded_images.append(img_padded)
else:
padded_images.append(img)
return {"image": torch.stack(padded_images, dim=0), "label": labels, "pat_id": patids}
def generate_saliency_maps(model, data_loader, output_dir, device):
"""Generate saliency maps using guided backprop method"""
model.eval()
visualizer = GuidedBackpropSmoothGrad(model=model.backbone, stdev_spread=0.15, n_samples=10, magnitude=True)
for sample in tqdm(data_loader, desc="Generating saliency maps"):
inputs = sample['image'].requires_grad_(True)
patids = sample["pat_id"]
imagename = patids[0]
input_tensor = inputs.to(device)
with torch.enable_grad():
saliency_map = visualizer(input_tensor)
# Save input image and saliency map
inputs_np = input_tensor.squeeze().cpu().detach().numpy()
saliency_np = saliency_map.squeeze().cpu().detach().numpy()
input_nifti = nib.Nifti1Image(inputs_np, np.eye(4))
saliency_nifti = nib.Nifti1Image(saliency_np, np.eye(4))
# Save files
nib.save(input_nifti, os.path.join(output_dir, f"{imagename}_image.nii.gz"))
nib.save(saliency_nifti, os.path.join(output_dir, f"{imagename}_saliencymap.nii.gz"))
def main():
parser = argparse.ArgumentParser(description='Generate saliency maps for medical images')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to the model checkpoint')
parser.add_argument('--input_csv', type=str, required=True,
help='Path to the input CSV file containing image paths')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save saliency maps')
parser.add_argument('--root_dir', type=str, required=True,
help='Root directory containing the image data')
args = parser.parse_args()
device = torch.device("cpu")
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Initialize dataset and dataloader
dataset = MedicalImageDatasetBalancedIntensity3D(
csv_path=args.input_csv,
root_dir=args.root_dir
)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=custom_collate,
num_workers=1
)
# Load brainiac and ensure it's on CPU
model = load_brainiac(args.checkpoint, device)
model = model.to(device)
# Make sure model weights are on CPU
model.backbone = model.backbone.to(device)
# Generate saliency maps
generate_saliency_maps(model, dataloader, args.output_dir, device)
print(f"Saliency maps generated and saved to {args.output_dir}")
if __name__ == "__main__":
main() |