import torch import torch.nn as nn import os import pickle import pandas as pd import numpy as np from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score from fuson_plm.utils.logging import log_update, open_logfile from fuson_plm.benchmarking.caid.plot import plot_fusion_stats_boxplots, plot_fusion_frac_disorder_r2 # calculate AUROC and AUPRC for each sequence def calc_metrics(row): probs = row['prob_1'] probs = [float(y) for y in probs.split(',')] true_labels = row['Label'] true_labels = [int(y) for y in list(true_labels)] pred_labels = row['pred_labels'] pred_labels = [int(y) for y in list(pred_labels)] # Calculate AUROC # Calculate AUPRC # Calculate all the other stats based on the predicted labels flat_binary_preds = np.array(pred_labels) flat_prob_preds = np.array(probs) flat_labels = np.array(true_labels) accuracy = accuracy_score(flat_labels, flat_binary_preds) precision = precision_score(flat_labels, flat_binary_preds) recall = recall_score(flat_labels, flat_binary_preds) f1 = f1_score(flat_labels, flat_binary_preds) try: roc_auc = roc_auc_score(flat_labels, flat_prob_preds) except: roc_auc = np.nan try: auprc = average_precision_score(flat_labels, flat_prob_preds) except: auprc = np.nan return pd.Series({ 'Accuracy': round(accuracy,3), 'Precision': round(precision,3), 'Recall': round(recall,3), 'F1': round(f1,3), 'AUROC': round(roc_auc,3) if not(np.isnan(roc_auc)) else roc_auc, 'AUPRC': round(auprc,3) if not(np.isnan(auprc)) else auprc, }) def get_model_preds_with_metrics(path_to_model_predictions): # Define paths and dataframes that we will need fusion_benchmark_set = pd.read_csv('splits/fusion_bench_df.csv') model_predictions = pd.read_csv(path_to_model_predictions) fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv') fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1]) # merge fusion data with seq ids fuson_db = pd.read_csv('../../data/fuson_db.csv') fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'}) fusion_structure_data = pd.merge( fusion_structure_data, fuson_db, on='Fusion_Seq', how='inner' ) # merge fusion structure data with top swissprot alignments swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv") fusion_structure_data = pd.merge( fusion_structure_data, swissprot_top_alignments, on="seq_id", how="left" ) model_predictions_labeled = pd.merge(model_predictions,fusion_benchmark_set.rename(columns={'Sequence':'sequence'}),on='sequence',how='inner') model_predictions_labeled = pd.merge(model_predictions_labeled, fusion_structure_data[['FusionGene','Fusion_Seq','Fusion_Structure_Link','Fusion_pLDDT','Fusion_AA_pLDDTs', 'top_hg_UniProtID', 'top_hg_UniProt_isoform', 'top_hg_UniProt_fus_indices', 'top_tg_UniProtID', 'top_tg_UniProt_isoform', 'top_tg_UniProt_fus_indices', 'top_UniProtID', 'top_UniProt_isoform', 'top_UniProt_fus_indices', 'top_UniProt_nIdentities', 'top_UniProt_nPositives']].rename( columns={'Fusion_Seq': 'sequence'} ), on='sequence', how='left') model_predictions_labeled['length'] = model_predictions_labeled['sequence'].str.len() model_predictions_labeled['Fusion_Structure_Link'] = model_predictions_labeled['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1]) model_predictions_labeled[['Accuracy','Precision','Recall','F1','AUROC','AUPRC']] = model_predictions_labeled.apply(lambda row: calc_metrics(row),axis=1) model_predictions_labeled = model_predictions_labeled.sort_values(by=['AUROC','F1','AUPRC','Accuracy','Precision','Recall'],ascending=[False,False,False,False,False,False]).reset_index(drop=True) model_predictions_labeled['pcnt_disordered'] = round(100*model_predictions_labeled['Label'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2) model_predictions_labeled['pred_pcnt_disordered'] = round(100*model_predictions_labeled['pred_labels'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2) log_update(f"Model predictions for {len(model_predictions_labeled)} fusion oncoproteins. Preview:") log_update( model_predictions_labeled[['sequence','length','FusionGene','Fusion_pLDDT','pcnt_disordered','pred_pcnt_disordered','AUROC','F1','AUPRC','Accuracy','Precision','Recall']].head() ) cols_str = '\n\t'+ '\n\t'.join(list(model_predictions_labeled.columns)) log_update(f"Columns in model predictions stats database: {cols_str}") # There is one duplicate row duplicate_sequences = model_predictions_labeled.loc[model_predictions_labeled['sequence'].duplicated()]['sequence'].unique().tolist() log_update(f"\nTotal duplicate sequences: {len(duplicate_sequences)}") gb = model_predictions_labeled.groupby('sequence').agg( pred_labels=("pred_labels", list), ) gb['pred_labels'] = gb['pred_labels'].apply(lambda x: list(set(x))) gb['unique_label_vectors'] = gb['pred_labels'].apply(lambda x: len(x)) log_update(f"Duplicate entries for sequences have the exact same label vector: {(gb['unique_label_vectors']==1).all()}") log_update("\tSince above statement is true, randomly dropping duplicate sequence rows - should make no difference to prediction.") model_predictions_labeled = model_predictions_labeled.drop_duplicates('sequence').reset_index(drop=True) #os.makedirs("processed_data/fusion_predictions",exist_ok=True) return model_predictions_labeled def calc_average_stats(model_pred_stats): # cols: Accuracy Precision Recall F1 AUROC AUPRC pcnt_disordered pred_pcnt_disordered averages = model_pred_stats[[ 'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC', 'AUPRC' ]].mean() averages def main(): with open_logfile("analyze_fusion_preds.txt"): ## Put path to model predictions you'd like to use for benchmarking path_to_model_predictions = "trained_models/fuson_plm/best/caid_hyperparam_screen_fusion_benchmark_probs.csv" save_dir = "results/final" preds_with_metrics_save_path = f"{save_dir}/model_preds_with_metrics.csv" boxplot_save_path = f"{save_dir}/fusion_disorder_boxplots.png" r2_save_path = "results/final/fusion_pred_disorder_r2.png" # Additional benchmarking on fusion predictions fuson_db = pd.read_csv("../../data/fuson_db.csv") seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) model_preds_with_metrics = get_model_preds_with_metrics(path_to_model_predictions) model_preds_with_metrics['seq_id'] = model_preds_with_metrics['sequence'].map(seq_id_dict) model_preds_with_metrics.to_csv(preds_with_metrics_save_path,index=False) print(model_preds_with_metrics.columns) # Box plot boxplot_model_preds = model_preds_with_metrics[['seq_id','FusionGene', 'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC' ]] boxplot_model_preds.to_csv(boxplot_save_path.replace(".png","_source_data.csv"),index=False) plot_fusion_stats_boxplots(boxplot_model_preds[['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC' ]], save_path=boxplot_save_path) # R2 plot r2_model_preds = model_preds_with_metrics[['seq_id','FusionGene','pcnt_disordered','pred_pcnt_disordered']] r2_model_preds.to_csv(r2_save_path.replace(".png","_source_data.csv"),index=False) plot_fusion_frac_disorder_r2(r2_model_preds['pcnt_disordered'], r2_model_preds['pred_pcnt_disordered'], save_path=r2_save_path) calc_average_stats(model_preds_with_metrics) if __name__ == "__main__": main()