import torch import numpy as np import random import yaml import os import argparse from tqdm import tqdm from torch.utils.data import DataLoader import nibabel as nib from monai.visualize.gradient_based import SmoothGrad, GuidedBackpropSmoothGrad from dataset2 import MedicalImageDatasetBalancedIntensity3D from load_brainiac import load_brainiac # Fix random seed seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # collate funcntion (unneccerary for single timpoint input) def custom_collate(batch): """Handles variable size of the scans and pads the sequence dimension.""" images = [item['image'] for item in batch] labels = [item['label'] for item in batch] patids = [item['pat_id'] for item in batch] max_len = 1 # singlescan input padded_images = [] for img in images: pad_size = max_len - img.shape[0] if pad_size > 0: padding = torch.zeros((pad_size,) + img.shape[1:]) img_padded = torch.cat([img, padding], dim=0) padded_images.append(img_padded) else: padded_images.append(img) return {"image": torch.stack(padded_images, dim=0), "label": labels, "pat_id": patids} def generate_saliency_maps(model, data_loader, output_dir, device): """Generate saliency maps using guided backprop method""" model.eval() visualizer = GuidedBackpropSmoothGrad(model=model.backbone, stdev_spread=0.15, n_samples=10, magnitude=True) for sample in tqdm(data_loader, desc="Generating saliency maps"): inputs = sample['image'].requires_grad_(True) patids = sample["pat_id"] imagename = patids[0] input_tensor = inputs.to(device) with torch.enable_grad(): saliency_map = visualizer(input_tensor) # Save input image and saliency map inputs_np = input_tensor.squeeze().cpu().detach().numpy() saliency_np = saliency_map.squeeze().cpu().detach().numpy() input_nifti = nib.Nifti1Image(inputs_np, np.eye(4)) saliency_nifti = nib.Nifti1Image(saliency_np, np.eye(4)) # Save files nib.save(input_nifti, os.path.join(output_dir, f"{imagename}_image.nii.gz")) nib.save(saliency_nifti, os.path.join(output_dir, f"{imagename}_saliencymap.nii.gz")) def main(): parser = argparse.ArgumentParser(description='Generate saliency maps for medical images') parser.add_argument('--checkpoint', type=str, required=True, help='Path to the model checkpoint') parser.add_argument('--input_csv', type=str, required=True, help='Path to the input CSV file containing image paths') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save saliency maps') parser.add_argument('--root_dir', type=str, required=True, help='Root directory containing the image data') args = parser.parse_args() device = torch.device("cpu") # Create output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) # Initialize dataset and dataloader dataset = MedicalImageDatasetBalancedIntensity3D( csv_path=args.input_csv, root_dir=args.root_dir ) dataloader = DataLoader( dataset, batch_size=1, shuffle=False, collate_fn=custom_collate, num_workers=1 ) # Load brainiac and ensure it's on CPU model = load_brainiac(args.checkpoint, device) model = model.to(device) # Make sure model weights are on CPU model.backbone = model.backbone.to(device) # Generate saliency maps generate_saliency_maps(model, dataloader, args.output_dir, device) print(f"Saliency maps generated and saved to {args.output_dir}") if __name__ == "__main__": main()