BrainIAC-Brainage-V0 / src /BrainIAC /get_brainiac_saliencymap.py
Divyanshu Tak
Initial commit of BrainIAC Docker application
f5288df
import torch
import numpy as np
import random
import yaml
import os
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import nibabel as nib
from monai.visualize.gradient_based import SmoothGrad, GuidedBackpropSmoothGrad
from dataset2 import MedicalImageDatasetBalancedIntensity3D
from load_brainiac import load_brainiac
# Fix random seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# collate funcntion (unneccerary for single timpoint input)
def custom_collate(batch):
"""Handles variable size of the scans and pads the sequence dimension."""
images = [item['image'] for item in batch]
labels = [item['label'] for item in batch]
patids = [item['pat_id'] for item in batch]
max_len = 1 # singlescan input
padded_images = []
for img in images:
pad_size = max_len - img.shape[0]
if pad_size > 0:
padding = torch.zeros((pad_size,) + img.shape[1:])
img_padded = torch.cat([img, padding], dim=0)
padded_images.append(img_padded)
else:
padded_images.append(img)
return {"image": torch.stack(padded_images, dim=0), "label": labels, "pat_id": patids}
def generate_saliency_maps(model, data_loader, output_dir, device):
"""Generate saliency maps using guided backprop method"""
model.eval()
visualizer = GuidedBackpropSmoothGrad(model=model.backbone, stdev_spread=0.15, n_samples=10, magnitude=True)
for sample in tqdm(data_loader, desc="Generating saliency maps"):
inputs = sample['image'].requires_grad_(True)
patids = sample["pat_id"]
imagename = patids[0]
input_tensor = inputs.to(device)
with torch.enable_grad():
saliency_map = visualizer(input_tensor)
# Save input image and saliency map
inputs_np = input_tensor.squeeze().cpu().detach().numpy()
saliency_np = saliency_map.squeeze().cpu().detach().numpy()
input_nifti = nib.Nifti1Image(inputs_np, np.eye(4))
saliency_nifti = nib.Nifti1Image(saliency_np, np.eye(4))
# Save files
nib.save(input_nifti, os.path.join(output_dir, f"{imagename}_image.nii.gz"))
nib.save(saliency_nifti, os.path.join(output_dir, f"{imagename}_saliencymap.nii.gz"))
def main():
parser = argparse.ArgumentParser(description='Generate saliency maps for medical images')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to the model checkpoint')
parser.add_argument('--input_csv', type=str, required=True,
help='Path to the input CSV file containing image paths')
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save saliency maps')
parser.add_argument('--root_dir', type=str, required=True,
help='Root directory containing the image data')
args = parser.parse_args()
device = torch.device("cpu")
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Initialize dataset and dataloader
dataset = MedicalImageDatasetBalancedIntensity3D(
csv_path=args.input_csv,
root_dir=args.root_dir
)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=custom_collate,
num_workers=1
)
# Load brainiac and ensure it's on CPU
model = load_brainiac(args.checkpoint, device)
model = model.to(device)
# Make sure model weights are on CPU
model.backbone = model.backbone.to(device)
# Generate saliency maps
generate_saliency_maps(model, dataloader, args.output_dir, device)
print(f"Saliency maps generated and saved to {args.output_dir}")
if __name__ == "__main__":
main()