import os import torch import nibabel as nib import gradio as gr import tempfile import yaml import traceback import zipfile import dicom2nifti import shutil import subprocess import SimpleITK as sitk import itk import numpy as np from scipy.signal import medfilt import skimage.filters import cv2 import io import base64 import uuid import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') try: # Assumes HD_BET is now at /app/BrainIAC/HD_BET or adjacent in src from HD_BET.run import run_hd_bet from monai.visualize.gradient_based import GuidedBackpropSmoothGrad except ImportError as e: print(f"Warning: Could not import HD_BET or MONAI visualize: {e}. Saliency/Preprocessing might fail.") run_hd_bet = None GuidedBackpropSmoothGrad = None # Import necessary components from your existing modules from model import Backbone, SingleScanModel, Classifier from monai.transforms import Resized, ScaleIntensityd # --- Constants --- APP_DIR = os.path.dirname(__file__) TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates") PARAMS_RIGID_PATH = os.path.join(APP_DIR, "golden_image", "mni_templates", "Parameters_Rigid.txt") DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii") HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py") # May need adjustment based on actual HD_BET location HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") # Path to copied models # --- Configuration Loading --- def load_config(): config_path = os.path.join(APP_DIR, 'config.yml') try: with open(config_path, 'r') as file: config = yaml.safe_load(file) if 'data' not in config: config['data'] = {} if 'image_size' not in config['data']: config['data']['image_size'] = [128, 128, 128] except FileNotFoundError: print(f"Warning: Configuration file not found at {config_path}. Using defaults.") config = { 'gpu': {'device': 'cpu'}, 'infer': {'checkpoints': 'checkpoints/brainage_model_latest.pt'}, 'data': {'image_size': [128, 128, 128]} } return config config = load_config() DEFAULT_IMAGE_SIZE = (128, 128, 128) image_size_cfg = config.get('data', {}).get('image_size', DEFAULT_IMAGE_SIZE) if not isinstance(image_size_cfg, (list, tuple)) or len(image_size_cfg) != 3: print(f"Warning: Invalid image_size in config ({image_size_cfg}). Using default {DEFAULT_IMAGE_SIZE}.") image_size = DEFAULT_IMAGE_SIZE else: image_size = tuple(image_size_cfg) # --- Model Loading --- def load_model(cfg): device = torch.device(cfg.get('gpu', {}).get('device', 'cpu')) backbone = Backbone() classifier = Classifier(d_model=2048) model = SingleScanModel(backbone, classifier) relative_path = cfg.get('infer', {}).get('checkpoints', 'checkpoints/brainage_model_latest.pt') checkpoint_path_abs = os.path.join(APP_DIR, relative_path) try: checkpoint = torch.load(checkpoint_path_abs, map_location=device) state_dict = checkpoint.get('model_state_dict', checkpoint) model.load_state_dict(state_dict) model.to(device) model.eval() print(f"Model loaded successfully from {checkpoint_path_abs} onto {device}.") return model, device except FileNotFoundError: print(f"Error: Checkpoint file not found at {checkpoint_path_abs}") return None, device except Exception as e: print(f"Error loading model checkpoint: {e}") traceback.print_exc() return None, device model, device = load_model(config) # --- Preprocessing Functions (Copied/Adapted from app.py) --- def bias_field_correction(img_array): print(" Running N4 Bias Field Correction...") image = sitk.GetImageFromArray(img_array) if image.GetPixelID() != sitk.sitkFloat32: image = sitk.Cast(image, sitk.sitkFloat32) maskImage = sitk.OtsuThreshold(image, 0, 1, 200) corrector = sitk.N4BiasFieldCorrectionImageFilter() numberFittingLevels = 4 max_iters = [min(50 * (2**i), 200) for i in range(numberFittingLevels)] corrector.SetMaximumNumberOfIterations(max_iters) corrected_image = corrector.Execute(image, maskImage) print(" N4 Correction finished.") return sitk.GetArrayFromImage(corrected_image) def denoise(volume, kernel_size=3): print(f" Applying median filter denoising (kernel={kernel_size})...") return medfilt(volume, kernel_size) def rescale_intensity(volume, percentils=[0.5, 99.5], bins_num=256): print(" Rescaling intensity...") volume_float = volume.astype(np.float32) try: t = skimage.filters.threshold_otsu(volume_float, nbins=256) volume_masked = np.copy(volume_float) volume_masked[volume_masked < t] = 0 obj_volume = volume_masked[np.where(volume_masked > 0)] except ValueError: print(" Otsu failed, skipping background mask.") obj_volume = volume_float.flatten() if obj_volume.size == 0: print(" Warning: No foreground voxels found. Scaling full volume.") obj_volume = volume_float.flatten() min_value = np.min(obj_volume) max_value = np.max(obj_volume) else: min_value = np.percentile(obj_volume, percentils[0]) max_value = np.percentile(obj_volume, percentils[1]) denominator = max_value - min_value if denominator < 1e-6: denominator = 1e-6 output_volume = np.copy(volume_float) if bins_num == 0: output_volume = (volume_float - min_value) / denominator output_volume = np.clip(output_volume, 0.0, 1.0) else: output_volume = np.round((volume_float - min_value) / denominator * (bins_num - 1)) output_volume = np.clip(output_volume, 0, bins_num - 1) return output_volume.astype(np.float32) def equalize_hist(volume, bins_num=256): print(" Performing histogram equalization...") mask = volume > 1e-6 obj_volume = volume[mask] if obj_volume.size == 0: print(" Warning: No non-zero voxels. Skipping equalization.") return volume hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max())) cdf = hist.cumsum() cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1]) equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized) equalized_volume = np.copy(volume) equalized_volume[mask] = equalized_obj_volume return equalized_volume.astype(np.float32) def enhance(img_array, run_bias_correction=True, kernel_size=3, percentils=[0.5, 99.5], bins_num=256, run_equalize_hist=True): print("Starting enhancement pipeline...") volume = img_array.astype(np.float32) try: if run_bias_correction: volume = bias_field_correction(volume) volume = denoise(volume, kernel_size) volume = rescale_intensity(volume, percentils, bins_num) if run_equalize_hist: volume = equalize_hist(volume, bins_num) print("Enhancement pipeline finished.") return volume except Exception as e: print(f"Error during enhancement: {e}") traceback.print_exc() raise RuntimeError(f"Failed enhancing image: {e}") def register_image(input_nifti_path, output_nifti_path): print(f"Registering {input_nifti_path} to {DEFAULT_TEMPLATE_PATH}") if not all(os.path.exists(p) for p in [PARAMS_RIGID_PATH, DEFAULT_TEMPLATE_PATH]): raise FileNotFoundError("Elastix parameter or template file not found.") fixed_image = itk.imread(DEFAULT_TEMPLATE_PATH, itk.F) moving_image = itk.imread(input_nifti_path, itk.F) parameter_object = itk.ParameterObject.New() parameter_object.AddParameterFile(PARAMS_RIGID_PATH) result_image, _ = itk.elastix_registration_method(fixed_image, moving_image, parameter_object=parameter_object, log_to_console=False) itk.imwrite(result_image, output_nifti_path) print(f"Registration output saved to {output_nifti_path}") def run_enhance_on_file(input_nifti_path, output_nifti_path): print(f"Running full enhancement on {input_nifti_path}") img_sitk = sitk.ReadImage(input_nifti_path) img_array = sitk.GetArrayFromImage(img_sitk) enhanced_array = enhance(img_array, run_bias_correction=True) enhanced_img_sitk = sitk.GetImageFromArray(enhanced_array) enhanced_img_sitk.CopyInformation(img_sitk) sitk.WriteImage(enhanced_img_sitk, output_nifti_path) print(f"Enhanced image saved to {output_nifti_path}") def run_skull_stripping(input_nifti_path, output_dir): print(f"Running HD-BET skull stripping on {input_nifti_path}") if run_hd_bet is None: raise RuntimeError("HD-BET module not imported.") if not os.path.exists(HD_BET_CONFIG_PATH): raise FileNotFoundError(f"HD-BET config not found at {HD_BET_CONFIG_PATH}") if not os.path.isdir(HD_BET_MODEL_DIR): raise FileNotFoundError(f"HD-BET models not found at {HD_BET_MODEL_DIR}") base_name = os.path.basename(input_nifti_path).replace(".nii.gz", "").replace(".nii", "") output_file_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz") output_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz") os.makedirs(output_dir, exist_ok=True) try: run_hd_bet(input_nifti_path, output_file_path, mode="fast", device='cpu', config_file=HD_BET_CONFIG_PATH, postprocess=False, do_tta=False, keep_mask=True, overwrite=True) finally: pass # Keep commented if env var not needed if not os.path.exists(output_file_path): raise RuntimeError("HD-BET did not produce output file.") print(f"Skull stripping output saved to {output_file_path}") return output_file_path, output_mask_path # --- MONAI Transforms --- resize_transform = Resized(keys=["image"], spatial_size=image_size) scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) def preprocess_nifti_for_model(nifti_path): print(f"Preprocessing NIfTI for model: {nifti_path}") scan_data = nib.load(nifti_path).get_fdata() scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim sample = {"image": scan_tensor} sample_resized = resize_transform(sample) sample_scaled = scale_transform(sample_resized) input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim if input_tensor.dim() != 5: raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}") print(f" Final shape for model: {input_tensor.shape}") return input_tensor # --- Saliency Generation --- def generate_saliency(model_to_use, input_tensor_5d): if GuidedBackpropSmoothGrad is None: raise ImportError("MONAI visualize components not imported.") if model_to_use is None: raise ValueError("Model not loaded.") print("Generating saliency map...") input_tensor_5d.requires_grad_(True) visualizer = GuidedBackpropSmoothGrad(model=model_to_use.backbone.to(device), stdev_spread=0.15, n_samples=10, magnitude=True) try: with torch.enable_grad(): saliency_map_5d = visualizer(input_tensor_5d.to(device)) input_3d = input_tensor_5d.squeeze().cpu().detach().numpy() saliency_3d = saliency_map_5d.squeeze().cpu().detach().numpy() print("Saliency map generated.") return input_3d, saliency_3d except Exception as e: print(f"Error during saliency map generation: {e}") traceback.print_exc() return None, None finally: input_tensor_5d.requires_grad_(False) # --- Plotting Function (Returns NumPy arrays for Gradio) --- def create_slice_plots(mri_data_3d, saliency_data_3d, slice_index): print(f" Generating plots for slice index: {slice_index}") if mri_data_3d is None or saliency_data_3d is None: return None, None, None if not (0 <= slice_index < mri_data_3d.shape[2]): print(f" Error: Slice index {slice_index} out of bounds (0-{mri_data_3d.shape[2]-1}).") return None, None, None # Function to save plot to NumPy array def save_plot_to_numpy(fig): with io.BytesIO() as buf: fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) # Adjust DPI as needed plt.close(fig) buf.seek(0) img_arr = plt.imread(buf, format='png') # Return RGBA array, can be simplified if only grayscale needed for input return (img_arr * 255).astype(np.uint8) try: mri_slice = mri_data_3d[:, :, slice_index] saliency_slice_orig = saliency_data_3d[:, :, slice_index] # Normalize MRI Slice (using volume stats) p1_vol, p99_vol = np.percentile(mri_data_3d, (1, 99)) mri_norm_denom = max(p99_vol - p1_vol, 1e-6) mri_slice_norm = np.clip((mri_slice - p1_vol) / mri_norm_denom, 0, 1) # Process Saliency Slice saliency_slice = np.copy(saliency_slice_orig) saliency_slice[saliency_slice < 0] = 0 saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0) s_max_vol = max(np.max(saliency_data_3d[saliency_data_3d >= 0]), 1e-6) saliency_slice_norm = saliency_slice_blurred / s_max_vol saliency_slice_thresholded = np.where(saliency_slice_norm > 0.0, saliency_slice_norm, 0) # Threshold slightly > 0 # Plot 1: Input Slice fig1, ax1 = plt.subplots(figsize=(6, 6)) ax1.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower') ax1.axis('off') input_plot_np = save_plot_to_numpy(fig1) # Plot 2: Saliency Heatmap fig2, ax2 = plt.subplots(figsize=(6, 6)) ax2.imshow(saliency_slice_thresholded, cmap='magma', interpolation='none', origin='lower', vmin=0) # Set vmin ax2.axis('off') heatmap_plot_np = save_plot_to_numpy(fig2) # Plot 3: Overlay fig3, ax3 = plt.subplots(figsize=(6, 6)) ax3.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower') if np.max(saliency_slice_thresholded) > 0: ax3.contour(saliency_slice_thresholded, cmap='magma', origin='lower', linewidths=1.0, levels=np.linspace(saliency_slice_thresholded.min(), saliency_slice_thresholded.max(), 5)) # Adjust levels ax3.axis('off') overlay_plot_np = save_plot_to_numpy(fig3) print(f" Generated numpy plots successfully for slice {slice_index}.") return input_plot_np, heatmap_plot_np, overlay_plot_np except Exception as e: print(f"Error generating numpy plots for slice {slice_index}: {e}") traceback.print_exc() return None, None, None # --- Gradio Processing Function --- def process_scan(file_type, uploaded_file, run_preprocess, generate_saliency_flag): if model is None: raise gr.Error("Model is not loaded. Cannot perform prediction.") if uploaded_file is None: raise gr.Error("No file uploaded.") temp_dir = tempfile.mkdtemp() print(f"Created temp directory: {temp_dir}") nifti_for_preprocessing_path = None error_message = None prediction_text = "Processing..." # Initialize outputs to None or placeholder images/values input_plot, heatmap_plot, overlay_plot = None, None, None saliency_state = {"input_path": None, "saliency_path": None, "num_slices": 0} slider_update = gr.Slider(value=0, minimum=0, maximum=1, visible=False) # Initially hidden, use max=1 to avoid log(0) error try: # --- Handle Upload and DICOM Conversion --- file_path = uploaded_file.name # Get path from Gradio file object filename = os.path.basename(file_path) print(f"Processing '{filename}' (type: {file_type})") if file_type == 'NIfTI': # Check if the filename ends with either .nii or .nii.gz if not (filename.lower().endswith('.nii.gz') or filename.lower().endswith('.nii')): raise gr.Error("Invalid NIfTI file. Please upload .nii or .nii.gz") # Define the destination path (always .nii.gz for consistency) dest_path = os.path.join(temp_dir, "uploaded_scan.nii.gz") nifti_for_preprocessing_path = dest_path # Check if the uploaded file is uncompressed .nii if filename.lower().endswith('.nii') and not filename.lower().endswith('.nii.gz'): print(f"Detected uncompressed .nii file: {filename}. Compressing to {dest_path}") try: # Load the uncompressed .nii file img = nib.load(file_path) # Save it as a compressed .nii.gz file nib.save(img, dest_path) print(f"Successfully compressed and saved to: {dest_path}") except Exception as e: raise gr.Error(f"Failed to load or compress .nii file: {e}") else: # If it's already .nii.gz, just copy it print(f"Copying compressed NIfTI {filename} to: {dest_path}") shutil.copy(file_path, dest_path) # nifti_for_preprocessing_path is already set to dest_path # print(f"NIfTI path for preprocessing: {nifti_for_preprocessing_path}") # Redundant logging elif file_type == 'DICOM (zip)': if not filename.endswith('.zip'): raise gr.Error("Invalid DICOM file. Please upload a .zip archive.") uploaded_zip_path = os.path.join(temp_dir, "dicom_files.zip") shutil.copy(file_path, uploaded_zip_path) print(f"Copied DICOM zip to: {uploaded_zip_path}") dicom_input_dir = os.path.join(temp_dir, "dicom_input") nifti_output_dir = os.path.join(temp_dir, "nifti_output") os.makedirs(dicom_input_dir, exist_ok=True) os.makedirs(nifti_output_dir, exist_ok=True) try: shutil.unpack_archive(uploaded_zip_path, dicom_input_dir) print("Unzip successful.") except Exception as e: raise gr.Error(f"Error unzipping DICOM file: {e}") try: dicom2nifti.convert_directory(dicom_input_dir, nifti_output_dir, compression=True, reorient=True) nifti_files = [f for f in os.listdir(nifti_output_dir) if f.endswith('.nii.gz')] if not nifti_files: raise RuntimeError("dicom2nifti did not produce a .nii.gz file.") nifti_for_preprocessing_path = os.path.join(nifti_output_dir, nifti_files[0]) print(f"DICOM conversion successful. NIfTI: {nifti_for_preprocessing_path}") except Exception as e: raise gr.Error(f"Error converting DICOM to NIfTI: {e}") else: raise gr.Error("Invalid file type selected.") if not nifti_for_preprocessing_path or not os.path.exists(nifti_for_preprocessing_path): raise gr.Error("Could not find the NIfTI file after initial processing.") # --- Optional Preprocessing --- nifti_to_predict_path = nifti_for_preprocessing_path if run_preprocess: print("--- Running Optional Preprocessing Pipeline ---") try: registered_path = os.path.join(temp_dir, "registered.nii.gz") register_image(nifti_for_preprocessing_path, registered_path) enhanced_path = os.path.join(temp_dir, "enhanced.nii.gz") run_enhance_on_file(registered_path, enhanced_path) skullstrip_output_dir = os.path.join(temp_dir, "skullstripped") skullstripped_path, _ = run_skull_stripping(enhanced_path, skullstrip_output_dir) nifti_to_predict_path = skullstripped_path print("--- Optional Preprocessing Pipeline Complete ---") except Exception as e: raise gr.Error(f"Error during preprocessing: {e}") else: print("--- Skipping Optional Preprocessing Pipeline ---") # --- Prediction --- input_tensor_5d = preprocess_nifti_for_model(nifti_to_predict_path) print("Performing prediction...") with torch.no_grad(): output = model(input_tensor_5d) predicted_age = output.item() predicted_age_years = predicted_age / 12 # Assuming model output is in months prediction_text = f"Predicted Age: {predicted_age_years:.2f} years" print(prediction_text) # --- Saliency Map Generation --- if generate_saliency_flag: print("--- Generating Saliency Data ---") try: input_3d, saliency_3d = generate_saliency(model, input_tensor_5d) if input_3d is not None and saliency_3d is not None: num_slices = input_3d.shape[2] center_slice_index = num_slices // 2 # Save numpy arrays to the temp dir for the slider callback unique_id = str(uuid.uuid4()) input_array_path = os.path.join(temp_dir, f"{unique_id}_input.npy") saliency_array_path = os.path.join(temp_dir, f"{unique_id}_saliency.npy") np.save(input_array_path, input_3d) np.save(saliency_array_path, saliency_3d) print(f"Saved input array: {input_array_path}") print(f"Saved saliency array: {saliency_array_path}") # Generate initial plots for the center slice input_plot, heatmap_plot, overlay_plot = create_slice_plots(input_3d, saliency_3d, center_slice_index) # Update state for the slider callback saliency_state = { "input_path": input_array_path, "saliency_path": saliency_array_path, "num_slices": num_slices } # Update and show the slider slider_update = gr.Slider(value=center_slice_index, minimum=0, maximum=num_slices - 1, step=1, label="Select Slice", visible=True) print("--- Saliency Generated and Initial Plot Created ---") else: error_message = "Saliency map generation failed." print(f"Warning: {error_message}") except ImportError as e: error_message = f"Cannot generate saliency: {e}" print(f"Warning: {error_message}") except Exception as e: error_message = f"Error during saliency processing: {e}" traceback.print_exc() print(f"Warning: {error_message}") # --- Cleanup handled by Gradio/System Temp --- # We are saving numpy arrays needed for the slider in the temp dir # Gradio might clean its own temp files, need to test if these persist for the slider. # If not, might need a more persistent storage or pass data differently. except Exception as e: print(f"Error in process_scan: {e}") traceback.print_exc() # Use gr.Warning for non-fatal errors shown to user if error_message: # Prepend specific error if available gr.Warning(f"{error_message}. General error: {e}") else: gr.Warning(f"An error occurred: {e}") # Return default/error states for outputs return "Error during processing", None, None, None, gr.Slider(visible=False), {"input_path": None, "saliency_path": None, "num_slices": 0} finally: # Optional: Schedule cleanup of the temp_dir if files aren't needed long-term # Be cautious if files ARE needed by slider state. Gradio might handle this? # shutil.rmtree(temp_dir, ignore_errors=True) # print(f"Cleaned up temp directory: {temp_dir}") # <--- Defer cleanup pass # Return results for Gradio Interface return prediction_text, input_plot, heatmap_plot, overlay_plot, slider_update, saliency_state # --- Gradio Slider Update Function --- def update_slice_viewer(slice_index, current_state): input_path = current_state.get("input_path") saliency_path = current_state.get("saliency_path") if not input_path or not saliency_path or not os.path.exists(input_path) or not os.path.exists(saliency_path): print(f"Warning: Cannot update slice viewer. Missing or invalid numpy array paths in state: {current_state}") # Return None or placeholder images to indicate error return None, None, None try: input_3d = np.load(input_path) saliency_3d = np.load(saliency_path) num_slices = input_3d.shape[2] # Ensure slice_index is valid (Gradio slider should handle bounds, but double-check) slice_index = int(slice_index) if not (0 <= slice_index < num_slices): print(f"Warning: Invalid slice index {slice_index} received by update function.") return None, None, None # Or return previous plots? # Generate new plots for the selected slice input_plot_np, heatmap_plot_np, overlay_plot_np = create_slice_plots(input_3d, saliency_3d, slice_index) return input_plot_np, heatmap_plot_np, overlay_plot_np except Exception as e: print(f"Error updating slice viewer for index {slice_index}: {e}") traceback.print_exc() # Return None or indicate error return None, None, None # --- Build Gradio Interface --- with gr.Blocks(css=""" #header-row { min-height: 150px; align-items: center; } .logo-img img { height: 150px; object-fit: contain; } """) as demo: # Header Row with Logos and Title with gr.Row(elem_id="header-row"): with gr.Column(scale=1): gr.Image(os.path.join(APP_DIR, "static/images/kannlab.png"), show_label=False, interactive=False, show_download_button=False, container=False, elem_classes=["logo-img"]) with gr.Column(scale=3): gr.Markdown( "

