File size: 3,916 Bytes
f5288df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import numpy as np
import random
import yaml
import os
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import nibabel as nib
from monai.visualize.gradient_based import SmoothGrad, GuidedBackpropSmoothGrad
from dataset2 import MedicalImageDatasetBalancedIntensity3D
from load_brainiac import load_brainiac

# Fix random seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# collate funcntion (unneccerary for single timpoint input)
def custom_collate(batch):
    """Handles variable size of the scans and pads the sequence dimension."""
    images = [item['image'] for item in batch]
    labels = [item['label'] for item in batch]
    patids = [item['pat_id'] for item in batch]
    
    max_len = 1  # singlescan input
    padded_images = []
    
    for img in images:
        pad_size = max_len - img.shape[0]
        if pad_size > 0:
            padding = torch.zeros((pad_size,) + img.shape[1:])
            img_padded = torch.cat([img, padding], dim=0)
            padded_images.append(img_padded)
        else:
            padded_images.append(img)

    return {"image": torch.stack(padded_images, dim=0), "label": labels, "pat_id": patids}


def generate_saliency_maps(model, data_loader, output_dir, device):
    """Generate saliency maps using guided backprop method"""
    model.eval()
    visualizer = GuidedBackpropSmoothGrad(model=model.backbone, stdev_spread=0.15, n_samples=10, magnitude=True)
    
    for sample in tqdm(data_loader, desc="Generating saliency maps"):
        inputs = sample['image'].requires_grad_(True)
        patids = sample["pat_id"]
        imagename = patids[0]
        
        input_tensor = inputs.to(device)
        
        with torch.enable_grad():
            saliency_map = visualizer(input_tensor)
        
        # Save input image and saliency map
        inputs_np = input_tensor.squeeze().cpu().detach().numpy()
        saliency_np = saliency_map.squeeze().cpu().detach().numpy()
        
        input_nifti = nib.Nifti1Image(inputs_np, np.eye(4))
        saliency_nifti = nib.Nifti1Image(saliency_np, np.eye(4))
        
        # Save files
        nib.save(input_nifti, os.path.join(output_dir, f"{imagename}_image.nii.gz"))
        nib.save(saliency_nifti, os.path.join(output_dir, f"{imagename}_saliencymap.nii.gz"))

def main():
    parser = argparse.ArgumentParser(description='Generate saliency maps for medical images')
    parser.add_argument('--checkpoint', type=str, required=True,
                      help='Path to the model checkpoint')
    parser.add_argument('--input_csv', type=str, required=True,
                      help='Path to the input CSV file containing image paths')
    parser.add_argument('--output_dir', type=str, required=True,
                      help='Directory to save saliency maps')
    parser.add_argument('--root_dir', type=str, required=True,
                      help='Root directory containing the image data')
    
    args = parser.parse_args()
    device = torch.device("cpu")
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize dataset and dataloader
    dataset = MedicalImageDatasetBalancedIntensity3D(
        csv_path=args.input_csv,
        root_dir=args.root_dir
    )
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=1
    )
    
    # Load brainiac and ensure it's on CPU
    model = load_brainiac(args.checkpoint, device)
    model = model.to(device)
    
    # Make sure model weights are on CPU
    model.backbone = model.backbone.to(device)
    
    # Generate saliency maps
    generate_saliency_maps(model, dataloader, args.output_dir, device)
    
    print(f"Saliency maps generated and saved to {args.output_dir}")

if __name__ == "__main__":
    main()