|
import torch |
|
import pandas as pd |
|
import os |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
from torch.cuda.amp import autocast |
|
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score |
|
from lifelines.utils import concordance_index |
|
import numpy as np |
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from dataset2 import MedicalImageDatasetBalancedIntensity3D |
|
from model import Backbone, SingleScanModel, Classifier |
|
from utils import BaseConfig |
|
|
|
|
|
|
|
def calculate_metrics(pred, target): |
|
""" |
|
Calculate evaluation metrics. |
|
Args: |
|
pred (numpy.ndarray): Model predictions |
|
target (numpy.ndarray): Ground truth values |
|
|
|
Returns: |
|
dict: Dictionary containing MAE, RMSE, R2, and C-index metrics |
|
""" |
|
mae = mean_absolute_error(target, pred) |
|
rmse = np.sqrt(mean_squared_error(target, pred)) |
|
r2 = r2_score(target, pred) |
|
try: |
|
c_index = concordance_index(target, pred, np.ones_like(target)) |
|
except: |
|
c_index = 0 |
|
return {'mae': mae, 'rmse': rmse, 'r2': r2, 'c_index': c_index} |
|
|
|
class StrokeTimeInference(BaseConfig): |
|
""" |
|
Inference class for time to stroke prediction model. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.setup_model() |
|
self.setup_data() |
|
|
|
|
|
def setup_model(self): |
|
config = self.get_config() |
|
self.backbone = Backbone() |
|
self.classifier = Classifier(d_model=2048) |
|
self.model = SingleScanModel(self.backbone, self.classifier) |
|
|
|
|
|
checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device, weights_only=False) |
|
self.model.load_state_dict(checkpoint["model_state_dict"]) |
|
self.model = self.model.to(self.device) |
|
self.model.eval() |
|
print("Model and checkpoint loaded!") |
|
|
|
def setup_data(self): |
|
""" |
|
Set up data loader for inference. |
|
""" |
|
config = self.get_config() |
|
self.test_dataset = MedicalImageDatasetBalancedIntensity3D( |
|
csv_path=config["data"]["test_csv"], |
|
root_dir=config["data"]["root_dir"] |
|
) |
|
self.test_loader = DataLoader( |
|
self.test_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
collate_fn=self.custom_collate, |
|
num_workers=1 |
|
) |
|
|
|
def infer(self): |
|
""" |
|
Run inference pass |
|
|
|
Returns: |
|
dict: Dictionary containing evaluation metrics |
|
""" |
|
results_df = pd.DataFrame(columns=['PredictedTime', 'TrueTime']) |
|
all_labels = [] |
|
all_predictions = [] |
|
|
|
with torch.no_grad(): |
|
for sample in tqdm(self.test_loader, desc="Inference", unit="batch"): |
|
inputs = sample['image'].to(self.device) |
|
labels = sample['label'].float().to(self.device) |
|
|
|
with autocast(): |
|
outputs = self.model(inputs) |
|
|
|
predictions = outputs.cpu().numpy().flatten() |
|
all_labels.extend(labels.cpu().numpy().flatten()) |
|
all_predictions.extend(predictions) |
|
|
|
result = pd.DataFrame({ |
|
'PredictedTime': predictions, |
|
'TrueTime': labels.cpu().numpy().flatten() |
|
}) |
|
results_df = pd.concat([results_df, result], ignore_index=True) |
|
|
|
|
|
metrics = calculate_metrics(np.array(all_predictions), np.array(all_labels)) |
|
|
|
|
|
print("\nTest Set Metrics:") |
|
print(f"Mean Absolute Error (MAE): {metrics['mae']:.4f}") |
|
print(f"Root Mean Squared Error (RMSE): {metrics['rmse']:.4f}") |
|
print(f"R² Score: {metrics['r2']:.4f}") |
|
print(f"Concordance Index: {metrics['c_index']:.4f}") |
|
|
|
|
|
results_df['AbsoluteError'] = abs(results_df['PredictedTime'] - results_df['TrueTime']) |
|
stats = { |
|
'Metrics': metrics, |
|
'Statistics': { |
|
'Mean_Prediction': np.mean(all_predictions), |
|
'Std_Prediction': np.std(all_predictions), |
|
'Mean_True': np.mean(all_labels), |
|
'Std_True': np.std(all_labels) |
|
} |
|
} |
|
|
|
|
|
results_df.to_csv('./data/output/stroke_time_predictions.csv', index=False) |
|
|
|
|
|
|
|
return metrics |
|
|
|
if __name__ == "__main__": |
|
inferencer = StrokeTimeInference() |
|
metrics = inferencer.infer() |