FusOn-pLM / fuson_plm /benchmarking /caid /analyze_fusion_preds.py
svincoff's picture
caid benchmark
bae913a
import torch
import torch.nn as nn
import os
import pickle
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score
from fuson_plm.utils.logging import log_update, open_logfile
from fuson_plm.benchmarking.caid.plot import plot_fusion_stats_boxplots, plot_fusion_frac_disorder_r2
# calculate AUROC and AUPRC for each sequence
def calc_metrics(row):
probs = row['prob_1']
probs = [float(y) for y in probs.split(',')]
true_labels = row['Label']
true_labels = [int(y) for y in list(true_labels)]
pred_labels = row['pred_labels']
pred_labels = [int(y) for y in list(pred_labels)]
# Calculate AUROC
# Calculate AUPRC
# Calculate all the other stats based on the predicted labels
flat_binary_preds = np.array(pred_labels)
flat_prob_preds = np.array(probs)
flat_labels = np.array(true_labels)
accuracy = accuracy_score(flat_labels, flat_binary_preds)
precision = precision_score(flat_labels, flat_binary_preds)
recall = recall_score(flat_labels, flat_binary_preds)
f1 = f1_score(flat_labels, flat_binary_preds)
try:
roc_auc = roc_auc_score(flat_labels, flat_prob_preds)
except:
roc_auc = np.nan
try:
auprc = average_precision_score(flat_labels, flat_prob_preds)
except:
auprc = np.nan
return pd.Series({
'Accuracy': round(accuracy,3),
'Precision': round(precision,3),
'Recall': round(recall,3),
'F1': round(f1,3),
'AUROC': round(roc_auc,3) if not(np.isnan(roc_auc)) else roc_auc,
'AUPRC': round(auprc,3) if not(np.isnan(auprc)) else auprc,
})
def get_model_preds_with_metrics(path_to_model_predictions):
# Define paths and dataframes that we will need
fusion_benchmark_set = pd.read_csv('splits/fusion_bench_df.csv')
model_predictions = pd.read_csv(path_to_model_predictions)
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])
# merge fusion data with seq ids
fuson_db = pd.read_csv('../../data/fuson_db.csv')
fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'})
fusion_structure_data = pd.merge(
fusion_structure_data,
fuson_db,
on='Fusion_Seq',
how='inner'
)
# merge fusion structure data with top swissprot alignments
swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv")
fusion_structure_data = pd.merge(
fusion_structure_data,
swissprot_top_alignments,
on="seq_id",
how="left"
)
model_predictions_labeled = pd.merge(model_predictions,fusion_benchmark_set.rename(columns={'Sequence':'sequence'}),on='sequence',how='inner')
model_predictions_labeled = pd.merge(model_predictions_labeled,
fusion_structure_data[['FusionGene','Fusion_Seq','Fusion_Structure_Link','Fusion_pLDDT','Fusion_AA_pLDDTs',
'top_hg_UniProtID', 'top_hg_UniProt_isoform', 'top_hg_UniProt_fus_indices', 'top_tg_UniProtID', 'top_tg_UniProt_isoform',
'top_tg_UniProt_fus_indices', 'top_UniProtID', 'top_UniProt_isoform', 'top_UniProt_fus_indices', 'top_UniProt_nIdentities',
'top_UniProt_nPositives']].rename(
columns={'Fusion_Seq': 'sequence'}
),
on='sequence',
how='left')
model_predictions_labeled['length'] = model_predictions_labeled['sequence'].str.len()
model_predictions_labeled['Fusion_Structure_Link'] = model_predictions_labeled['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])
model_predictions_labeled[['Accuracy','Precision','Recall','F1','AUROC','AUPRC']] = model_predictions_labeled.apply(lambda row: calc_metrics(row),axis=1)
model_predictions_labeled = model_predictions_labeled.sort_values(by=['AUROC','F1','AUPRC','Accuracy','Precision','Recall'],ascending=[False,False,False,False,False,False]).reset_index(drop=True)
model_predictions_labeled['pcnt_disordered'] = round(100*model_predictions_labeled['Label'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2)
model_predictions_labeled['pred_pcnt_disordered'] = round(100*model_predictions_labeled['pred_labels'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2)
log_update(f"Model predictions for {len(model_predictions_labeled)} fusion oncoproteins. Preview:")
log_update(
model_predictions_labeled[['sequence','length','FusionGene','Fusion_pLDDT','pcnt_disordered','pred_pcnt_disordered','AUROC','F1','AUPRC','Accuracy','Precision','Recall']].head()
)
cols_str = '\n\t'+ '\n\t'.join(list(model_predictions_labeled.columns))
log_update(f"Columns in model predictions stats database: {cols_str}")
# There is one duplicate row
duplicate_sequences = model_predictions_labeled.loc[model_predictions_labeled['sequence'].duplicated()]['sequence'].unique().tolist()
log_update(f"\nTotal duplicate sequences: {len(duplicate_sequences)}")
gb = model_predictions_labeled.groupby('sequence').agg(
pred_labels=("pred_labels", list),
)
gb['pred_labels'] = gb['pred_labels'].apply(lambda x: list(set(x)))
gb['unique_label_vectors'] = gb['pred_labels'].apply(lambda x: len(x))
log_update(f"Duplicate entries for sequences have the exact same label vector: {(gb['unique_label_vectors']==1).all()}")
log_update("\tSince above statement is true, randomly dropping duplicate sequence rows - should make no difference to prediction.")
model_predictions_labeled = model_predictions_labeled.drop_duplicates('sequence').reset_index(drop=True)
#os.makedirs("processed_data/fusion_predictions",exist_ok=True)
return model_predictions_labeled
def calc_average_stats(model_pred_stats):
# cols: Accuracy Precision Recall F1 AUROC AUPRC pcnt_disordered pred_pcnt_disordered
averages = model_pred_stats[[
'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC', 'AUPRC'
]].mean()
averages
def main():
with open_logfile("analyze_fusion_preds.txt"):
## Put path to model predictions you'd like to use for benchmarking
path_to_model_predictions = "trained_models/fuson_plm/best/caid_hyperparam_screen_fusion_benchmark_probs.csv"
save_dir = "results/final"
preds_with_metrics_save_path = f"{save_dir}/model_preds_with_metrics.csv"
boxplot_save_path = f"{save_dir}/fusion_disorder_boxplots.png"
r2_save_path = "results/final/fusion_pred_disorder_r2.png"
# Additional benchmarking on fusion predictions
fuson_db = pd.read_csv("../../data/fuson_db.csv")
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
model_preds_with_metrics = get_model_preds_with_metrics(path_to_model_predictions)
model_preds_with_metrics['seq_id'] = model_preds_with_metrics['sequence'].map(seq_id_dict)
model_preds_with_metrics.to_csv(preds_with_metrics_save_path,index=False)
print(model_preds_with_metrics.columns)
# Box plot
boxplot_model_preds = model_preds_with_metrics[['seq_id','FusionGene',
'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'
]]
boxplot_model_preds.to_csv(boxplot_save_path.replace(".png","_source_data.csv"),index=False)
plot_fusion_stats_boxplots(boxplot_model_preds[['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'
]], save_path=boxplot_save_path)
# R2 plot
r2_model_preds = model_preds_with_metrics[['seq_id','FusionGene','pcnt_disordered','pred_pcnt_disordered']]
r2_model_preds.to_csv(r2_save_path.replace(".png","_source_data.csv"),index=False)
plot_fusion_frac_disorder_r2(r2_model_preds['pcnt_disordered'], r2_model_preds['pred_pcnt_disordered'], save_path=r2_save_path)
calc_average_stats(model_preds_with_metrics)
if __name__ == "__main__":
main()