File size: 8,587 Bytes
bae913a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
import torch.nn as nn
import os
import pickle
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score
from fuson_plm.utils.logging import log_update, open_logfile
from fuson_plm.benchmarking.caid.plot import plot_fusion_stats_boxplots, plot_fusion_frac_disorder_r2
# calculate AUROC and AUPRC for each sequence
def calc_metrics(row):
probs = row['prob_1']
probs = [float(y) for y in probs.split(',')]
true_labels = row['Label']
true_labels = [int(y) for y in list(true_labels)]
pred_labels = row['pred_labels']
pred_labels = [int(y) for y in list(pred_labels)]
# Calculate AUROC
# Calculate AUPRC
# Calculate all the other stats based on the predicted labels
flat_binary_preds = np.array(pred_labels)
flat_prob_preds = np.array(probs)
flat_labels = np.array(true_labels)
accuracy = accuracy_score(flat_labels, flat_binary_preds)
precision = precision_score(flat_labels, flat_binary_preds)
recall = recall_score(flat_labels, flat_binary_preds)
f1 = f1_score(flat_labels, flat_binary_preds)
try:
roc_auc = roc_auc_score(flat_labels, flat_prob_preds)
except:
roc_auc = np.nan
try:
auprc = average_precision_score(flat_labels, flat_prob_preds)
except:
auprc = np.nan
return pd.Series({
'Accuracy': round(accuracy,3),
'Precision': round(precision,3),
'Recall': round(recall,3),
'F1': round(f1,3),
'AUROC': round(roc_auc,3) if not(np.isnan(roc_auc)) else roc_auc,
'AUPRC': round(auprc,3) if not(np.isnan(auprc)) else auprc,
})
def get_model_preds_with_metrics(path_to_model_predictions):
# Define paths and dataframes that we will need
fusion_benchmark_set = pd.read_csv('splits/fusion_bench_df.csv')
model_predictions = pd.read_csv(path_to_model_predictions)
fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])
# merge fusion data with seq ids
fuson_db = pd.read_csv('../../data/fuson_db.csv')
fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'})
fusion_structure_data = pd.merge(
fusion_structure_data,
fuson_db,
on='Fusion_Seq',
how='inner'
)
# merge fusion structure data with top swissprot alignments
swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv")
fusion_structure_data = pd.merge(
fusion_structure_data,
swissprot_top_alignments,
on="seq_id",
how="left"
)
model_predictions_labeled = pd.merge(model_predictions,fusion_benchmark_set.rename(columns={'Sequence':'sequence'}),on='sequence',how='inner')
model_predictions_labeled = pd.merge(model_predictions_labeled,
fusion_structure_data[['FusionGene','Fusion_Seq','Fusion_Structure_Link','Fusion_pLDDT','Fusion_AA_pLDDTs',
'top_hg_UniProtID', 'top_hg_UniProt_isoform', 'top_hg_UniProt_fus_indices', 'top_tg_UniProtID', 'top_tg_UniProt_isoform',
'top_tg_UniProt_fus_indices', 'top_UniProtID', 'top_UniProt_isoform', 'top_UniProt_fus_indices', 'top_UniProt_nIdentities',
'top_UniProt_nPositives']].rename(
columns={'Fusion_Seq': 'sequence'}
),
on='sequence',
how='left')
model_predictions_labeled['length'] = model_predictions_labeled['sequence'].str.len()
model_predictions_labeled['Fusion_Structure_Link'] = model_predictions_labeled['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])
model_predictions_labeled[['Accuracy','Precision','Recall','F1','AUROC','AUPRC']] = model_predictions_labeled.apply(lambda row: calc_metrics(row),axis=1)
model_predictions_labeled = model_predictions_labeled.sort_values(by=['AUROC','F1','AUPRC','Accuracy','Precision','Recall'],ascending=[False,False,False,False,False,False]).reset_index(drop=True)
model_predictions_labeled['pcnt_disordered'] = round(100*model_predictions_labeled['Label'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2)
model_predictions_labeled['pred_pcnt_disordered'] = round(100*model_predictions_labeled['pred_labels'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2)
log_update(f"Model predictions for {len(model_predictions_labeled)} fusion oncoproteins. Preview:")
log_update(
model_predictions_labeled[['sequence','length','FusionGene','Fusion_pLDDT','pcnt_disordered','pred_pcnt_disordered','AUROC','F1','AUPRC','Accuracy','Precision','Recall']].head()
)
cols_str = '\n\t'+ '\n\t'.join(list(model_predictions_labeled.columns))
log_update(f"Columns in model predictions stats database: {cols_str}")
# There is one duplicate row
duplicate_sequences = model_predictions_labeled.loc[model_predictions_labeled['sequence'].duplicated()]['sequence'].unique().tolist()
log_update(f"\nTotal duplicate sequences: {len(duplicate_sequences)}")
gb = model_predictions_labeled.groupby('sequence').agg(
pred_labels=("pred_labels", list),
)
gb['pred_labels'] = gb['pred_labels'].apply(lambda x: list(set(x)))
gb['unique_label_vectors'] = gb['pred_labels'].apply(lambda x: len(x))
log_update(f"Duplicate entries for sequences have the exact same label vector: {(gb['unique_label_vectors']==1).all()}")
log_update("\tSince above statement is true, randomly dropping duplicate sequence rows - should make no difference to prediction.")
model_predictions_labeled = model_predictions_labeled.drop_duplicates('sequence').reset_index(drop=True)
#os.makedirs("processed_data/fusion_predictions",exist_ok=True)
return model_predictions_labeled
def calc_average_stats(model_pred_stats):
# cols: Accuracy Precision Recall F1 AUROC AUPRC pcnt_disordered pred_pcnt_disordered
averages = model_pred_stats[[
'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC', 'AUPRC'
]].mean()
averages
def main():
with open_logfile("analyze_fusion_preds.txt"):
## Put path to model predictions you'd like to use for benchmarking
path_to_model_predictions = "trained_models/fuson_plm/best/caid_hyperparam_screen_fusion_benchmark_probs.csv"
save_dir = "results/final"
preds_with_metrics_save_path = f"{save_dir}/model_preds_with_metrics.csv"
boxplot_save_path = f"{save_dir}/fusion_disorder_boxplots.png"
r2_save_path = "results/final/fusion_pred_disorder_r2.png"
# Additional benchmarking on fusion predictions
fuson_db = pd.read_csv("../../data/fuson_db.csv")
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
model_preds_with_metrics = get_model_preds_with_metrics(path_to_model_predictions)
model_preds_with_metrics['seq_id'] = model_preds_with_metrics['sequence'].map(seq_id_dict)
model_preds_with_metrics.to_csv(preds_with_metrics_save_path,index=False)
print(model_preds_with_metrics.columns)
# Box plot
boxplot_model_preds = model_preds_with_metrics[['seq_id','FusionGene',
'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'
]]
boxplot_model_preds.to_csv(boxplot_save_path.replace(".png","_source_data.csv"),index=False)
plot_fusion_stats_boxplots(boxplot_model_preds[['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'
]], save_path=boxplot_save_path)
# R2 plot
r2_model_preds = model_preds_with_metrics[['seq_id','FusionGene','pcnt_disordered','pred_pcnt_disordered']]
r2_model_preds.to_csv(r2_save_path.replace(".png","_source_data.csv"),index=False)
plot_fusion_frac_disorder_r2(r2_model_preds['pcnt_disordered'], r2_model_preds['pred_pcnt_disordered'], save_path=r2_save_path)
calc_average_stats(model_preds_with_metrics)
if __name__ == "__main__":
main() |