" "BrainIAC: Brain Age Prediction" "

" ) with gr.Column(scale=1): gr.Image(os.path.join(APP_DIR, "static/images/brainiac.jpeg"), show_label=False, interactive=False, show_download_button=False, container=False, elem_classes=["logo-img"]) # --- Add model description section --- with gr.Accordion("ℹ️ Model Details and Usage Guide", open=False): gr.Markdown(""" ### 🧠 BrainIAC: Brain Age Prediction **Model Description** A 3D ResNet50 model trained to predict brain age from T1-weighted MRI scans. **Training Dataset** - **Subjects**: trained and tested on dataset consisting of 6,240 scans with age range of 0–35 years (Pediatric and youg Adult) - **Imaging Modality**: T1-weighted MRI - **Preprocessing**: Registration to MNI, N4 bias correction, histogram equalization, skull stripping **Input** - Format: NIfTI or zipped DICOM - Required sequence: T1w (3D) **Output** - Brain age in years, Range 0-35 years **Intended Use** - Research use only! (See [LICENSE](file/LICENSE)) **NOTE** - Not validated on T2, FLAIR, DWI or other sequences - Not validated on pathological cases - Upload PHI data at own risk! - The model is hosted on a cloud-based CPU instance. - The data is not store, shared or collected for any purpose! """) # Use gr.State to store paths to numpy arrays for the slider callback saliency_state = gr.State({"input_path": None, "saliency_path": None, "num_slices": 0}) # Main Content Row (Controls Left, Output Right) with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Controls") file_type = gr.Radio(["NIfTI", "DICOM (zip)"], label="Select Input File Type", value="NIfTI") scan_file = gr.File(label="Upload Scan File") run_preprocess = gr.Checkbox(label="Run Preprocessing Pipeline ", value=False) generate_saliency_checkbox = gr.Checkbox(label="Generate Saliency Maps ", value=True) submit_btn = gr.Button("Predict Brain Age", variant="primary") with gr.Column(scale=3): with gr.Group(): gr.Markdown("### Prediction Result") prediction_output = gr.Label(label="Prediction Result") with gr.Group(): gr.Markdown("### Saliency Map Viewer (Axial Slice)") slice_slider = gr.Slider(label="Select Slice", minimum=0, maximum=0, step=1, value=0, visible=False) with gr.Row(): with gr.Column(): gr.Markdown("

Input Slice

") input_slice_img = gr.Image(label="Input Slice", type="numpy", show_label=False) with gr.Column(): gr.Markdown("

Saliency Heatmap

") heatmap_slice_img = gr.Image(label="Saliency Heatmap", type="numpy", show_label=False) with gr.Column(): gr.Markdown("

Overlay

") overlay_slice_img = gr.Image(label="Overlay", type="numpy", show_label=False) # --- Wire Components --- submit_btn.click( fn=process_scan, inputs=[file_type, scan_file, run_preprocess, generate_saliency_checkbox], outputs=[prediction_output, input_slice_img, heatmap_slice_img, overlay_slice_img, slice_slider, saliency_state] ) slice_slider.change( fn=update_slice_viewer, inputs=[slice_slider, saliency_state], outputs=[input_slice_img, heatmap_slice_img, overlay_slice_img] ) # --- Launch the App --- if __name__ == "__main__": if model is None: print("ERROR: Model failed to load. Gradio app cannot start.") else: print("Launching Gradio Interface...") demo.launch(server_name="0.0.0.0", server_port=7860, debug=False, share=False)