|
|
|
import torch |
|
import torch.nn as nn |
|
import os |
|
import pickle |
|
import pandas as pd |
|
import numpy as np |
|
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score |
|
from fuson_plm.utils.logging import log_update, open_logfile |
|
from fuson_plm.benchmarking.caid.plot import plot_fusion_stats_boxplots, plot_fusion_frac_disorder_r2 |
|
|
|
|
|
def calc_metrics(row): |
|
probs = row['prob_1'] |
|
probs = [float(y) for y in probs.split(',')] |
|
true_labels = row['Label'] |
|
true_labels = [int(y) for y in list(true_labels)] |
|
pred_labels = row['pred_labels'] |
|
pred_labels = [int(y) for y in list(pred_labels)] |
|
|
|
|
|
|
|
|
|
|
|
flat_binary_preds = np.array(pred_labels) |
|
flat_prob_preds = np.array(probs) |
|
flat_labels = np.array(true_labels) |
|
|
|
accuracy = accuracy_score(flat_labels, flat_binary_preds) |
|
precision = precision_score(flat_labels, flat_binary_preds) |
|
recall = recall_score(flat_labels, flat_binary_preds) |
|
f1 = f1_score(flat_labels, flat_binary_preds) |
|
try: |
|
roc_auc = roc_auc_score(flat_labels, flat_prob_preds) |
|
except: |
|
roc_auc = np.nan |
|
|
|
try: |
|
auprc = average_precision_score(flat_labels, flat_prob_preds) |
|
except: |
|
auprc = np.nan |
|
|
|
return pd.Series({ |
|
'Accuracy': round(accuracy,3), |
|
'Precision': round(precision,3), |
|
'Recall': round(recall,3), |
|
'F1': round(f1,3), |
|
'AUROC': round(roc_auc,3) if not(np.isnan(roc_auc)) else roc_auc, |
|
'AUPRC': round(auprc,3) if not(np.isnan(auprc)) else auprc, |
|
}) |
|
|
|
def get_model_preds_with_metrics(path_to_model_predictions): |
|
|
|
fusion_benchmark_set = pd.read_csv('splits/fusion_bench_df.csv') |
|
model_predictions = pd.read_csv(path_to_model_predictions) |
|
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv') |
|
fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1]) |
|
|
|
|
|
fuson_db = pd.read_csv('../../data/fuson_db.csv') |
|
fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'}) |
|
fusion_structure_data = pd.merge( |
|
fusion_structure_data, |
|
fuson_db, |
|
on='Fusion_Seq', |
|
how='inner' |
|
) |
|
|
|
|
|
swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv") |
|
fusion_structure_data = pd.merge( |
|
fusion_structure_data, |
|
swissprot_top_alignments, |
|
on="seq_id", |
|
how="left" |
|
) |
|
|
|
model_predictions_labeled = pd.merge(model_predictions,fusion_benchmark_set.rename(columns={'Sequence':'sequence'}),on='sequence',how='inner') |
|
model_predictions_labeled = pd.merge(model_predictions_labeled, |
|
fusion_structure_data[['FusionGene','Fusion_Seq','Fusion_Structure_Link','Fusion_pLDDT','Fusion_AA_pLDDTs', |
|
'top_hg_UniProtID', 'top_hg_UniProt_isoform', 'top_hg_UniProt_fus_indices', 'top_tg_UniProtID', 'top_tg_UniProt_isoform', |
|
'top_tg_UniProt_fus_indices', 'top_UniProtID', 'top_UniProt_isoform', 'top_UniProt_fus_indices', 'top_UniProt_nIdentities', |
|
'top_UniProt_nPositives']].rename( |
|
columns={'Fusion_Seq': 'sequence'} |
|
), |
|
on='sequence', |
|
how='left') |
|
model_predictions_labeled['length'] = model_predictions_labeled['sequence'].str.len() |
|
model_predictions_labeled['Fusion_Structure_Link'] = model_predictions_labeled['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1]) |
|
|
|
model_predictions_labeled[['Accuracy','Precision','Recall','F1','AUROC','AUPRC']] = model_predictions_labeled.apply(lambda row: calc_metrics(row),axis=1) |
|
model_predictions_labeled = model_predictions_labeled.sort_values(by=['AUROC','F1','AUPRC','Accuracy','Precision','Recall'],ascending=[False,False,False,False,False,False]).reset_index(drop=True) |
|
model_predictions_labeled['pcnt_disordered'] = round(100*model_predictions_labeled['Label'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2) |
|
model_predictions_labeled['pred_pcnt_disordered'] = round(100*model_predictions_labeled['pred_labels'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2) |
|
log_update(f"Model predictions for {len(model_predictions_labeled)} fusion oncoproteins. Preview:") |
|
log_update( |
|
model_predictions_labeled[['sequence','length','FusionGene','Fusion_pLDDT','pcnt_disordered','pred_pcnt_disordered','AUROC','F1','AUPRC','Accuracy','Precision','Recall']].head() |
|
) |
|
cols_str = '\n\t'+ '\n\t'.join(list(model_predictions_labeled.columns)) |
|
log_update(f"Columns in model predictions stats database: {cols_str}") |
|
|
|
|
|
duplicate_sequences = model_predictions_labeled.loc[model_predictions_labeled['sequence'].duplicated()]['sequence'].unique().tolist() |
|
log_update(f"\nTotal duplicate sequences: {len(duplicate_sequences)}") |
|
gb = model_predictions_labeled.groupby('sequence').agg( |
|
pred_labels=("pred_labels", list), |
|
) |
|
gb['pred_labels'] = gb['pred_labels'].apply(lambda x: list(set(x))) |
|
gb['unique_label_vectors'] = gb['pred_labels'].apply(lambda x: len(x)) |
|
log_update(f"Duplicate entries for sequences have the exact same label vector: {(gb['unique_label_vectors']==1).all()}") |
|
log_update("\tSince above statement is true, randomly dropping duplicate sequence rows - should make no difference to prediction.") |
|
|
|
model_predictions_labeled = model_predictions_labeled.drop_duplicates('sequence').reset_index(drop=True) |
|
|
|
return model_predictions_labeled |
|
|
|
def calc_average_stats(model_pred_stats): |
|
|
|
averages = model_pred_stats[[ |
|
'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC', 'AUPRC' |
|
]].mean() |
|
averages |
|
|
|
def main(): |
|
with open_logfile("analyze_fusion_preds.txt"): |
|
|
|
path_to_model_predictions = "trained_models/fuson_plm/best/caid_hyperparam_screen_fusion_benchmark_probs.csv" |
|
save_dir = "results/final" |
|
preds_with_metrics_save_path = f"{save_dir}/model_preds_with_metrics.csv" |
|
boxplot_save_path = f"{save_dir}/fusion_disorder_boxplots.png" |
|
r2_save_path = "results/final/fusion_pred_disorder_r2.png" |
|
|
|
|
|
fuson_db = pd.read_csv("../../data/fuson_db.csv") |
|
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) |
|
model_preds_with_metrics = get_model_preds_with_metrics(path_to_model_predictions) |
|
model_preds_with_metrics['seq_id'] = model_preds_with_metrics['sequence'].map(seq_id_dict) |
|
model_preds_with_metrics.to_csv(preds_with_metrics_save_path,index=False) |
|
print(model_preds_with_metrics.columns) |
|
|
|
|
|
boxplot_model_preds = model_preds_with_metrics[['seq_id','FusionGene', |
|
'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC' |
|
]] |
|
|
|
boxplot_model_preds.to_csv(boxplot_save_path.replace(".png","_source_data.csv"),index=False) |
|
plot_fusion_stats_boxplots(boxplot_model_preds[['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC' |
|
]], save_path=boxplot_save_path) |
|
|
|
|
|
r2_model_preds = model_preds_with_metrics[['seq_id','FusionGene','pcnt_disordered','pred_pcnt_disordered']] |
|
r2_model_preds.to_csv(r2_save_path.replace(".png","_source_data.csv"),index=False) |
|
plot_fusion_frac_disorder_r2(r2_model_preds['pcnt_disordered'], r2_model_preds['pred_pcnt_disordered'], save_path=r2_save_path) |
|
calc_average_stats(model_preds_with_metrics) |
|
|
|
if __name__ == "__main__": |
|
main() |