File size: 2,884 Bytes
c43fbc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, roc_auc_score, average_precision_score
from fuson_plm.utils.logging import log_update
import time
import xgboost as xgb
import numpy as np
import pandas as pd
def train_final_predictor(X_train, y_train, n_estimators=50,tree_method="hist"):
clf = xgb.XGBClassifier(n_estimators=n_estimators, tree_method=tree_method)
clf.fit(X_train, y_train)
return clf
def evaluate_predictor(clf,X_test,y_test,class1_thresh=None):
# Predicting the labels on test set
y_pred_test = clf.predict(X_test) # labels with automatic thresholds
y_pred_prob_test = clf.predict_proba(X_test)[:, 1]
if class1_thresh is not None: y_pred_customthresh_test = np.where(np.array(y_pred_prob_test) >= class1_thresh, 1, 0)
# Calculating metrics - automatic
accuracy = accuracy_score(y_test, y_pred_test)
precision = precision_score(y_test, y_pred_test)
recall = recall_score(y_test, y_pred_test)
f1 = f1_score(y_test, y_pred_test)
auroc_prob = roc_auc_score(y_test, y_pred_prob_test)
auprc_prob = average_precision_score(y_test, y_pred_prob_test)
auroc_label = roc_auc_score(y_test, y_pred_test)
auprc_label = average_precision_score(y_test, y_pred_test)
automatic_stats_df = pd.DataFrame(data={
'Accuracy': [accuracy],
'Precision': [precision],
'Recall': [recall],
'F1 Score': [f1],
'AUROC': [auroc_prob],
'AUROC Label': [auroc_label],
'AUPRC': [auprc_prob],
'AUPRC Label': [auprc_label]
})
# Calculating metrics - custom threshold (note that probability ones won't change)
if class1_thresh is not None:
accuracy_custom = accuracy_score(y_test, y_pred_customthresh_test)
precision_custom = precision_score(y_test, y_pred_customthresh_test)
recall_custom = recall_score(y_test, y_pred_customthresh_test)
f1_custom = f1_score(y_test, y_pred_customthresh_test)
auroc_prob_custom = roc_auc_score(y_test, y_pred_prob_test)
auprc_prob_custom = average_precision_score(y_test, y_pred_prob_test)
auroc_label_custom = roc_auc_score(y_test, y_pred_customthresh_test)
auprc_label_custom = average_precision_score(y_test, y_pred_customthresh_test)
custom_stats_df = pd.DataFrame(data={
'Accuracy': [accuracy_custom],
'Precision': [precision_custom],
'Recall': [recall_custom],
'F1 Score': [f1_custom],
'AUROC': [auroc_prob_custom],
'AUROC Label': [auroc_label_custom],
'AUPRC': [auprc_prob_custom],
'AUPRC Label': [auprc_label_custom]
})
else:
custom_stats_df = None
return automatic_stats_df, custom_stats_df |