|
import os |
|
import yaml |
|
import torch |
|
import nibabel as nib |
|
import numpy as np |
|
import gradio as gr |
|
from typing import Tuple |
|
import tempfile |
|
import shutil |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import cv2 |
|
import io |
|
import base64 |
|
import uuid |
|
import traceback |
|
|
|
import SimpleITK as sitk |
|
import itk |
|
from scipy.signal import medfilt |
|
import skimage.filters |
|
|
|
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Resized, NormalizeIntensityd, ToTensord, EnsureTyped |
|
from monai.inferers import sliding_window_inference |
|
|
|
from model import ViTUNETRSegmentationModel |
|
|
|
|
|
try: |
|
from HD_BET.run import run_hd_bet |
|
from HD_BET.hd_bet import hd_bet |
|
except Exception as e: |
|
print(f"Warning: HD_BET not available: {e}") |
|
run_hd_bet = None |
|
hd_bet = None |
|
|
|
|
|
APP_DIR = os.path.dirname(__file__) |
|
TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates") |
|
PARAMS_RIGID_PATH = os.path.join(TEMPLATE_DIR, "Parameters_Rigid.txt") |
|
DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "temp_head.nii.gz") |
|
FLAIR_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_04.5-18.5_t2w.nii.gz") |
|
HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py") |
|
HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") |
|
|
|
|
|
def load_config() -> dict: |
|
cfg_path = os.path.join(APP_DIR, "config.yml") |
|
if os.path.exists(cfg_path): |
|
with open(cfg_path, "r") as f: |
|
return yaml.safe_load(f) |
|
|
|
return { |
|
"gpu": {"device": "cpu"}, |
|
"infer": { |
|
"checkpoints": "./checkpoints/idh_model.ckpt", |
|
"simclr_checkpoint": None, |
|
"threshold": 0.5, |
|
"image_size": [96, 96, 96], |
|
}, |
|
} |
|
|
|
|
|
def build_model(cfg: dict): |
|
device = torch.device(cfg.get("gpu", {}).get("device", "cpu")) |
|
infer_cfg = cfg.get("infer", {}) |
|
model_cfg = cfg.get("model", {}) |
|
simclr_path = None |
|
ckpt_path = os.path.join(APP_DIR, infer_cfg.get("checkpoints", "")) |
|
|
|
model = ViTUNETRSegmentationModel( |
|
simclr_ckpt_path=None, |
|
img_size=tuple(model_cfg.get("img_size", [96, 96, 96])), |
|
in_channels=model_cfg.get("in_channels", 1), |
|
out_channels=model_cfg.get("out_channels", 1) |
|
) |
|
|
|
|
|
if os.path.exists(ckpt_path): |
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
if "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith("model."): |
|
new_state_dict[key[len("model."):]] = value |
|
else: |
|
new_state_dict[key] = value |
|
else: |
|
new_state_dict = checkpoint |
|
model.load_state_dict(new_state_dict, strict=False) |
|
else: |
|
print(f"Warning: Segmentation checkpoint not found at {ckpt_path}. Model will use backbone-only weights.") |
|
|
|
model.to(device) |
|
model.eval() |
|
return model, device |
|
|
|
|
|
|
|
|
|
def bias_field_correction(img_array: np.ndarray) -> np.ndarray: |
|
image = sitk.GetImageFromArray(img_array.astype(np.float32)) |
|
if image.GetPixelID() != sitk.sitkFloat32: |
|
image = sitk.Cast(image, sitk.sitkFloat32) |
|
maskImage = sitk.OtsuThreshold(image, 0, 1, 200) |
|
corrector = sitk.N4BiasFieldCorrectionImageFilter() |
|
numberFittingLevels = 4 |
|
max_iters = [min(50 * (2 ** i), 200) for i in range(numberFittingLevels)] |
|
corrector.SetMaximumNumberOfIterations(max_iters) |
|
corrected_image = corrector.Execute(image, maskImage) |
|
return sitk.GetArrayFromImage(corrected_image) |
|
|
|
|
|
def denoise(volume: np.ndarray, kernel_size: int = 3) -> np.ndarray: |
|
return medfilt(volume, kernel_size) |
|
|
|
|
|
def rescale_intensity(volume: np.ndarray, percentils=[0.5, 99.5], bins_num=256) -> np.ndarray: |
|
volume_float = volume.astype(np.float32) |
|
try: |
|
t = skimage.filters.threshold_otsu(volume_float, nbins=256) |
|
volume_masked = np.copy(volume_float) |
|
volume_masked[volume_masked < t] = 0 |
|
obj_volume = volume_masked[np.where(volume_masked > 0)] |
|
except ValueError: |
|
obj_volume = volume_float.flatten() |
|
if obj_volume.size == 0: |
|
obj_volume = volume_float.flatten() |
|
min_value = np.min(obj_volume) |
|
max_value = np.max(obj_volume) |
|
else: |
|
min_value = np.percentile(obj_volume, percentils[0]) |
|
max_value = np.percentile(obj_volume, percentils[1]) |
|
denom = max_value - min_value |
|
if denom < 1e-6: |
|
denom = 1e-6 |
|
if bins_num == 0: |
|
output_volume = (volume_float - min_value) / denom |
|
output_volume = np.clip(output_volume, 0.0, 1.0) |
|
else: |
|
output_volume = np.round((volume_float - min_value) / denom * (bins_num - 1)) |
|
output_volume = np.clip(output_volume, 0, bins_num - 1) |
|
return output_volume.astype(np.float32) |
|
|
|
|
|
def equalize_hist(volume: np.ndarray, bins_num=256) -> np.ndarray: |
|
mask = volume > 1e-6 |
|
obj_volume = volume[mask] |
|
if obj_volume.size == 0: |
|
return volume |
|
hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max())) |
|
cdf = hist.cumsum() |
|
cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1]) |
|
equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized) |
|
equalized_volume = np.copy(volume) |
|
equalized_volume[mask] = equalized_obj_volume |
|
return equalized_volume.astype(np.float32) |
|
|
|
|
|
def run_enhance_on_file(input_nifti_path: str, output_nifti_path: str): |
|
""" |
|
Simplified enhancement - just copy the file since N4 is now done in registration. |
|
This maintains compatibility with the existing preprocessing pipeline. |
|
""" |
|
print(f"Enhancement step (N4 already applied during registration): {input_nifti_path}") |
|
|
|
import shutil |
|
shutil.copy2(input_nifti_path, output_nifti_path) |
|
print(f"Enhancement complete (passthrough): {output_nifti_path}") |
|
|
|
|
|
def register_image_sitk(input_nifti_path: str, output_nifti_path: str, template_path: str, interp_type='linear'): |
|
""" |
|
MRI registration with SimpleITK matching the provided script approach. |
|
|
|
Args: |
|
input_nifti_path: Path to input NIfTI file |
|
output_nifti_path: Path to save registered output |
|
template_path: Path to template image |
|
interp_type: Interpolation type ('linear', 'bspline', 'nearest_neighbor') |
|
""" |
|
print(f"Registering {input_nifti_path} to template {template_path}") |
|
|
|
|
|
fixed_img = sitk.ReadImage(template_path, sitk.sitkFloat32) |
|
moving_img = sitk.ReadImage(input_nifti_path, sitk.sitkFloat32) |
|
|
|
|
|
moving_img = sitk.N4BiasFieldCorrection(moving_img) |
|
|
|
|
|
old_size = fixed_img.GetSize() |
|
old_spacing = fixed_img.GetSpacing() |
|
new_spacing = (1, 1, 1) |
|
new_size = [ |
|
int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), |
|
int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), |
|
int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) |
|
] |
|
|
|
|
|
if interp_type == 'linear': |
|
interp_type = sitk.sitkLinear |
|
elif interp_type == 'bspline': |
|
interp_type = sitk.sitkBSpline |
|
elif interp_type == 'nearest_neighbor': |
|
interp_type = sitk.sitkNearestNeighbor |
|
else: |
|
interp_type = sitk.sitkLinear |
|
|
|
|
|
resample = sitk.ResampleImageFilter() |
|
resample.SetOutputSpacing(new_spacing) |
|
resample.SetSize(new_size) |
|
resample.SetOutputOrigin(fixed_img.GetOrigin()) |
|
resample.SetOutputDirection(fixed_img.GetDirection()) |
|
resample.SetInterpolator(interp_type) |
|
resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue()) |
|
resample.SetOutputPixelType(sitk.sitkFloat32) |
|
fixed_img = resample.Execute(fixed_img) |
|
|
|
|
|
transform = sitk.CenteredTransformInitializer( |
|
fixed_img, |
|
moving_img, |
|
sitk.Euler3DTransform(), |
|
sitk.CenteredTransformInitializerFilter.GEOMETRY) |
|
|
|
|
|
registration_method = sitk.ImageRegistrationMethod() |
|
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) |
|
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) |
|
registration_method.SetMetricSamplingPercentage(0.01) |
|
registration_method.SetInterpolator(sitk.sitkLinear) |
|
registration_method.SetOptimizerAsGradientDescent( |
|
learningRate=1.0, |
|
numberOfIterations=100, |
|
convergenceMinimumValue=1e-6, |
|
convergenceWindowSize=10) |
|
registration_method.SetOptimizerScalesFromPhysicalShift() |
|
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) |
|
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) |
|
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() |
|
registration_method.SetInitialTransform(transform) |
|
|
|
|
|
final_transform = registration_method.Execute(fixed_img, moving_img) |
|
|
|
|
|
moving_img_resampled = sitk.Resample( |
|
moving_img, |
|
fixed_img, |
|
final_transform, |
|
sitk.sitkLinear, |
|
0.0, |
|
moving_img.GetPixelID()) |
|
|
|
sitk.WriteImage(moving_img_resampled, output_nifti_path) |
|
print(f"Registration complete. Saved to: {output_nifti_path}") |
|
|
|
|
|
def register_image(input_nifti_path: str, output_nifti_path: str): |
|
"""Wrapper to maintain compatibility - now uses SimpleITK registration.""" |
|
if not os.path.exists(DEFAULT_TEMPLATE_PATH): |
|
raise FileNotFoundError(f"Template file missing: {DEFAULT_TEMPLATE_PATH}") |
|
register_image_sitk(input_nifti_path, output_nifti_path, DEFAULT_TEMPLATE_PATH) |
|
|
|
|
|
def run_skull_stripping(input_nifti_path: str, output_dir: str): |
|
""" |
|
Brain extraction using HD-BET direct integration matching the script approach. |
|
|
|
Args: |
|
input_nifti_path: Path to input NIfTI file |
|
output_dir: Directory to save skull-stripped output |
|
|
|
Returns: |
|
tuple: (output_file_path, output_mask_path) |
|
""" |
|
print(f"Running HD-BET skull stripping on {input_nifti_path}") |
|
|
|
if hd_bet is None: |
|
raise RuntimeError("HD-BET not available. Please include HD_BET and hdbet_model in src/IDH.") |
|
|
|
if not os.path.exists(HD_BET_MODEL_DIR): |
|
raise FileNotFoundError(f"HD-BET models not found at {HD_BET_MODEL_DIR}") |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
base_name = os.path.basename(input_nifti_path).replace('.nii.gz', '').replace('.nii', '') |
|
|
|
|
|
temp_input_dir = os.path.join(output_dir, "temp_input") |
|
os.makedirs(temp_input_dir, exist_ok=True) |
|
|
|
|
|
temp_input_path = os.path.join(temp_input_dir, f"{base_name}_0000.nii.gz") |
|
shutil.copy2(input_nifti_path, temp_input_path) |
|
|
|
|
|
device = "0" if torch.cuda.is_available() else "cpu" |
|
|
|
try: |
|
|
|
model_file = os.path.join(HD_BET_MODEL_DIR, '0.model') |
|
|
|
if os.path.exists(model_file): |
|
print(f"Local model file exists at: {model_file}") |
|
else: |
|
print(f"Warning: Model file not found at: {model_file}") |
|
|
|
if os.path.exists(HD_BET_MODEL_DIR): |
|
print(f"Contents of {HD_BET_MODEL_DIR}: {os.listdir(HD_BET_MODEL_DIR)}") |
|
else: |
|
print(f"Directory {HD_BET_MODEL_DIR} does not exist") |
|
|
|
|
|
print(f"Running hd_bet with input_dir: {temp_input_dir}, output_dir: {output_dir}") |
|
hd_bet(temp_input_dir, output_dir, device=device, mode='fast', tta=0) |
|
|
|
|
|
output_file_path = os.path.join(output_dir, f"{base_name}_0000.nii.gz") |
|
output_mask_path = os.path.join(output_dir, f"{base_name}_0000_mask.nii.gz") |
|
|
|
|
|
final_output_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz") |
|
final_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz") |
|
|
|
if os.path.exists(output_file_path): |
|
shutil.move(output_file_path, final_output_path) |
|
if os.path.exists(output_mask_path): |
|
shutil.move(output_mask_path, final_mask_path) |
|
|
|
|
|
shutil.rmtree(temp_input_dir, ignore_errors=True) |
|
|
|
if not os.path.exists(final_output_path): |
|
raise RuntimeError(f"HD-BET did not produce output file: {final_output_path}") |
|
|
|
print(f"Skull stripping complete. Output saved to: {final_output_path}") |
|
return final_output_path, final_mask_path |
|
|
|
except Exception as e: |
|
|
|
shutil.rmtree(temp_input_dir, ignore_errors=True) |
|
raise RuntimeError(f"HD-BET skull stripping failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
def create_segmentation_plots(input_data_3d, seg_mask_3d, slice_index): |
|
"""Create segmentation visualization plots: Input, Mask, and Overlay.""" |
|
print(f"Generating segmentation plots for slice index: {slice_index}") |
|
|
|
if any(data is None for data in [input_data_3d, seg_mask_3d]): |
|
return None, None, None |
|
|
|
|
|
if not (0 <= slice_index < input_data_3d.shape[2]): |
|
print(f"Error: Slice index {slice_index} out of bounds (0-{input_data_3d.shape[2]-1}).") |
|
return None, None, None |
|
|
|
def save_plot_to_numpy(fig): |
|
with io.BytesIO() as buf: |
|
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) |
|
plt.close(fig) |
|
buf.seek(0) |
|
img_arr = plt.imread(buf, format='png') |
|
return (img_arr * 255).astype(np.uint8) |
|
|
|
try: |
|
|
|
input_slice = input_data_3d[:, :, slice_index] |
|
mask_slice = seg_mask_3d[:, :, slice_index] |
|
|
|
|
|
def normalize_slice(slice_data, volume_data): |
|
p1, p99 = np.percentile(volume_data, (1, 99)) |
|
denom = max(p99 - p1, 1e-6) |
|
return np.clip((slice_data - p1) / denom, 0, 1) |
|
|
|
input_slice_norm = normalize_slice(input_slice, input_data_3d) |
|
|
|
|
|
plots = [] |
|
|
|
|
|
fig1, ax1 = plt.subplots(figsize=(6, 6)) |
|
ax1.imshow(input_slice_norm, cmap='gray', interpolation='none', origin='lower') |
|
ax1.axis('off') |
|
ax1.set_title('Input Image', fontsize=14, color='white', pad=10) |
|
plots.append(save_plot_to_numpy(fig1)) |
|
|
|
|
|
fig2, ax2 = plt.subplots(figsize=(6, 6)) |
|
ax2.imshow(mask_slice, cmap='hot', interpolation='none', origin='lower', vmin=0, vmax=1) |
|
ax2.axis('off') |
|
ax2.set_title('Segmentation Mask', fontsize=14, color='white', pad=10) |
|
plots.append(save_plot_to_numpy(fig2)) |
|
|
|
|
|
fig3, ax3 = plt.subplots(figsize=(6, 6)) |
|
ax3.imshow(input_slice_norm, cmap='gray', interpolation='none', origin='lower') |
|
|
|
mask_overlay = np.ma.masked_where(mask_slice < 0.5, mask_slice) |
|
ax3.imshow(mask_overlay, cmap='Reds', interpolation='none', origin='lower', alpha=0.7, vmin=0, vmax=1) |
|
ax3.axis('off') |
|
ax3.set_title('Overlay', fontsize=14, color='white', pad=10) |
|
plots.append(save_plot_to_numpy(fig3)) |
|
|
|
print(f"Generated 3 segmentation plots successfully for axial slice {slice_index}.") |
|
return tuple(plots) |
|
|
|
except Exception as e: |
|
print(f"Error generating segmentation plots for slice {slice_index}: {e}") |
|
traceback.print_exc() |
|
return tuple([None] * 3) |
|
|
|
|
|
|
|
|
|
def extract_attention_map(vit_model, image, layer_idx=-1, img_size=(96, 96, 96), patch_size=16): |
|
""" |
|
Extracts the attention map from a Vision Transformer (ViT) model. |
|
|
|
This function wraps the attention blocks of the ViT to capture the attention |
|
weights during a forward pass. It then processes these weights to generate |
|
a 3D saliency map corresponding to the model's focus on the input image. |
|
""" |
|
attention_maps = {} |
|
original_attns = {} |
|
|
|
|
|
class AttentionWithWeights(torch.nn.Module): |
|
def __init__(self, original_attn_module): |
|
super().__init__() |
|
self.original_attn_module = original_attn_module |
|
self.attn_weights = None |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
output = self.original_attn_module(x) |
|
if hasattr(self.original_attn_module, 'qkv'): |
|
qkv = self.original_attn_module.qkv(x) |
|
batch_size, seq_len, _ = x.shape |
|
|
|
qkv = qkv.reshape(batch_size, seq_len, 3, self.original_attn_module.num_heads, -1) |
|
qkv = qkv.permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
attn = (q @ k.transpose(-2, -1)) * self.original_attn_module.scale |
|
self.attn_weights = attn.softmax(dim=-1) |
|
return output |
|
|
|
|
|
for i, block in enumerate(vit_model.blocks): |
|
if hasattr(block, 'attn'): |
|
original_attns[i] = block.attn |
|
block.attn = AttentionWithWeights(block.attn) |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
_ = vit_model(image) |
|
|
|
|
|
for i, block in enumerate(vit_model.blocks): |
|
if hasattr(block.attn, 'attn_weights') and block.attn.attn_weights is not None: |
|
attention_maps[f"layer_{i}"] = block.attn.attn_weights.detach() |
|
|
|
finally: |
|
|
|
for i, original_attn in original_attns.items(): |
|
vit_model.blocks[i].attn = original_attn |
|
|
|
if not attention_maps: |
|
raise RuntimeError("Could not extract any attention maps. Please check the ViT model structure.") |
|
|
|
|
|
if layer_idx < 0: |
|
layer_idx = len(attention_maps) + layer_idx |
|
layer_name = f"layer_{layer_idx}" |
|
if layer_name not in attention_maps: |
|
raise ValueError(f"Layer {layer_idx} not found. Available layers: {list(attention_maps.keys())}") |
|
|
|
layer_attn = attention_maps[layer_name] |
|
|
|
head_attn = layer_attn[0].mean(dim=0) |
|
|
|
cls_attn = head_attn[0, 1:] |
|
|
|
|
|
patches_per_dim = img_size[0] // patch_size |
|
total_patches = patches_per_dim ** 3 |
|
|
|
|
|
if cls_attn.shape[0] != total_patches: |
|
if cls_attn.shape[0] > total_patches: |
|
cls_attn = cls_attn[:total_patches] |
|
else: |
|
padded = torch.zeros(total_patches, device=cls_attn.device) |
|
padded[:cls_attn.shape[0]] = cls_attn |
|
cls_attn = padded |
|
|
|
cls_attn_3d = cls_attn.reshape(patches_per_dim, patches_per_dim, patches_per_dim) |
|
cls_attn_3d = cls_attn_3d.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
upsampled_attn = torch.nn.functional.interpolate( |
|
cls_attn_3d, |
|
size=img_size, |
|
mode='trilinear', |
|
align_corners=False |
|
).squeeze() |
|
|
|
|
|
upsampled_attn = upsampled_attn.cpu().numpy() |
|
upsampled_attn = (upsampled_attn - upsampled_attn.min()) / (upsampled_attn.max() - upsampled_attn.min()) |
|
return upsampled_attn |
|
|
|
|
|
def generate_saliency_dual(model, input_tensor, layer_idx=-1): |
|
""" |
|
Generate saliency maps for dual-input IDH model. |
|
|
|
Args: |
|
model: The complete IDH model |
|
input_tensor: Dual input tensor (batch_size, 2, C, D, H, W) |
|
layer_idx: ViT layer to visualize |
|
|
|
Returns: |
|
tuple: (flair_input_3d, t1c_input_3d, flair_saliency_3d) |
|
""" |
|
print("Generating saliency maps for dual input...") |
|
|
|
try: |
|
|
|
|
|
flair_tensor = input_tensor[:, 0] |
|
t1c_tensor = input_tensor[:, 1] |
|
|
|
|
|
vit_model = model.backbone.backbone |
|
|
|
|
|
flair_attn = extract_attention_map(vit_model, flair_tensor, layer_idx) |
|
|
|
|
|
flair_input_3d = flair_tensor.squeeze().cpu().detach().numpy() |
|
t1c_input_3d = t1c_tensor.squeeze().cpu().detach().numpy() |
|
|
|
print("Saliency maps generated successfully.") |
|
return flair_input_3d, t1c_input_3d, flair_attn |
|
|
|
except Exception as e: |
|
print(f"Error during saliency generation: {e}") |
|
traceback.print_exc() |
|
return None, None, None |
|
|
|
|
|
|
|
|
|
def create_slice_plots_dual(flair_data_3d, t1c_data_3d, flair_saliency_3d, slice_index): |
|
"""Create slice plots for simplified dual input visualization: T1c, FLAIR, FLAIR attention.""" |
|
print(f"Generating plots for slice index: {slice_index}") |
|
|
|
if any(data is None for data in [flair_data_3d, t1c_data_3d, flair_saliency_3d]): |
|
return None, None, None |
|
|
|
|
|
if not (0 <= slice_index < flair_data_3d.shape[2]): |
|
print(f"Error: Slice index {slice_index} out of bounds (0-{flair_data_3d.shape[2]-1}).") |
|
return None, None, None |
|
|
|
def save_plot_to_numpy(fig): |
|
with io.BytesIO() as buf: |
|
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) |
|
plt.close(fig) |
|
buf.seek(0) |
|
img_arr = plt.imread(buf, format='png') |
|
return (img_arr * 255).astype(np.uint8) |
|
|
|
try: |
|
|
|
flair_slice = flair_data_3d[:, :, slice_index] |
|
t1c_slice = t1c_data_3d[:, :, slice_index] |
|
flair_saliency_slice = flair_saliency_3d[:, :, slice_index] |
|
|
|
|
|
def normalize_slice(slice_data, volume_data): |
|
p1, p99 = np.percentile(volume_data, (1, 99)) |
|
denom = max(p99 - p1, 1e-6) |
|
return np.clip((slice_data - p1) / denom, 0, 1) |
|
|
|
flair_slice_norm = normalize_slice(flair_slice, flair_data_3d) |
|
t1c_slice_norm = normalize_slice(t1c_slice, t1c_data_3d) |
|
|
|
|
|
def process_saliency_slice(saliency_slice, saliency_volume): |
|
saliency_slice = np.copy(saliency_slice) |
|
saliency_slice[saliency_slice < 0] = 0 |
|
saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0) |
|
s_max = max(np.max(saliency_volume[saliency_volume >= 0]), 1e-6) |
|
saliency_slice_norm = saliency_slice_blurred / s_max |
|
return np.where(saliency_slice_norm > 0.0, saliency_slice_norm, 0) |
|
|
|
flair_sal_processed = process_saliency_slice(flair_saliency_slice, flair_saliency_3d) |
|
|
|
|
|
plots = [] |
|
|
|
|
|
fig1, ax1 = plt.subplots(figsize=(6, 6)) |
|
ax1.imshow(t1c_slice_norm, cmap='gray', interpolation='none', origin='lower') |
|
ax1.axis('off') |
|
ax1.set_title('T1c Input', fontsize=14, color='white', pad=10) |
|
plots.append(save_plot_to_numpy(fig1)) |
|
|
|
|
|
fig2, ax2 = plt.subplots(figsize=(6, 6)) |
|
ax2.imshow(flair_slice_norm, cmap='gray', interpolation='none', origin='lower') |
|
ax2.axis('off') |
|
ax2.set_title('FLAIR Input', fontsize=14, color='white', pad=10) |
|
plots.append(save_plot_to_numpy(fig2)) |
|
|
|
|
|
fig3, ax3 = plt.subplots(figsize=(6, 6)) |
|
ax3.imshow(flair_sal_processed, cmap='magma', interpolation='none', origin='lower', vmin=0) |
|
ax3.axis('off') |
|
ax3.set_title('FLAIR Attention', fontsize=14, color='white', pad=10) |
|
plots.append(save_plot_to_numpy(fig3)) |
|
|
|
print(f"Generated 3 plots successfully for axial slice {slice_index}.") |
|
return tuple(plots) |
|
|
|
except Exception as e: |
|
print(f"Error generating plots for slice {slice_index}: {e}") |
|
traceback.print_exc() |
|
return tuple([None] * 3) |
|
|
|
|
|
|
|
|
|
def get_validation_transform(image_size: Tuple[int, int, int]): |
|
return Compose([ |
|
LoadImaged(keys=["image"]), |
|
EnsureChannelFirstd(keys=["image"]), |
|
Resized(keys=["image"], spatial_size=tuple(image_size), mode="trilinear"), |
|
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
|
EnsureTyped(keys=["image"]), |
|
ToTensord(keys=["image"]), |
|
]) |
|
|
|
|
|
def preprocess_nifti(image_path: str, image_size: Tuple[int, int, int], device: torch.device) -> torch.Tensor: |
|
transform = get_validation_transform(image_size) |
|
sample = {"image": image_path} |
|
sample = transform(sample) |
|
image = sample["image"].unsqueeze(0).to(device) |
|
return image |
|
|
|
|
|
def save_nifti_for_download(data_array: np.ndarray, reference_path: str, output_path: str, affine=None): |
|
""" |
|
Save a numpy array as NIfTI file for download, preserving spatial information from reference. |
|
|
|
Args: |
|
data_array: 3D numpy array to save |
|
reference_path: Path to reference NIfTI file for header info |
|
output_path: Path where to save the output file |
|
affine: Optional affine matrix, if None will use reference |
|
""" |
|
try: |
|
|
|
ref_img = nib.load(reference_path) |
|
|
|
if affine is None: |
|
affine = ref_img.affine |
|
|
|
|
|
new_img = nib.Nifti1Image(data_array, affine, ref_img.header) |
|
|
|
|
|
nib.save(new_img, output_path) |
|
print(f"Saved NIfTI file: {output_path}") |
|
return output_path |
|
|
|
except Exception as e: |
|
print(f"Error saving NIfTI file: {e}") |
|
return None |
|
|
|
|
|
def predict_segmentation(input_file, threshold: float, do_preprocess: bool, cfg: dict, model, device): |
|
try: |
|
if input_file is None: |
|
return {"error": "Please upload a NIfTI file (.nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None |
|
|
|
input_path = input_file.name if hasattr(input_file, 'name') else input_file |
|
|
|
if not (input_path.endswith(".nii") or input_path.endswith(".nii.gz")): |
|
return {"error": "Input must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None |
|
|
|
work_dir = tempfile.mkdtemp() |
|
final_input_path = input_path |
|
|
|
try: |
|
|
|
if do_preprocess: |
|
|
|
reg_path = os.path.join(work_dir, "flair_registered.nii.gz") |
|
register_image_sitk(input_path, reg_path, FLAIR_TEMPLATE_PATH) |
|
|
|
enh_path = os.path.join(work_dir, "flair_enhanced.nii.gz") |
|
run_enhance_on_file(reg_path, enh_path) |
|
|
|
skullstrip_dir = os.path.join(work_dir, "skullstripped") |
|
bet_path, _ = run_skull_stripping(enh_path, skullstrip_dir) |
|
final_input_path = bet_path |
|
|
|
|
|
image_size = cfg.get("infer", {}).get("image_size", [96, 96, 96]) |
|
training_cfg = cfg.get("training", {}) |
|
input_tensor = preprocess_nifti(final_input_path, image_size, device) |
|
|
|
with torch.no_grad(): |
|
|
|
seg_logits = sliding_window_inference( |
|
inputs=input_tensor, |
|
roi_size=tuple(image_size), |
|
sw_batch_size=training_cfg.get("sw_batch_size", 2), |
|
predictor=model, |
|
overlap=0.5 |
|
) |
|
|
|
seg_prob = torch.sigmoid(seg_logits) |
|
seg_mask = (seg_prob > threshold).float() |
|
|
|
|
|
input_3d = input_tensor.squeeze().cpu().detach().numpy() |
|
seg_prob_3d = seg_prob.squeeze().cpu().detach().numpy() |
|
seg_mask_3d = seg_mask.squeeze().cpu().detach().numpy() |
|
|
|
|
|
total_voxels = np.prod(seg_mask_3d.shape) |
|
segmented_voxels = int(np.sum(seg_mask_3d)) |
|
segmentation_percentage = (segmented_voxels / total_voxels) * 100 |
|
|
|
prediction_result = { |
|
"segmented_voxels": segmented_voxels, |
|
"total_voxels": total_voxels, |
|
"segmentation_percentage": float(segmentation_percentage), |
|
"threshold": float(threshold), |
|
"preprocessing": bool(do_preprocess), |
|
"max_probability": float(np.max(seg_prob_3d)), |
|
"mean_probability": float(np.mean(seg_prob_3d)) |
|
} |
|
|
|
|
|
input_img = seg_mask_img = overlay_img = None |
|
slider_update = gr.Slider(visible=False) |
|
viz_state = {"input_paths": None, "mask_paths": None, "num_slices": 0} |
|
|
|
|
|
download_preprocessed = None |
|
download_mask = None |
|
|
|
|
|
print("--- Generating Visualizations ---") |
|
try: |
|
num_slices = input_3d.shape[2] |
|
center_slice_index = num_slices // 2 |
|
|
|
|
|
unique_id = str(uuid.uuid4()) |
|
temp_paths = [] |
|
for name, data in [("input", input_3d), ("seg_prob", seg_prob_3d), ("seg_mask", seg_mask_3d)]: |
|
path = os.path.join(work_dir, f"{unique_id}_{name}.npy") |
|
np.save(path, data) |
|
temp_paths.append(path) |
|
|
|
|
|
plots = create_segmentation_plots(input_3d, seg_mask_3d, center_slice_index) |
|
if plots and all(p is not None for p in plots): |
|
input_img, seg_mask_img, overlay_img = plots |
|
|
|
|
|
viz_state = { |
|
"input_paths": [temp_paths[0]], |
|
"mask_paths": temp_paths[1:], |
|
"num_slices": num_slices |
|
} |
|
slider_update = gr.Slider(value=center_slice_index, minimum=0, maximum=num_slices-1, step=1, label="Select Slice", visible=True) |
|
print("--- Visualization Generation Complete ---") |
|
|
|
except Exception as e: |
|
print(f"Error during visualization generation: {e}") |
|
traceback.print_exc() |
|
|
|
|
|
print("--- Generating Download Files ---") |
|
try: |
|
|
|
base_name = os.path.splitext(os.path.basename(input_path))[0] |
|
if base_name.endswith('.nii'): |
|
base_name = os.path.splitext(base_name)[0] |
|
|
|
|
|
preprocessed_download_path = os.path.join(work_dir, f"{base_name}_preprocessed.nii.gz") |
|
|
|
saved_preprocessed_path = save_nifti_for_download( |
|
input_3d, |
|
input_path, |
|
preprocessed_download_path |
|
) |
|
if saved_preprocessed_path: |
|
download_preprocessed = gr.File(value=saved_preprocessed_path, visible=True, label="Download Preprocessed Image") |
|
|
|
|
|
mask_download_path = os.path.join(work_dir, f"{base_name}_segmentation_mask.nii.gz") |
|
saved_mask_path = save_nifti_for_download( |
|
seg_mask_3d, |
|
final_input_path, |
|
mask_download_path |
|
) |
|
if saved_mask_path: |
|
download_mask = gr.File(value=saved_mask_path, visible=True, label="Download Segmentation Mask") |
|
|
|
print("--- Download Files Generated ---") |
|
|
|
except Exception as e: |
|
print(f"Error generating download files: {e}") |
|
traceback.print_exc() |
|
|
|
return (prediction_result, input_img, seg_mask_img, overlay_img, slider_update, viz_state, download_preprocessed, download_mask) |
|
|
|
except Exception as e: |
|
shutil.rmtree(work_dir, ignore_errors=True) |
|
return {"error": f"Processing failed: {str(e)}"}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None |
|
|
|
except Exception as e: |
|
return {"error": str(e)}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None |
|
|
|
|
|
def update_slice_viewer_segmentation(slice_index, current_state): |
|
"""Update slice viewer for segmentation visualization.""" |
|
input_paths = current_state.get("input_paths", []) |
|
mask_paths = current_state.get("mask_paths", []) |
|
|
|
if not input_paths or not mask_paths or len(input_paths) != 1 or len(mask_paths) != 2: |
|
print(f"Warning: Invalid state for slice viewer update: {current_state}") |
|
return None, None, None |
|
|
|
try: |
|
|
|
input_3d = np.load(input_paths[0]) |
|
seg_mask_3d = np.load(mask_paths[1]) |
|
|
|
|
|
slice_index = int(slice_index) |
|
if not (0 <= slice_index < input_3d.shape[2]): |
|
print(f"Warning: Invalid slice index {slice_index}") |
|
return None, None, None |
|
|
|
|
|
plots = create_segmentation_plots(input_3d, seg_mask_3d, slice_index) |
|
return plots if plots else tuple([None] * 3) |
|
|
|
except Exception as e: |
|
print(f"Error updating slice viewer for index {slice_index}: {e}") |
|
traceback.print_exc() |
|
return tuple([None] * 3) |
|
|
|
|
|
def build_interface(): |
|
cfg = load_config() |
|
model, device = build_model(cfg) |
|
default_threshold = float(cfg.get("infer", {}).get("threshold", 0.5)) |
|
|
|
with gr.Blocks(title="BrainIAC: Glioma Segmentation", css=""" |
|
#header-row { |
|
min-height: 150px; |
|
align-items: center; |
|
} |
|
.logo-img img { |
|
height: 150px; |
|
object-fit: contain; |
|
} |
|
""") as demo: |
|
|
|
with gr.Row(elem_id="header-row"): |
|
with gr.Column(scale=1): |
|
gr.Image(os.path.join(APP_DIR, "static/images/kannlab.png"), |
|
show_label=False, interactive=False, |
|
show_download_button=False, |
|
container=False, |
|
elem_classes=["logo-img"]) |
|
with gr.Column(scale=3): |
|
gr.Markdown( |
|
"<h1 style='text-align: center; margin-bottom: 2.5rem'>" |
|
"BrainIAC: Glioma Segmentation" |
|
"</h1>" |
|
) |
|
with gr.Column(scale=1): |
|
gr.Image(os.path.join(APP_DIR, "static/images/brainiac.jpeg"), |
|
show_label=False, interactive=False, |
|
show_download_button=False, |
|
container=False, |
|
elem_classes=["logo-img"]) |
|
|
|
|
|
with gr.Accordion("ℹ️ Model Details and Usage Guide", open=False): |
|
gr.Markdown(""" |
|
### 🧠 BrainIAC: Glioma Segmentation |
|
|
|
**Model Description** |
|
A Vision Transformer UNETR (ViT-UNETR) model with BrainIAC as pre-trained backbone designed for glioma segmentation from MRI scans. |
|
|
|
**Training Dataset** |
|
- **Subjects**: Trained on MRI scans from glioma patients |
|
- **Imaging Modalities**: Single modality MRI FLAIR, and binary mask |
|
- **Preprocessing**: N4 bias correction, MNI registration, and skull stripping (HD-BET) |
|
|
|
**Input** |
|
- Format: NIfTI (.nii or .nii.gz) |
|
- Single MRI FLAIR sequence |
|
- Image size: Automatically resized to 96×96×96 voxels |
|
|
|
**Output** |
|
- Binary segmentation mask highlighting glioma regions |
|
- Segmentation statistics (volume, percentage) |
|
- Probability maps and overlay visualization |
|
|
|
**Intended Use** |
|
- Research use only! |
|
|
|
**NOTE** |
|
- Single modality input FLAIR |
|
- Not validated on other MRI sequences |
|
- Not validated for other brain pathologies beyond gliomas |
|
- Upload PHI data at own risk! |
|
- The model is hosted on a cloud-based CPU instance |
|
- The data is not stored, shared or collected for any purpose! |
|
|
|
**Visualization** |
|
The interface shows three views for each slice: |
|
- **Input Image**: The preprocessed MRI scan |
|
- **Segmentation Mask**: The predicted binary mask |
|
- **Overlay**: The mask overlaid on the input image |
|
|
|
""") |
|
|
|
|
|
viz_state = gr.State({"input_paths": None, "mask_paths": None, "num_slices": 0}) |
|
|
|
|
|
gr.Markdown("**Upload MRI NIfTI volume** — Optional preprocessing performs registration to MNI, enhancement, and skull stripping.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("### Controls") |
|
input_file = gr.File(label="MRI Image (.nii or .nii.gz)") |
|
preprocess_checkbox = gr.Checkbox(value=False, label="Preprocess NIfTI (debiasing + registration + skull stripping)") |
|
threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, step=0.01, label="Segmentation Threshold") |
|
predict_btn = gr.Button("Generate Segmentation", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Group(): |
|
gr.Markdown("### Segmentation Result") |
|
output_json = gr.JSON(label="Results") |
|
|
|
|
|
with gr.Row(): |
|
download_preprocessed_btn = gr.File(label="Download Preprocessed Image", visible=False) |
|
download_mask_btn = gr.File(label="Download Segmentation Mask", visible=False) |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown("### Segmentation Viewer (Axial Slice)") |
|
slice_slider = gr.Slider(label="Select Slice", minimum=0, maximum=0, step=1, value=0, visible=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("<p style='text-align: center;'>Input Image</p>") |
|
input_img = gr.Image(label="Input Image", type="numpy", show_label=False) |
|
with gr.Column(): |
|
gr.Markdown("<p style='text-align: center;'>Segmentation Mask</p>") |
|
seg_mask_img = gr.Image(label="Segmentation Mask", type="numpy", show_label=False) |
|
with gr.Column(): |
|
gr.Markdown("<p style='text-align: center;'>Overlay</p>") |
|
overlay_img = gr.Image(label="Overlay", type="numpy", show_label=False) |
|
|
|
|
|
predict_btn.click( |
|
fn=lambda f, prep, thr: predict_segmentation(f, thr, prep, cfg, model, device), |
|
inputs=[input_file, preprocess_checkbox, threshold_input], |
|
outputs=[output_json, input_img, seg_mask_img, overlay_img, slice_slider, viz_state, download_preprocessed_btn, download_mask_btn], |
|
) |
|
|
|
slice_slider.change( |
|
fn=update_slice_viewer_segmentation, |
|
inputs=[slice_slider, viz_state], |
|
outputs=[input_img, seg_mask_img, overlay_img] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
iface = build_interface() |
|
iface.launch(server_name="0.0.0.0", server_port=7860) |