Divytak commited on
Commit
3bf67c5
·
verified ·
1 Parent(s): 895e6ba

Delete src/BrainIAC/app.py

Browse files
Files changed (1) hide show
  1. src/BrainIAC/app.py +0 -728
src/BrainIAC/app.py DELETED
@@ -1,728 +0,0 @@
1
- import os
2
- import torch
3
- import nibabel as nib
4
- from flask import Flask, request, render_template, redirect, url_for, flash, jsonify
5
- import tempfile
6
- import yaml
7
- import traceback # For detailed error printing
8
- import zipfile
9
- import dicom2nifti
10
- import shutil
11
- import subprocess # To run unzip command
12
- import SimpleITK as sitk
13
- import itk
14
- import numpy as np
15
- from scipy.signal import medfilt
16
- import skimage.filters
17
- import cv2 # For Gaussian Blur
18
- import io # For saving plots to memory
19
- import base64 # For encoding plots
20
- import uuid # For unique IDs
21
-
22
- # Configure Matplotlib for non-GUI backend *before* importing pyplot
23
- import matplotlib
24
- matplotlib.use('Agg')
25
- import matplotlib.pyplot as plt
26
-
27
- # --- Preprocessing Imports ---
28
- try:
29
- # Adjust import path based on Docker structure
30
- # Assumes HD_BET is now at /app/BrainIAC/HD_BET
31
- from HD_BET.run import run_hd_bet
32
- # Import MONAI saliency visualizer
33
- from monai.visualize.gradient_based import GuidedBackpropSmoothGrad
34
- except ImportError as e:
35
- print(f"Could not import HD_BET or MONAI visualize: {e}. Advanced features might fail.")
36
- run_hd_bet = None
37
- GuidedBackpropSmoothGrad = None
38
-
39
- # Import necessary components from your existing modules
40
- from model import Backbone, SingleScanModel, Classifier
41
- # Removed: from dataset2 import NormalSynchronizedTransform3D
42
- # Import specific MONAI transforms needed
43
- from monai.transforms import Resized, ScaleIntensityd # Removed ToTensord, will handle manually
44
-
45
- app = Flask(__name__)
46
- app.secret_key = 'supersecretkey' # Needed for flashing messages
47
-
48
- # --- Constants for Preprocessing ---
49
- APP_DIR = os.path.dirname(__file__)
50
- TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates")
51
- PARAMS_RIGID_PATH = os.path.join(APP_DIR, "golden_image", "mni_templates", "Parameters_Rigid.txt")
52
- DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii") # Using adult template as default
53
- HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py")
54
- HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") # Path to copied models
55
-
56
- # --- Configuration Loading ---
57
- def load_config():
58
- # Assuming config.yml is in the same directory as app.py
59
- config_path = os.path.join(APP_DIR, 'config.yml')
60
- try:
61
- with open(config_path, 'r') as file:
62
- config = yaml.safe_load(file)
63
- # Add default image_size if not present in config
64
- if 'data' not in config: config['data'] = {}
65
- if 'image_size' not in config['data']: config['data']['image_size'] = [128, 128, 128]
66
-
67
- except FileNotFoundError:
68
- print(f"Error: Configuration file not found at {config_path}")
69
- # Provide default config or handle error appropriately
70
- config = {
71
- 'gpu': {'device': 'cpu'},
72
- 'infer': {'checkpoints': 'checkpoints/brainage_model_latest.pt'},
73
- 'data': {'image_size': [128, 128, 128]} # Default image size
74
- }
75
- return config
76
-
77
- config = load_config()
78
- # Ensure image_size is available, e.g., from config or a default
79
- DEFAULT_IMAGE_SIZE = (128, 128, 128)
80
- image_size_cfg = config.get('data', {}).get('image_size', DEFAULT_IMAGE_SIZE)
81
- # Validate image_size format
82
- if not isinstance(image_size_cfg, (list, tuple)) or len(image_size_cfg) != 3:
83
- print(f"Warning: Invalid image_size in config ({image_size_cfg}). Using default {DEFAULT_IMAGE_SIZE}.")
84
- image_size = DEFAULT_IMAGE_SIZE
85
- else:
86
- image_size = tuple(image_size_cfg) # Ensure it's a tuple for transforms
87
-
88
- # --- Model Loading ---
89
- def load_model(device, checkpoint_path):
90
- backbone = Backbone()
91
- classifier = Classifier(d_model=2048) # Make sure d_model matches your trained model
92
- model = SingleScanModel(backbone, classifier)
93
-
94
- try:
95
- # Construct absolute path if checkpoint_path is relative
96
- relative_path = config.get('infer', {}).get('checkpoints', 'checkpoints/brainage_model_latest.pt')
97
- # Use path relative to app.py location
98
- checkpoint_path_abs = os.path.join(APP_DIR, relative_path)
99
-
100
- checkpoint = torch.load(checkpoint_path_abs, map_location=device)
101
- # Adjust key if necessary based on how model was saved
102
- if 'model_state_dict' in checkpoint:
103
- model.load_state_dict(checkpoint['model_state_dict'])
104
- else:
105
- model.load_state_dict(checkpoint)
106
- model.to(device)
107
- model.eval()
108
- print(f"Model loaded successfully from {checkpoint_path_abs} onto {device}.")
109
- return model
110
- except FileNotFoundError:
111
- print(f"Error: Checkpoint file not found at {checkpoint_path_abs}")
112
- return None
113
- except Exception as e:
114
- print(f"Error loading model checkpoint: {e}")
115
- traceback.print_exc()
116
- return None
117
-
118
- device = torch.device(config.get('gpu', {}).get('device', 'cpu')) # Default to CPU
119
- model = load_model(device, config) # Pass full config for path finding
120
-
121
- # --- Preprocessing Functions from preprocess_utils.py ---
122
- def bias_field_correction(img_array):
123
- """Performs N4 bias field correction using SimpleITK."""
124
- image = sitk.GetImageFromArray(img_array)
125
- # Ensure image is float32 for N4
126
- if image.GetPixelID() != sitk.sitkFloat32:
127
- image = sitk.Cast(image, sitk.sitkFloat32)
128
- maskImage = sitk.OtsuThreshold(image, 0, 1, 200)
129
- corrector = sitk.N4BiasFieldCorrectionImageFilter()
130
- numberFittingLevels = 4
131
- # Define iterations per level more robustly
132
- max_iters = [min(50 * (2**i), 200) for i in range(numberFittingLevels)]
133
- corrector.SetMaximumNumberOfIterations(max_iters)
134
- # Set convergence threshold (optional, can speed up)
135
- # corrector.SetConvergenceThreshold(1e-6)
136
- print(" Running N4 Bias Field Correction...")
137
- corrected_image = corrector.Execute(image, maskImage)
138
- print(" N4 Correction finished.")
139
- return sitk.GetArrayFromImage(corrected_image)
140
-
141
- def denoise(volume, kernel_size=3):
142
- """Applies median filter for denoising."""
143
- print(f" Applying median filter denoising (kernel={kernel_size})...")
144
- return medfilt(volume, kernel_size)
145
-
146
- def rescale_intensity(volume, percentils=[0.5, 99.5], bins_num=256):
147
- """Rescales intensity after removing background via Otsu."""
148
- print(" Rescaling intensity...")
149
- # Ensure input is float for Otsu and calculations
150
- volume_float = volume.astype(np.float32)
151
- try:
152
- t = skimage.filters.threshold_otsu(volume_float, nbins=256)
153
- print(f" Otsu threshold found: {t}")
154
- volume_masked = np.copy(volume_float)
155
- volume_masked[volume_masked < t] = 0 # Apply mask based on original values
156
- obj_volume = volume_masked[np.where(volume_masked > 0)]
157
- except ValueError: # Handle cases with near-uniform intensity
158
- print(" Otsu failed (likely uniform image), skipping background mask.")
159
- obj_volume = volume_float.flatten()
160
-
161
- if obj_volume.size == 0:
162
- print(" Warning: No foreground voxels found after Otsu. Scaling full volume.")
163
- obj_volume = volume_float.flatten() # Fallback to full volume
164
- min_value = np.min(obj_volume)
165
- max_value = np.max(obj_volume)
166
- else:
167
- min_value = np.percentile(obj_volume, percentils[0])
168
- max_value = np.percentile(obj_volume, percentils[1])
169
-
170
- print(f" Intensity range used for scaling: [{min_value:.2f}, {max_value:.2f}]")
171
- # Avoid division by zero if max == min
172
- denominator = max_value - min_value
173
- if denominator < 1e-6: denominator = 1e-6
174
-
175
- # Create a copy to modify for output
176
- output_volume = np.copy(volume_float)
177
- # Apply scaling only to the object volume identified (or full volume as fallback)
178
- if bins_num == 0:
179
- # Scale to 0-1 (float)
180
- output_volume = (volume_float - min_value) / denominator
181
- output_volume = np.clip(output_volume, 0.0, 1.0) # Clip results to [0, 1]
182
- else:
183
- # Scale and bin
184
- output_volume = np.round((volume_float - min_value) / denominator * (bins_num - 1))
185
- output_volume = np.clip(output_volume, 0, bins_num - 1) # Ensure within bin range
186
-
187
- # Ensure output is float32 for consistency
188
- return output_volume.astype(np.float32)
189
-
190
- def equalize_hist(volume, bins_num=256):
191
- """Performs histogram equalization on non-zero voxels."""
192
- print(" Performing histogram equalization...")
193
- # Create a mask of non-zero voxels
194
- mask = volume > 1e-6 # Use a small epsilon for float comparison
195
- obj_volume = volume[mask]
196
-
197
- if obj_volume.size == 0:
198
- print(" Warning: No non-zero voxels found for histogram equalization. Skipping.")
199
- return volume # Return original volume if no foreground
200
-
201
- # Compute histogram and CDF on the non-zero voxels
202
- hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max()))
203
- cdf = hist.cumsum()
204
-
205
- # Normalize CDF
206
- cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1])
207
-
208
- # Interpolate new values for the object volume
209
- equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized)
210
-
211
- # Create a copy of the original volume to put the results back
212
- equalized_volume = np.copy(volume)
213
- equalized_volume[mask] = equalized_obj_volume
214
-
215
- # Ensure output is float32
216
- return equalized_volume.astype(np.float32)
217
-
218
- def enhance(img_array, run_bias_correction=True, kernel_size=3, percentils=[0.5, 99.5], bins_num=256, run_equalize_hist=True):
219
- """Full enhancement pipeline from preprocess_utils."""
220
- print("Starting enhancement pipeline...")
221
- volume = img_array.astype(np.float32) # Ensure float input
222
- try:
223
- if run_bias_correction:
224
- volume = bias_field_correction(volume)
225
- volume = denoise(volume, kernel_size)
226
- volume = rescale_intensity(volume, percentils, bins_num)
227
- if run_equalize_hist:
228
- volume = equalize_hist(volume, bins_num)
229
- print("Enhancement pipeline finished.")
230
- return volume
231
- except Exception as e:
232
- print(f"Error during enhancement: {e}")
233
- traceback.print_exc()
234
- raise RuntimeError(f"Failed enhancing image: {e}") # Re-raise to stop processing
235
-
236
- # --- Registration Function (modified enhance call) ---
237
- def register_image(input_nifti_path, output_nifti_path):
238
- """Registers input NIfTI to the default template using Elastix."""
239
- print(f"Registering {input_nifti_path} to {DEFAULT_TEMPLATE_PATH}")
240
- if not os.path.exists(PARAMS_RIGID_PATH):
241
- raise FileNotFoundError(f"Elastix parameter file not found at {PARAMS_RIGID_PATH}")
242
- if not os.path.exists(DEFAULT_TEMPLATE_PATH):
243
- raise FileNotFoundError(f"Default template file not found at {DEFAULT_TEMPLATE_PATH}")
244
-
245
- fixed_image = itk.imread(DEFAULT_TEMPLATE_PATH, itk.F)
246
- moving_image = itk.imread(input_nifti_path, itk.F)
247
-
248
- parameter_object = itk.ParameterObject.New()
249
- parameter_object.AddParameterFile(PARAMS_RIGID_PATH)
250
-
251
- result_image, _ = itk.elastix_registration_method(
252
- fixed_image, moving_image,
253
- parameter_object=parameter_object,
254
- log_to_console=False # Keep console clean
255
- )
256
- itk.imwrite(result_image, output_nifti_path)
257
- print(f"Registration output saved to {output_nifti_path}")
258
-
259
- # --- Enhanced Image Function (calls actual enhance) ---
260
- def run_enhance_on_file(input_nifti_path, output_nifti_path):
261
- """Reads NIfTI, runs enhance pipeline, saves NIfTI."""
262
- print(f"Running full enhancement on {input_nifti_path}")
263
- img_sitk = sitk.ReadImage(input_nifti_path)
264
- img_array = sitk.GetArrayFromImage(img_sitk)
265
-
266
- # Run the actual enhancement pipeline
267
- enhanced_array = enhance(img_array, run_bias_correction=True) # Assuming N4 is desired
268
-
269
- enhanced_img_sitk = sitk.GetImageFromArray(enhanced_array)
270
- enhanced_img_sitk.CopyInformation(img_sitk) # Preserve metadata
271
- sitk.WriteImage(enhanced_img_sitk, output_nifti_path)
272
- print(f"Enhanced image saved to {output_nifti_path}")
273
-
274
- # --- Skull Stripping Function (Set Environment Variable) ---
275
- def run_skull_stripping(input_nifti_path, output_dir):
276
- """Runs HD-BET skull stripping."""
277
- print(f"Running HD-BET skull stripping on {input_nifti_path}")
278
- if run_hd_bet is None:
279
- raise RuntimeError("HD-BET module could not be imported. Cannot perform skull stripping.")
280
-
281
- # Removed environment variable setting as path is fixed in HD_BET/paths.py
282
- # # Set environment variable *before* calling run_hd_bet
283
- # # Ensure the target directory exists
284
- # if not os.path.isdir(HD_BET_MODEL_DIR):
285
- # raise FileNotFoundError(f"HD-BET model directory not found at specified path: {HD_BET_MODEL_DIR}")
286
- # print(f"Setting HD_BET_MODELS environment variable to: {HD_BET_MODEL_DIR}")
287
- # os.environ['HD_BET_MODELS'] = HD_BET_MODEL_DIR
288
-
289
- # Check config path
290
- if not os.path.exists(HD_BET_CONFIG_PATH):
291
- alt_config_path = os.path.join(APP_DIR, "HD_BET", "HD_BET", "config.py")
292
- if os.path.exists(alt_config_path):
293
- print(f"Warning: Using alternative HD-BET config path: {alt_config_path}")
294
- config_to_use = alt_config_path
295
- else:
296
- raise FileNotFoundError(f"HD-BET config file not found at {HD_BET_CONFIG_PATH} or {alt_config_path}")
297
- else:
298
- config_to_use = HD_BET_CONFIG_PATH
299
-
300
- # Define output paths
301
- base_name = os.path.basename(input_nifti_path).replace(".nii.gz", "").replace(".nii", "")
302
- output_file_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz")
303
- output_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz")
304
-
305
- # Make sure output directory exists
306
- os.makedirs(output_dir, exist_ok=True)
307
-
308
- # Run HD-BET
309
- run_hd_bet(input_nifti_path, output_file_path,
310
- mode="fast",
311
- device='cpu',
312
- config_file=config_to_use,
313
- postprocess=False,
314
- do_tta=False,
315
- keep_mask=True,
316
- overwrite=True)
317
-
318
- # Unset environment variable after use (optional, good practice)
319
- # del os.environ['HD_BET_MODELS']
320
-
321
- if not os.path.exists(output_file_path):
322
- raise RuntimeError(f"HD-BET did not produce the expected output file: {output_file_path}")
323
-
324
- print(f"Skull stripping output saved to {output_file_path}")
325
- return output_file_path, output_mask_path
326
-
327
- # --- Image Preprocessing ---
328
- # Define necessary MONAI transforms directly
329
- # Keys must match the dictionary keys we create later ('image')
330
- resize_transform = Resized(keys=["image"], spatial_size=image_size)
331
- scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0)
332
-
333
- def preprocess_nifti(nifti_path):
334
- """Loads and preprocesses a NIfTI file, returning a 5D tensor."""
335
- print(f"Preprocessing NIfTI: {nifti_path}")
336
- scan_data = nib.load(nifti_path).get_fdata()
337
- print(f" Loaded scan data shape: {scan_data.shape}")
338
- scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim
339
- print(f" Shape after tensor+channel: {scan_tensor.shape}")
340
- sample = {"image": scan_tensor}
341
- sample_resized = resize_transform(sample)
342
- print(f" Shape after resize: {sample_resized['image'].shape}")
343
- sample_scaled = scale_transform(sample_resized)
344
- print(f" Shape after scaling: {sample_scaled['image'].shape}")
345
- input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim
346
- print(f" Final shape for model: {input_tensor.shape}")
347
- if input_tensor.dim() != 5:
348
- raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}. Expected 5D.")
349
- return input_tensor
350
-
351
- # --- Final NIfTI Preprocessing for Model ---
352
- def preprocess_nifti_for_model(nifti_path):
353
- """Loads final NIfTI and prepares 5D tensor for the model."""
354
- # ... (Same as previous preprocess_nifti function) ...
355
- print(f"Preprocessing NIfTI for model: {nifti_path}")
356
- scan_data = nib.load(nifti_path).get_fdata()
357
- print(f" Loaded scan data shape: {scan_data.shape}")
358
- scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim
359
- print(f" Shape after tensor+channel: {scan_tensor.shape}")
360
- sample = {"image": scan_tensor}
361
- sample_resized = resize_transform(sample)
362
- print(f" Shape after resize: {sample_resized['image'].shape}")
363
- sample_scaled = scale_transform(sample_resized)
364
- print(f" Shape after scaling: {sample_scaled['image'].shape}")
365
- input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim
366
- print(f" Final shape for model: {input_tensor.shape}")
367
- if input_tensor.dim() != 5:
368
- raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}. Expected 5D.")
369
- return input_tensor
370
-
371
- # --- Saliency Map Generation ---
372
- def generate_saliency(model, input_tensor_5d):
373
- """Generates saliency map using GuidedBackpropSmoothGrad."""
374
- if GuidedBackpropSmoothGrad is None:
375
- raise ImportError("MONAI visualize components not imported. Cannot generate saliency map.")
376
- if model is None:
377
- raise ValueError("Model not loaded. Cannot generate saliency map.")
378
-
379
- print("Generating saliency map...")
380
- input_tensor_5d.requires_grad_(True)
381
- # Use the backbone for saliency as in the original script
382
- # Ensure model and backbone are on the correct device (CPU in this case)
383
- visualizer = GuidedBackpropSmoothGrad(model=model.backbone.to(device),
384
- stdev_spread=0.15,
385
- n_samples=10,
386
- magnitude=True)
387
-
388
- try:
389
- with torch.enable_grad():
390
- saliency_map_5d = visualizer(input_tensor_5d.to(device))
391
- print("Saliency map generated.")
392
-
393
- # Detach, move to CPU, remove Batch and Channel dims for processing/plotting -> (D, H, W)
394
- input_3d = input_tensor_5d.squeeze().cpu().detach().numpy()
395
- saliency_3d = saliency_map_5d.squeeze().cpu().detach().numpy()
396
-
397
- return input_3d, saliency_3d
398
-
399
- except Exception as e:
400
- print(f"Error during saliency map generation: {e}")
401
- traceback.print_exc()
402
- # Return None or empty arrays if generation fails
403
- return None, None
404
- finally:
405
- # Ensure requires_grad is turned off if it was modified
406
- input_tensor_5d.requires_grad_(False)
407
-
408
- # --- Plotting Function for Single Slice ---
409
- def create_plot_images_for_slice(mri_data_3d, saliency_data_3d, slice_index):
410
- """Creates base64 encoded PNGs for a specific axial slice index."""
411
- print(f" Generating plots for slice index: {slice_index}")
412
- if mri_data_3d is None or saliency_data_3d is None:
413
- print(" Input or Saliency data is None, cannot generate plot.")
414
- return None
415
- if slice_index < 0 or slice_index >= mri_data_3d.shape[2]:
416
- print(f" Error: Slice index {slice_index} out of bounds (0-{mri_data_3d.shape[2]-1}).")
417
- return None
418
-
419
- # Function to save plot to base64 string (copied from previous version)
420
- def save_plot_to_base64(fig):
421
- buf = io.BytesIO()
422
- fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75)
423
- plt.close(fig) # Close the figure immediately
424
- buf.seek(0)
425
- img_str = base64.b64encode(buf.read()).decode('utf-8')
426
- buf.close()
427
- return img_str
428
-
429
- try:
430
- mri_slice = mri_data_3d[:, :, slice_index]
431
- saliency_slice_orig = saliency_data_3d[:, :, slice_index]
432
-
433
- # --- Normalize MRI Slice (using volume stats if available, otherwise slice stats) ---
434
- # For consistency, ideally pass volume stats, but recalculating per slice is fallback
435
- p1_vol, p99_vol = np.percentile(mri_data_3d, (1, 99))
436
- mri_norm_denom = p99_vol - p1_vol
437
- if mri_norm_denom < 1e-6: mri_norm_denom = 1e-6
438
- mri_slice_norm = np.clip(mri_slice, p1_vol, p99_vol)
439
- mri_slice_norm = (mri_slice_norm - p1_vol) / mri_norm_denom
440
-
441
- # --- Process Saliency Slice ---
442
- saliency_slice = np.copy(saliency_slice_orig)
443
- saliency_slice[saliency_slice < 0] = 0 # Ensure non-negative
444
- saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0)
445
- # Use volume max for normalization if possible, fallback to slice max
446
- s_max_vol = np.max(saliency_data_3d[saliency_data_3d >= 0]) # Max of non-negative values in volume
447
- if s_max_vol < 1e-6: s_max_vol = 1e-6
448
- # --- Add logging for the calculated global max ---
449
- print(f" Calculated Global Max Saliency (s_max_vol) for normalization: {s_max_vol:.4f}")
450
- # --------------------------------------------------
451
- saliency_slice_norm = saliency_slice_blurred / s_max_vol
452
- threshold_value = 0.0
453
- saliency_slice_thresholded = np.where(saliency_slice_norm > threshold_value, saliency_slice_norm, 0)
454
-
455
- # --- Generate Plots ---
456
- slice_plots = {}
457
-
458
- # Plot 1: Input Slice
459
- fig1, ax1 = plt.subplots(figsize=(3, 3))
460
- ax1.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower')
461
- ax1.axis('off')
462
- slice_plots['input_slice'] = save_plot_to_base64(fig1)
463
-
464
- # Plot 2: Saliency Heatmap
465
- fig2, ax2 = plt.subplots(figsize=(3, 3))
466
- ax2.imshow(saliency_slice_thresholded, cmap='magma', interpolation='none', origin='lower')
467
- ax2.axis('off')
468
- slice_plots['heatmap_slice'] = save_plot_to_base64(fig2)
469
-
470
- # Plot 3: Overlay
471
- fig3, ax3 = plt.subplots(figsize=(3, 3))
472
- ax3.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower')
473
- if np.max(saliency_slice_thresholded) > 0:
474
- # Remove fixed levels to let contour auto-determine levels based on slice data
475
- ax3.contour(saliency_slice_thresholded, cmap='magma', origin='lower', linewidths=1.0)
476
- ax3.axis('off')
477
- slice_plots['overlay_slice'] = save_plot_to_base64(fig3)
478
-
479
- print(f" Generated plots successfully for slice {slice_index}.")
480
- return slice_plots
481
-
482
- except Exception as e:
483
- print(f"Error generating plots for slice {slice_index}: {e}")
484
- traceback.print_exc()
485
- return None
486
-
487
- # --- Flask Routes ---
488
- @app.route('/', methods=['GET'])
489
- def index():
490
- return render_template('index.html')
491
-
492
- @app.route('/predict', methods=['POST'])
493
- def predict():
494
- if model is None:
495
- flash('Model not loaded. Cannot perform prediction.', 'error')
496
- return redirect(url_for('index'))
497
-
498
- # Get form data
499
- file_type = request.form.get('file_type')
500
- run_preprocess_flag = request.form.get('preprocess') == 'yes'
501
- generate_saliency_flag = request.form.get('generate_saliency') == 'yes' # Get saliency flag
502
- file = request.files.get('scan_file')
503
-
504
- # --- Basic Input Validation ---
505
- if not file_type:
506
- flash('Please select a file type (NIfTI or DICOM).', 'error')
507
- return redirect(url_for('index'))
508
- if not file or file.filename == '':
509
- flash('No scan file selected', 'error')
510
- return redirect(url_for('index'))
511
-
512
- print(f"Received upload: type='{file_type}', filename='{file.filename}', preprocess={run_preprocess_flag}, saliency={generate_saliency_flag}")
513
-
514
- # --- Setup Temporary Directory ---
515
- # temp_dir_obj = tempfile.TemporaryDirectory() # <--- PROBLEM: Cleans up automatically
516
- # Use mkdtemp to create a persistent temporary directory
517
- # NOTE: Requires a manual cleanup strategy later!
518
- try:
519
- temp_dir = tempfile.mkdtemp()
520
- except Exception as e:
521
- print(f"Error creating temporary directory: {e}")
522
- flash("Server error: Could not create temporary directory.", "error")
523
- return redirect(url_for('index'))
524
-
525
- # Generate a unique ID based on the temp directory name
526
- unique_id = os.path.basename(temp_dir)
527
- print(f"Created persistent temp directory: {temp_dir} (ID: {unique_id})")
528
- nifti_for_preprocessing_path = None # Path to the NIfTI before optional preprocessing
529
-
530
- try:
531
- # --- Handle Upload and DICOM Conversion ---
532
- # --- Handle NIfTI Upload ---
533
- if file_type == 'nifti':
534
- if not file.filename.endswith('.nii.gz'):
535
- flash('Invalid file type for NIfTI selection. Please upload .nii.gz', 'error')
536
- # temp_dir_obj.cleanup() # No object to cleanup, need manual rmtree
537
- shutil.rmtree(temp_dir, ignore_errors=True)
538
- return redirect(url_for('index'))
539
- uploaded_file_path = os.path.join(temp_dir, "uploaded_scan.nii.gz")
540
- file.save(uploaded_file_path)
541
- print(f"Saved uploaded NIfTI file to: {uploaded_file_path}")
542
- nifti_for_preprocessing_path = uploaded_file_path
543
-
544
- # --- Handle DICOM Upload ---
545
- elif file_type == 'dicom':
546
- if not file.filename.endswith('.zip'):
547
- flash('Invalid file type for DICOM selection. Please upload a .zip file.', 'error')
548
- # temp_dir_obj.cleanup()
549
- shutil.rmtree(temp_dir, ignore_errors=True)
550
- return redirect(url_for('index'))
551
- uploaded_zip_path = os.path.join(temp_dir, "dicom_files.zip")
552
- file.save(uploaded_zip_path)
553
- print(f"Saved uploaded DICOM zip to: {uploaded_zip_path}")
554
- dicom_input_dir = os.path.join(temp_dir, "dicom_input")
555
- nifti_output_dir = os.path.join(temp_dir, "nifti_output")
556
- os.makedirs(dicom_input_dir, exist_ok=True)
557
- os.makedirs(nifti_output_dir, exist_ok=True)
558
- try:
559
- # Use shutil.unpack_archive for better cross-platform compatibility potentially
560
- shutil.unpack_archive(uploaded_zip_path, dicom_input_dir)
561
- print(f"Unzip successful.")
562
- except Exception as e:
563
- print(f"Unzip failed: {e}")
564
- flash(f'Error unzipping DICOM file: {e}', 'error')
565
- # temp_dir_obj.cleanup()
566
- shutil.rmtree(temp_dir, ignore_errors=True)
567
- return redirect(url_for('index'))
568
- try:
569
- dicom2nifti.convert_directory(dicom_input_dir, nifti_output_dir, compression=True, reorient=True)
570
- nifti_files = [f for f in os.listdir(nifti_output_dir) if f.endswith('.nii.gz')]
571
- if not nifti_files:
572
- raise RuntimeError("dicom2nifti did not produce a .nii.gz file.")
573
- nifti_for_preprocessing_path = os.path.join(nifti_output_dir, nifti_files[0])
574
- print(f"DICOM conversion successful. NIfTI file: {nifti_for_preprocessing_path}")
575
- except Exception as e:
576
- print(f"DICOM to NIfTI conversion failed: {e}")
577
- flash(f'Error converting DICOM to NIfTI: {e}', 'error')
578
- # temp_dir_obj.cleanup()
579
- shutil.rmtree(temp_dir, ignore_errors=True)
580
- return redirect(url_for('index'))
581
- else:
582
- flash('Invalid file type selected.', 'error')
583
- # temp_dir_obj.cleanup()
584
- shutil.rmtree(temp_dir, ignore_errors=True)
585
- return redirect(url_for('index'))
586
-
587
- if not nifti_for_preprocessing_path or not os.path.exists(nifti_for_preprocessing_path):
588
- flash('Error: Could not find the NIfTI file for processing.', 'error')
589
- # temp_dir_obj.cleanup()
590
- shutil.rmtree(temp_dir, ignore_errors=True)
591
- return redirect(url_for('index'))
592
-
593
- # --- Optional Preprocessing Steps ---
594
- nifti_to_predict_path = nifti_for_preprocessing_path
595
- if run_preprocess_flag:
596
- print("--- Running Optional Preprocessing Pipeline ---")
597
- try:
598
- registered_path = os.path.join(temp_dir, "registered.nii.gz")
599
- register_image(nifti_for_preprocessing_path, registered_path)
600
- enhanced_path = os.path.join(temp_dir, "enhanced.nii.gz")
601
- run_enhance_on_file(registered_path, enhanced_path)
602
- skullstrip_output_dir = os.path.join(temp_dir, "skullstripped")
603
- skullstripped_path, _ = run_skull_stripping(enhanced_path, skullstrip_output_dir)
604
- nifti_to_predict_path = skullstripped_path
605
- print("--- Optional Preprocessing Pipeline Complete ---")
606
- except Exception as e:
607
- print(f"Error during optional preprocessing pipeline: {e}")
608
- traceback.print_exc()
609
- flash(f'Error during preprocessing: {e}', 'error')
610
- # temp_dir_obj.cleanup()
611
- shutil.rmtree(temp_dir, ignore_errors=True)
612
- return redirect(url_for('index'))
613
- else:
614
- print("--- Skipping Optional Preprocessing Pipeline ---")
615
-
616
- # --- Final Preprocessing for Model & Prediction ---
617
- input_tensor_5d = preprocess_nifti_for_model(nifti_to_predict_path)
618
- print("Performing prediction...")
619
- with torch.no_grad():
620
- output = model(input_tensor_5d)
621
- predicted_age = output.item()
622
- predicted_age_years = predicted_age / 12 # Adjust if needed
623
- print(f"Prediction successful: {predicted_age_years:.2f} years")
624
-
625
- # --- Saliency Data Handling (Generate, Save, Get Initial Plot) ---
626
- saliency_output_for_template = None # Initialize
627
- if generate_saliency_flag:
628
- print("--- Generating & Saving Saliency Data ---")
629
- try:
630
- input_3d_for_plot, saliency_3d = generate_saliency(model, input_tensor_5d)
631
-
632
- if input_3d_for_plot is not None and saliency_3d is not None:
633
- num_slices = input_3d_for_plot.shape[2]
634
- center_slice_index = num_slices // 2
635
-
636
- # Save the numpy arrays for the dynamic route
637
- input_array_path = os.path.join(temp_dir, f"{unique_id}_input.npy")
638
- saliency_array_path = os.path.join(temp_dir, f"{unique_id}_saliency.npy")
639
- np.save(input_array_path, input_3d_for_plot)
640
- np.save(saliency_array_path, saliency_3d)
641
- print(f"Saved input array to {input_array_path}")
642
- print(f"Saved saliency array to {saliency_array_path}")
643
-
644
- # Generate ONLY the center slice plots for the initial view
645
- center_slice_plots = create_plot_images_for_slice(input_3d_for_plot, saliency_3d, center_slice_index)
646
-
647
- if center_slice_plots:
648
- # Prepare data structure for the template
649
- saliency_output_for_template = {
650
- 'center_slice_plots': center_slice_plots,
651
- 'num_slices': num_slices,
652
- 'center_slice_index': center_slice_index,
653
- 'unique_id': unique_id, # Pass the ID for filenames
654
- 'temp_dir_path': temp_dir # Pass the full path for lookup
655
- }
656
- print("--- Saliency Data Saved & Initial Plot Generated ---")
657
- else:
658
- print("--- Center Slice Plotting Failed ---")
659
- flash('Failed to generate initial saliency plot.', 'warning')
660
- else:
661
- print("--- Saliency Generation Failed --- ")
662
- flash('Saliency map generation failed.', 'warning')
663
-
664
- except Exception as e:
665
- print(f"Error during saliency processing/saving: {e}")
666
- traceback.print_exc()
667
- flash('Could not generate or save saliency maps due to an error.', 'error')
668
-
669
- # Render result, passing prediction and potentially the NEW saliency structure
670
- return render_template('index.html',
671
- prediction=f"{predicted_age_years:.2f} years",
672
- saliency_info=saliency_output_for_template) # Pass the new dict
673
-
674
- except Exception as e:
675
- flash(f'Error processing file: {e}', 'error')
676
- print(f"Caught Exception during prediction process: {e}")
677
- traceback.print_exc()
678
- # Ensure cleanup happens even if exception occurs mid-process
679
- # temp_dir_obj.cleanup()
680
- if temp_dir and os.path.exists(temp_dir):
681
- shutil.rmtree(temp_dir, ignore_errors=True) # Manual cleanup on general error
682
- return redirect(url_for('index'))
683
-
684
- # NOTE: Temporary directory created with mkdtemp is NOT automatically cleaned.
685
- # Need a separate mechanism (e.g., cron job, background task) to remove old directories
686
- # from the system's temporary location (e.g., /tmp) based on age.
687
- # Leaving the directory here so /get_slice can access the files.
688
-
689
- # --- New Route for Dynamic Slice Loading ---
690
- @app.route('/get_slice/<unique_id>/<int:slice_index>')
691
- def get_slice(unique_id, slice_index):
692
- # Get the actual temporary directory path from query parameter
693
- temp_dir_path = request.args.get('path')
694
- if not temp_dir_path:
695
- print("Error: 'path' query parameter missing in /get_slice request")
696
- return jsonify({"error": "Required path information missing."}), 400
697
-
698
- # Construct paths using the provided directory path and unique ID
699
- input_array_path = os.path.join(temp_dir_path, f"{unique_id}_input.npy")
700
- saliency_array_path = os.path.join(temp_dir_path, f"{unique_id}_saliency.npy")
701
- print(f"Attempting to load slice {slice_index} for ID {unique_id} from actual path: {temp_dir_path}")
702
-
703
- try:
704
- # Check using the exact paths constructed above
705
- if not os.path.exists(input_array_path) or not os.path.exists(saliency_array_path):
706
- print(f"Error: .npy files not found for ID {unique_id} at {temp_dir_path}")
707
- return jsonify({"error": "Saliency data not found. It might have expired or failed to save."}), 404
708
-
709
- input_3d = np.load(input_array_path)
710
- saliency_3d = np.load(saliency_array_path)
711
- print(f"Loaded arrays for ID {unique_id}. Input shape: {input_3d.shape}, Saliency shape: {saliency_3d.shape}")
712
-
713
- # Generate plots for the requested slice using the helper function
714
- slice_plots = create_plot_images_for_slice(input_3d, saliency_3d, slice_index)
715
-
716
- if slice_plots:
717
- return jsonify(slice_plots) # Return plot data as JSON
718
- else:
719
- return jsonify({"error": f"Failed to generate plots for slice {slice_index}."}), 500
720
-
721
- except Exception as e:
722
- print(f"Error in /get_slice for ID {unique_id}, slice {slice_index}: {e}")
723
- traceback.print_exc()
724
- return jsonify({"error": "An internal error occurred while fetching the slice data."}), 500
725
-
726
- if __name__ == '__main__':
727
- # Use '0.0.0.0' to make it accessible outside the container
728
- app.run(host='0.0.0.0', port=5000, debug=False) # Turn off debug for production/docker