|
from sklearn.model_selection import train_test_split, StratifiedKFold |
|
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, roc_auc_score, average_precision_score |
|
from fuson_plm.utils.logging import log_update |
|
import time |
|
import xgboost as xgb |
|
import numpy as np |
|
import pandas as pd |
|
|
|
def train_final_predictor(X_train, y_train, n_estimators=50,tree_method="hist"): |
|
clf = xgb.XGBClassifier(n_estimators=n_estimators, tree_method=tree_method) |
|
clf.fit(X_train, y_train) |
|
return clf |
|
|
|
def evaluate_predictor(clf,X_test,y_test,class1_thresh=None): |
|
|
|
y_pred_test = clf.predict(X_test) |
|
y_pred_prob_test = clf.predict_proba(X_test)[:, 1] |
|
if class1_thresh is not None: y_pred_customthresh_test = np.where(np.array(y_pred_prob_test) >= class1_thresh, 1, 0) |
|
|
|
|
|
accuracy = accuracy_score(y_test, y_pred_test) |
|
precision = precision_score(y_test, y_pred_test) |
|
recall = recall_score(y_test, y_pred_test) |
|
f1 = f1_score(y_test, y_pred_test) |
|
auroc_prob = roc_auc_score(y_test, y_pred_prob_test) |
|
auprc_prob = average_precision_score(y_test, y_pred_prob_test) |
|
auroc_label = roc_auc_score(y_test, y_pred_test) |
|
auprc_label = average_precision_score(y_test, y_pred_test) |
|
|
|
automatic_stats_df = pd.DataFrame(data={ |
|
'Accuracy': [accuracy], |
|
'Precision': [precision], |
|
'Recall': [recall], |
|
'F1 Score': [f1], |
|
'AUROC': [auroc_prob], |
|
'AUROC Label': [auroc_label], |
|
'AUPRC': [auprc_prob], |
|
'AUPRC Label': [auprc_label] |
|
}) |
|
|
|
|
|
if class1_thresh is not None: |
|
accuracy_custom = accuracy_score(y_test, y_pred_customthresh_test) |
|
precision_custom = precision_score(y_test, y_pred_customthresh_test) |
|
recall_custom = recall_score(y_test, y_pred_customthresh_test) |
|
f1_custom = f1_score(y_test, y_pred_customthresh_test) |
|
auroc_prob_custom = roc_auc_score(y_test, y_pred_prob_test) |
|
auprc_prob_custom = average_precision_score(y_test, y_pred_prob_test) |
|
auroc_label_custom = roc_auc_score(y_test, y_pred_customthresh_test) |
|
auprc_label_custom = average_precision_score(y_test, y_pred_customthresh_test) |
|
|
|
custom_stats_df = pd.DataFrame(data={ |
|
'Accuracy': [accuracy_custom], |
|
'Precision': [precision_custom], |
|
'Recall': [recall_custom], |
|
'F1 Score': [f1_custom], |
|
'AUROC': [auroc_prob_custom], |
|
'AUROC Label': [auroc_label_custom], |
|
'AUPRC': [auprc_prob_custom], |
|
'AUPRC Label': [auprc_label_custom] |
|
}) |
|
else: |
|
custom_stats_df = None |
|
|
|
return automatic_stats_df, custom_stats_df |