import torch import numpy as np import pandas as pd import random import yaml import os import argparse from tqdm import tqdm from torch.utils.data import DataLoader from dataset2 import MedicalImageDatasetBalancedIntensity3D from load_brainiac import load_brainiac # fix random seed seed = 42 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Set GPU os.environ['CUDA_VISIBLE_DEVICES'] = "0" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Define custom collate function for data loading def custom_collate(batch): images = [item['image'] for item in batch] labels = [item['label'] for item in batch] max_len = 1 padded_images = [] for img in images: pad_size = max_len - img.shape[0] if pad_size > 0: padding = torch.zeros((pad_size,) + img.shape[1:]) img_padded = torch.cat([img, padding], dim=0) padded_images.append(img_padded) else: padded_images.append(img) return {"image": torch.stack(padded_images, dim=0), "label": torch.stack(labels)} #========================= # Inference function #========================= def infer(model, test_loader): features_df = None # Placeholder for feature DataFrame model.eval() with torch.no_grad(): for sample in tqdm(test_loader, desc="Inference", unit="batch"): inputs = sample['image'].to(device) class_labels = sample['label'].float().to(device) # Get features from the model features = model(inputs) features_numpy = features.cpu().numpy() # Expand features into separate columns feature_columns = [f'Feature_{i}' for i in range(features_numpy.shape[1])] batch_features = pd.DataFrame( features_numpy, columns=feature_columns ) batch_features['GroundTruthClassLabel'] = class_labels.cpu().numpy().flatten() # Append batch features to features_df if features_df is None: features_df = batch_features else: features_df = pd.concat([features_df, batch_features], ignore_index=True) return features_df #========================= # Main inference pipeline #========================= def main(): # argparse parser = argparse.ArgumentParser(description='Extract BrainIAC features from images') parser.add_argument('--checkpoint', type=str, required=True, help='Path to the BrainIAC model checkpoint') parser.add_argument('--input_csv', type=str, required=True, help='Path to the input CSV file containing image paths') parser.add_argument('--output_csv', type=str, required=True, help='Path to save the output features CSV') parser.add_argument('--root_dir', type=str, required=True, help='Root directory containing the image data') args = parser.parse_args() # spinup the dataloader test_dataset = MedicalImageDatasetBalancedIntensity3D( csv_path=args.input_csv, root_dir=args.root_dir ) test_loader = DataLoader( test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate, num_workers=1 ) # Load brainiac model = load_brainiac(args.checkpoint, device) model = model.to(device) # infer features_df = infer(model, test_loader) # Save features features_df.to_csv(args.output_csv, index=False) print(f"Features saved to {args.output_csv}") if __name__ == "__main__": main()