Spaces:
Running
Running
File size: 8,017 Bytes
f5288df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import sys
import os
import glob
import SimpleITK as sitk
from tqdm import tqdm
import random
from HD_BET.hd_bet import hd_bet
import argparse
import torch
def brain_extraction(input_dir, output_dir, device):
"""
Brain extraction using HDBET package (UNet based DL method)
Args:
input_dir {path} -- input directory for registered images
output_dir {path} -- output directory for brain extracted images
Returns:
Brain images
"""
print("Running brain extraction...")
print(f"Input directory: {input_dir}")
print(f"Output directory: {output_dir}")
# Run HD-BET directly with the output directory
hd_bet(input_dir, output_dir, device=device, mode='fast', tta=0)
print('Brain extraction complete!')
print("\nContents of output directory after brain extraction:")
print(os.listdir(output_dir))
def registration(input_dir, output_dir, temp_img, interp_type='linear'):
"""
MRI registration with SimpleITK
Args:
input_dir {path} -- Directory containing input images
output_dir {path} -- Directory to save registered images
temp_img {str} -- Registration image template
Returns:
The sitk image object -- nii.gz
"""
# Read the template image
fixed_img = sitk.ReadImage(temp_img, sitk.sitkFloat32)
# Track problematic files
IDs = []
print("Preloading step...")
for img_dir in tqdm(sorted(glob.glob(input_dir + '/*.nii.gz'))):
ID = img_dir.split('/')[-1].split('.')[0]
try:
moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32)
except Exception as e:
IDs.append(ID)
print(f"Error loading {ID}: {e}")
count = 0
print("Registering images...")
list_of_files = sorted(glob.glob(input_dir + '/*.nii.gz'))
for img_dir in tqdm(list_of_files):
ID = img_dir.split('/')[-1].split('.')[0]
if ID in IDs:
print(f'Skipping problematic file: {ID}')
continue
if "_mask" in ID:
continue
print(f"Processing image {count + 1}: {ID}")
try:
# Read and preprocess moving image
moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32)
moving_img = sitk.N4BiasFieldCorrection(moving_img)
# Resample fixed image to 1mm isotropic
old_size = fixed_img.GetSize()
old_spacing = fixed_img.GetSpacing()
new_spacing = (1, 1, 1)
new_size = [
int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))),
int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))),
int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2])))
]
# Set interpolation type
if interp_type == 'linear':
interp_type = sitk.sitkLinear
elif interp_type == 'bspline':
interp_type = sitk.sitkBSpline
elif interp_type == 'nearest_neighbor':
interp_type = sitk.sitkNearestNeighbor
# Resample fixed image
resample = sitk.ResampleImageFilter()
resample.SetOutputSpacing(new_spacing)
resample.SetSize(new_size)
resample.SetOutputOrigin(fixed_img.GetOrigin())
resample.SetOutputDirection(fixed_img.GetDirection())
resample.SetInterpolator(interp_type)
resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue())
resample.SetOutputPixelType(sitk.sitkFloat32)
fixed_img = resample.Execute(fixed_img)
# Initialize transform
transform = sitk.CenteredTransformInitializer(
fixed_img,
moving_img,
sitk.Euler3DTransform(),
sitk.CenteredTransformInitializerFilter.GEOMETRY)
# Set up registration method
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(
learningRate=1.0,
numberOfIterations=100,
convergenceMinimumValue=1e-6,
convergenceWindowSize=10)
registration_method.SetOptimizerScalesFromPhysicalShift()
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
registration_method.SetInitialTransform(transform)
# Execute registration
final_transform = registration_method.Execute(fixed_img, moving_img)
# Apply transform and save registered image
moving_img_resampled = sitk.Resample(
moving_img,
fixed_img,
final_transform,
sitk.sitkLinear,
0.0,
moving_img.GetPixelID())
# Save with _0000 suffix as required by HD-BET
output_filename = os.path.join(output_dir, f"{ID}_0000.nii.gz")
sitk.WriteImage(moving_img_resampled, output_filename)
print(f"Saved registered image to: {output_filename}")
count += 1
except Exception as e:
print(f"Error processing {ID}: {e}")
continue
print(f"Successfully registered {count} images.")
# Debug information
print(f"Contents of output directory {output_dir}:")
print(os.listdir(output_dir))
return count > 0
def main(temp_img, input_dir, output_dir):
"""
Main function to process brain MRI images
Args:
temp_img {str} -- Path to template image
input_dir {str} -- Path to input directory containing images
output_dir {str} -- Path to output directory for results
"""
os.makedirs(output_dir, exist_ok=True)
# set device
device = "0" if torch.cuda.is_available() else "cpu"
# Create temporary directory for intermediate results
temp_reg_dir = os.path.join(output_dir, 'temp_registered')
os.makedirs(temp_reg_dir, exist_ok=True)
print("Starting brain MRI preprocessing...")
# REgistration
print("\nStep 1: Image Registration")
success = registration(
input_dir=input_dir,
output_dir=temp_reg_dir,
temp_img=temp_img
)
if not success:
print("Registration failed! No images were processed successfully.")
return
print("\nChecking temporary directory contents:")
print(os.listdir(temp_reg_dir))
# skullstripping
print("\nStep 2: Brain Extraction")
brain_extraction(
input_dir=temp_reg_dir,
output_dir=output_dir,
device=device
)
# Clean up temporary directory
import shutil
shutil.rmtree(temp_reg_dir)
print("\nPreprocessing complete! Final results saved in:", output_dir)
print("Final preprocessed files:")
print(os.listdir(output_dir))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Process brain MRI registration and skull stripping.")
parser.add_argument("--temp_img", type=str, required=True, help="Path to the atlas template image.")
parser.add_argument("--input_dir", type=str, required=True, help="Path to the input images directory.")
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the processed images.")
args = parser.parse_args()
main(temp_img=args.temp_img, input_dir=args.input_dir, output_dir=args.output_dir) |