File size: 8,587 Bytes
bae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

import torch
import torch.nn as nn
import os
import pickle
import pandas as pd
import numpy as np

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, average_precision_score
from fuson_plm.utils.logging import log_update, open_logfile
from fuson_plm.benchmarking.caid.plot import plot_fusion_stats_boxplots, plot_fusion_frac_disorder_r2

# calculate AUROC and AUPRC for each sequence
def calc_metrics(row):
    probs = row['prob_1']
    probs = [float(y) for y in probs.split(',')]
    true_labels = row['Label']
    true_labels = [int(y) for y in list(true_labels)]
    pred_labels = row['pred_labels']
    pred_labels = [int(y) for y in list(pred_labels)]
    
    # Calculate AUROC
    # Calculate AUPRC
    # Calculate all the other stats based on the predicted labels
    
    flat_binary_preds = np.array(pred_labels)
    flat_prob_preds = np.array(probs)
    flat_labels = np.array(true_labels)

    accuracy = accuracy_score(flat_labels, flat_binary_preds)
    precision = precision_score(flat_labels, flat_binary_preds)
    recall = recall_score(flat_labels, flat_binary_preds)
    f1 = f1_score(flat_labels, flat_binary_preds)
    try:
        roc_auc = roc_auc_score(flat_labels, flat_prob_preds)
    except:
        roc_auc = np.nan
        
    try:
        auprc = average_precision_score(flat_labels, flat_prob_preds)
    except: 
        auprc = np.nan
    
    return pd.Series({
        'Accuracy': round(accuracy,3),
        'Precision': round(precision,3),
        'Recall': round(recall,3),
        'F1': round(f1,3),
        'AUROC': round(roc_auc,3) if not(np.isnan(roc_auc)) else roc_auc,
        'AUPRC': round(auprc,3) if not(np.isnan(auprc)) else auprc,
    })

def get_model_preds_with_metrics(path_to_model_predictions):
    # Define paths and dataframes that we will need 
    fusion_benchmark_set = pd.read_csv('splits/fusion_bench_df.csv')
    model_predictions = pd.read_csv(path_to_model_predictions)
    fusion_structure_data = pd.read_csv('processed_data/fusionpdb/FusionPDB_level2-3_cleaned_structure_info.csv')
    fusion_structure_data['Fusion_Structure_Link'] = fusion_structure_data['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])
    
    # merge fusion data with seq ids 
    fuson_db = pd.read_csv('../../data/fuson_db.csv')
    fuson_db = fuson_db[['aa_seq','seq_id']].rename(columns={'aa_seq':'Fusion_Seq'})
    fusion_structure_data = pd.merge(
        fusion_structure_data,
        fuson_db,
        on='Fusion_Seq',
        how='inner'
    )

    # merge fusion structure data with top swissprot alignments
    swissprot_top_alignments = pd.read_csv("../../data/blast/blast_outputs/swissprot_top_alignments.csv")
    fusion_structure_data = pd.merge(
        fusion_structure_data,
        swissprot_top_alignments,
        on="seq_id",
        how="left"
    )

    model_predictions_labeled = pd.merge(model_predictions,fusion_benchmark_set.rename(columns={'Sequence':'sequence'}),on='sequence',how='inner')
    model_predictions_labeled = pd.merge(model_predictions_labeled, 
                                        fusion_structure_data[['FusionGene','Fusion_Seq','Fusion_Structure_Link','Fusion_pLDDT','Fusion_AA_pLDDTs',
                                                                'top_hg_UniProtID', 'top_hg_UniProt_isoform', 'top_hg_UniProt_fus_indices', 'top_tg_UniProtID', 'top_tg_UniProt_isoform', 
                                                                'top_tg_UniProt_fus_indices', 'top_UniProtID', 'top_UniProt_isoform', 'top_UniProt_fus_indices', 'top_UniProt_nIdentities', 
                                                                'top_UniProt_nPositives']].rename(
                                            columns={'Fusion_Seq': 'sequence'}
                                        ),
                                        on='sequence',
                                        how='left')
    model_predictions_labeled['length'] = model_predictions_labeled['sequence'].str.len()
    model_predictions_labeled['Fusion_Structure_Link'] = model_predictions_labeled['Fusion_Structure_Link'].apply(lambda x: x.split('/')[-1])

    model_predictions_labeled[['Accuracy','Precision','Recall','F1','AUROC','AUPRC']] = model_predictions_labeled.apply(lambda row: calc_metrics(row),axis=1)
    model_predictions_labeled = model_predictions_labeled.sort_values(by=['AUROC','F1','AUPRC','Accuracy','Precision','Recall'],ascending=[False,False,False,False,False,False]).reset_index(drop=True)
    model_predictions_labeled['pcnt_disordered'] = round(100*model_predictions_labeled['Label'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2)
    model_predictions_labeled['pred_pcnt_disordered'] = round(100*model_predictions_labeled['pred_labels'].apply(lambda x: sum([int(y) for y in list(x)]))/model_predictions_labeled['sequence'].str.len(),2)
    log_update(f"Model predictions for {len(model_predictions_labeled)} fusion oncoproteins. Preview:")
    log_update(
        model_predictions_labeled[['sequence','length','FusionGene','Fusion_pLDDT','pcnt_disordered','pred_pcnt_disordered','AUROC','F1','AUPRC','Accuracy','Precision','Recall']].head()
    )
    cols_str = '\n\t'+ '\n\t'.join(list(model_predictions_labeled.columns))
    log_update(f"Columns in model predictions stats database: {cols_str}")
    
    # There is one duplicate row
    duplicate_sequences = model_predictions_labeled.loc[model_predictions_labeled['sequence'].duplicated()]['sequence'].unique().tolist()
    log_update(f"\nTotal duplicate sequences: {len(duplicate_sequences)}")
    gb = model_predictions_labeled.groupby('sequence').agg(
        pred_labels=("pred_labels", list),
    )
    gb['pred_labels'] = gb['pred_labels'].apply(lambda x: list(set(x)))
    gb['unique_label_vectors'] = gb['pred_labels'].apply(lambda x: len(x))
    log_update(f"Duplicate entries for sequences have the exact same label vector: {(gb['unique_label_vectors']==1).all()}")
    log_update("\tSince above statement is true, randomly dropping duplicate sequence rows - should make no difference to prediction.")
    
    model_predictions_labeled = model_predictions_labeled.drop_duplicates('sequence').reset_index(drop=True)
    #os.makedirs("processed_data/fusion_predictions",exist_ok=True)
    return model_predictions_labeled

def calc_average_stats(model_pred_stats):
    # cols: Accuracy	Precision	Recall	F1	AUROC	AUPRC	pcnt_disordered	pred_pcnt_disordered
    averages = model_pred_stats[[
        'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC', 'AUPRC'
    ]].mean()
    averages

def main():
    with open_logfile("analyze_fusion_preds.txt"):
        ## Put path to model predictions you'd like to use for benchmarking
        path_to_model_predictions = "trained_models/fuson_plm/best/caid_hyperparam_screen_fusion_benchmark_probs.csv"
        save_dir = "results/final"
        preds_with_metrics_save_path = f"{save_dir}/model_preds_with_metrics.csv"
        boxplot_save_path = f"{save_dir}/fusion_disorder_boxplots.png"
        r2_save_path = "results/final/fusion_pred_disorder_r2.png"
        
        # Additional benchmarking on fusion predictions
        fuson_db = pd.read_csv("../../data/fuson_db.csv")
        seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
        model_preds_with_metrics = get_model_preds_with_metrics(path_to_model_predictions)
        model_preds_with_metrics['seq_id'] = model_preds_with_metrics['sequence'].map(seq_id_dict)
        model_preds_with_metrics.to_csv(preds_with_metrics_save_path,index=False)
        print(model_preds_with_metrics.columns)
        
        # Box plot
        boxplot_model_preds = model_preds_with_metrics[['seq_id','FusionGene',
        'Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'
        ]]
        
        boxplot_model_preds.to_csv(boxplot_save_path.replace(".png","_source_data.csv"),index=False)
        plot_fusion_stats_boxplots(boxplot_model_preds[['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'
        ]], save_path=boxplot_save_path)
        
        # R2 plot
        r2_model_preds = model_preds_with_metrics[['seq_id','FusionGene','pcnt_disordered','pred_pcnt_disordered']]
        r2_model_preds.to_csv(r2_save_path.replace(".png","_source_data.csv"),index=False)
        plot_fusion_frac_disorder_r2(r2_model_preds['pcnt_disordered'], r2_model_preds['pred_pcnt_disordered'], save_path=r2_save_path)
        calc_average_stats(model_preds_with_metrics)

if __name__ == "__main__":
    main()