File size: 17,456 Bytes
3efa812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#### Recover mutations from literature. A benchmark
import fuson_plm.benchmarking.mutation_prediction.recovery.config as config
import os
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES

import pandas as pd
import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import os
import torch.nn.functional as F

from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy
from fuson_plm.benchmarking.embed import load_fuson_model

def check_env_variables():
    log_update("\nChecking on environment variables...")
    log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
    log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")

def get_top_k_aa_mutations(all_probabilities, sequence, i, top_k_mutations, k=10):
    """
    Should only return top AA mutations
    """
    all_probs = pd.DataFrame.from_dict(all_probabilities, orient='index').reset_index()
    all_probs = all_probs.sort_values(by=0,ascending=False).reset_index(drop=True)
    top_k_mutation = all_probs['index'].tolist()[0:k]
    top_k_mutation = ",".join(top_k_mutation)
    top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities)
    
    return top_k_mutations

def get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_k_mutations, k=3):
    top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices[0].tolist()
    top_k_mutation = []
    for token in top_k_tokens:
        replaced_text = tokenizer.decode([token])
        top_k_mutation.append(replaced_text)

    top_k_mutation = ",".join(top_k_mutation)
    top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities)

def predict_positionwise_mutations(model, tokenizer, device, sequence):
    log_update("\t\tPredicting position-wise mutations...")
    top_10_mutations = {}
    decoded_full_sequence = ''
    mut_count = 0

    # Mask and unmask sequentially 
    for i in range(len(sequence)):
        log_update(f"\t\t\t- pos {i+1}/{len(sequence)}")
        all_probabilities = {}  # stored probabilities of each AA at this position

        # Mask JUST the current position
        masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
        inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Forward pass
        with torch.no_grad():
            logits = model(**inputs).logits
        
        # Find logits at masked positions (should just be 1!)
        mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
        mask_token_logits = logits[0, mask_token_index, :]
        mask_token_probs = F.softmax(mask_token_logits, dim=-1)
        
        # Collect probabilities for natural AAs (token IDs 4-23 inclusive)
        for token_idx in range(4, 23 + 1):
            token = mask_token_probs[0, token_idx]
            replaced_text = tokenizer.decode([token_idx])
            all_probabilities[replaced_text] = token.item()

        # Isolate top n mutations
        #get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_10_mutations, k=10)
        get_top_k_aa_mutations(all_probabilities, sequence, i, top_10_mutations, k=10)
        
        # Building whole decoded sequence with top 1 token
        top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
        new_residue = tokenizer.decode([top_1_tokens])
        decoded_full_sequence += new_residue

        # Check how many mutations in total
        if sequence[i] != new_residue:
            mut_count += 1
    
    # Convert results into DataFrame
    original_residues = []
    top10_mutations = []
    positions = []
    all_logits = []

    for (original_residue, position), (top10, probs) in top_10_mutations.items():
        original_residues.append(original_residue)
        top10_mutations.append(top10)
        positions.append(position+1)            # originally this line was "position" but it should be position + 1
        all_logits.append(probs)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Position': positions,
        'Top 10 Mutations': top10_mutations,
        'All Probabilities': all_logits,
    })
    df['Top Mutation'] = df['Top 10 Mutations'].apply(lambda x: x.split(',')[0])
    df['Top 3 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:3]))
    df['Top 4 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:4]))
    df['Top 5 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:5]))
    
    return df, decoded_full_sequence, mut_count

def evaluate_literature_mut_performance(predicted_mutations_df, literature_mutations_df, decoded_full_sequence, mut_count, sequence="", focus_region_start=0, focus_region_end=0, offset=0):
    """
    Given a dataframe of predicted mutations and literature mutations, see how well the predicted mutations did 
    """
    log_update("\t\tComparing predicted mutations to literature-provided mutations")
    return_df = predicted_mutations_df.copy(deep=True)
    return_df['Literature Mutation'] = [np.nan]*len(return_df)
    return_df['Top 1 Hit'] = [np.nan]*len(return_df)
    return_df['Top 3 Hit'] = [np.nan]*len(return_df)
    return_df['Top 4 Hit'] = [np.nan]*len(return_df)
    return_df['Top 5 Hit'] = [np.nan]*len(return_df)
    return_df['Top 10 Hit'] = [np.nan]*len(return_df)
    
    log_update(f"\tFormula: new position = {focus_region_start} + lit_position - {offset}")
    # Iterate through the literature mutations rows
    for i, row in literature_mutations_df.iterrows():
        lit_position = row['Position']
        lit_mutations = row['Mutation']
        original_residue = row['Original Residue']
        seq_position = focus_region_start + (lit_position - offset) # find position of the sequence
        
        matching_row = return_df[return_df['Position'] == seq_position]
        matching_row_index = matching_row.index
        matching_residue = matching_row.iloc[0]['Original Residue']
        match = original_residue==matching_residue
        log_update(f"\tLit pos: {lit_position}, OG residue: {original_residue}, Full sequence pos: {seq_position}, Full sequence residue: {matching_residue}\n\t\tMatch: {match}")
        
        # Iterate through the matching rows. We are at the right spot if we have the right original residue. 
        if match:
            top_mutation = matching_row.iloc[0]['Top Mutation']   # get top 3 mutations
            top_mutation = top_mutation.split(',')
            print(top_mutation)
            return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations  # get desired mutation
            # If we got any of the mutatios reported in the literature, hit! 
            if any(letter in lit_mutations for letter in top_mutation):
                return_df.loc[matching_row_index, 'Top 1 Hit'] = True
            else:
                return_df.loc[matching_row_index, 'Top 1 Hit'] = False
                
            for k in [3,4,5,10]:
                top_k_mutations = matching_row.iloc[0][f'Top {k} Mutations']   # get top 3 mutations
                top_k_mutations = top_k_mutations.split(",")
                print(top_k_mutations)
                return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations  # get desired mutation
                # If we got any of the mutatios reported in the literature, hit! 
                if any(letter in lit_mutations for letter in top_k_mutations):
                    return_df.loc[matching_row_index, f'Top {k} Hit'] = True
                else:
                    return_df.loc[matching_row_index, f'Top {k} Hit'] = False

    return return_df, (decoded_full_sequence, mut_count, (mut_count/len(sequence)) * 100)

def evaluate_eml4_alk(model, tokenizer, device, model_str):
    alk_muts = pd.read_csv("alk_mutations.csv")
    decoded_full_sequence, mut_count = None, None

    EML4_ALK_SEQ = np.nan       ## not publicly available 
    cons_domain_alk = np.nan    # no publicly available 
    focus_region_start = EML4_ALK_SEQ.find(cons_domain_alk)
    
    if os.path.isfile(f"eml4_alk_mutations/{model_str}/mutated_df.csv"):
        log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from eml4_alk_mutations/{model_str}/mutated_df.csv")
        mutated_df = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv")
        mutated_summary = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv")
        decoded_full_sequence = mutated_summary['decoded_full_sequence'][0]
        mut_count = mutated_summary['mut_count'][0]
    else:
        mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, EML4_ALK_SEQ)
        mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]})
        mutated_df.to_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv",index=False)
        mutated_summary.to_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv",index=False)
    
    lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, alk_muts, decoded_full_sequence, mut_count,
                                                            sequence=EML4_ALK_SEQ, 
                                                            focus_region_start=focus_region_start,
                                                            focus_region_end = focus_region_start + len(cons_domain_alk),
                                                            offset=1115 # original: 1116
                                                            )

    return lit_performance_df, (mut_seq, mut_count, mut_pcnt)

def evaluate_bcr_abl(model, tokenizer, device, model_str):
    abl_muts = pd.read_csv("abl_mutations.csv")
    decoded_full_sequence, mut_count = None, None

    BCR_ABL_SEQ = np.nan    ## not publicly available
    cons_domain_abl = np.nan    ## not publicly available 
    focus_region_start = BCR_ABL_SEQ.find(cons_domain_abl)
    
    if os.path.isfile(f"bcr_abl_mutations/{model_str}/mutated_df.csv"):
        log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from bcr_abl_mutations/{model_str}/mutated_df.csv")
        mutated_df = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv")
        mutated_summary = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv")
        decoded_full_sequence = mutated_summary['decoded_full_sequence'][0]
        mut_count = mutated_summary['mut_count'][0]
    else:
        mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, BCR_ABL_SEQ)
        mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]})
        mutated_df.to_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv",index=False)
        mutated_summary.to_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv",index=False)
        
    lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, abl_muts, decoded_full_sequence, mut_count,
                                                            sequence=BCR_ABL_SEQ, 
                                                            focus_region_start=focus_region_start,
                                                            focus_region_end = focus_region_start + len(cons_domain_abl),
                                                            offset=241  # original: 242
                                                            )

    return lit_performance_df, (mut_seq, mut_count, mut_pcnt)

def summarize_individual_performance(performance_df, path_to_lit_df):
    """
    performance_df = dataframe with stats on performance
    path_to_lit_df = original dataframe
    """
    # Load original df
    lit_muts = pd.read_csv(path_to_lit_df)
    
    # Mutated Sequence,Original Residue,Position,Top 3 Mutations,Literature Mutation,Hit,All Probabilities
    mut_rows = performance_df.loc[performance_df['Literature Mutation'].notna()].reset_index(drop=True)
    mut_rows = mut_rows[['Original Residue','Position','Literature Mutation',
                         'Top Mutation','Top 1 Hit',
                         'Top 3 Mutations','Top 3 Hit',
                         'Top 4 Mutations','Top 4 Hit',
                         'Top 5 Mutations','Top 5 Hit',
                         'Top 10 Mutations','Top 10 Hit'
                         ]]
    
    mut_rows_str = mut_rows.to_string(index=False)
    mut_rows_str = "\t\t" + mut_rows_str.replace("\n","\n\t\t")
    log_update(f"\tPerformance on all mutated positions shown below:\n{mut_rows_str}")
    
    # Summarize: total hits, percentage of hits
    total_original_muts = len(lit_muts)
    for k in [1,3,4,5,10]:
        total_hits = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==True])
        total_misses = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==False])
        total_potential_muts = total_hits+total_misses
        hit_pcnt = round(100*total_hits/total_potential_muts, 2)
        miss_pcnt = round(100*total_misses/total_potential_muts, 2)
        
        log_update(f"\tTotal positions tested / total positions mutated in literature: {total_potential_muts}/{total_original_muts}")
        log_update(f"\n\t\tTop {k} hit performance:\n\t\t\tHits:{total_hits} ({hit_pcnt}%)\n\t\t\tMisses:{total_misses} ({miss_pcnt}%)")
    
def main():
    os.makedirs('results',exist_ok=True)
    output_dir = f'results/{get_local_time()}'
    os.makedirs(output_dir,exist_ok=True)
    os.makedirs("bcr_abl_mutations",exist_ok=True)
    os.makedirs("eml4_alk_mutations",exist_ok=True)
    with open_logfile(f"{output_dir}/mutation_discovery_log.txt"):
        print_configpy(config)
        
        # Make sure environment variables are set correctly
        check_env_variables()
        
        # Get device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        log_update(f"Using device: {device}")
        
        # Load fuson
        fuson_ckpt_path = config.FUSON_PLM_CKPT
        if fuson_ckpt_path=="FusOn-pLM":
            fuson_ckpt_path="../../../.."
            model_name = "fuson_plm"
            model_epoch = "best"
            model_str = f"fuson_plm/best"
        else:
            model_name = list(fuson_ckpt_path.keys())[0]
            epoch = list(fuson_ckpt_path.values())[0]
            fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
            model_name, model_epoch = fuson_ckpt_path.split('/')[-2::]
            model_epoch = model_epoch.split('checkpoint_')[-1]
            model_str = f"{model_name}/{model_epoch}"
            
        log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}")
        fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path)
        fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path)
        fuson_model.to(device)
        fuson_model.eval()
        
        
        # Evaluate BCR::ABL performance with FusOn
        os.makedirs(f"bcr_abl_mutations/{model_name}",exist_ok=True)
        os.makedirs(f"bcr_abl_mutations/{model_name}/{model_epoch}",exist_ok=True)
        log_update("\tEvaluating performance on BCR::ABL mutation prediction with FusOn")
        abl_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_bcr_abl(fuson_model, fuson_tokenizer, device, model_str)
        abl_lit_performance_fuson.to_csv(f'{output_dir}/BCR_ABL_mutation_recovery_fuson.csv', index = False)
        with open(f'{output_dir}/BCR_ABL_mutation_recovery_fuson_summary.txt', 'w') as f:
            f.write(mut_seq)
            f.write(f'number of mutations: {mut_count}')
            f.write(f'percentage of seq mutated: {mut_pcnt}')
        
        # Evaluate EML4::ALK performance with Fuson
        os.makedirs(f"eml4_alk_mutations/{model_name}",exist_ok=True)
        os.makedirs(f"eml4_alk_mutations/{model_name}/{model_epoch}",exist_ok=True)
        log_update("\tEvaluating performance on EML4::ALK mutation prediction with FusOn")
        alk_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_eml4_alk(fuson_model, fuson_tokenizer, device, model_str)
        alk_lit_performance_fuson.to_csv(f'{output_dir}/EML4_ALK_mutation_recovery_fuson.csv', index = False)
        with open(f'{output_dir}/EML4_ALK_mutation_recovery_fuson_summary.txt', 'w') as f:
            f.write(mut_seq)
            f.write(f'number of mutations: {mut_count}')
            f.write(f'percentage of seq mutated: {mut_pcnt}')
        
        ### Summarize
        log_update("\nSummarizing FusOn-pLM performance on BCR::ABL")
        summarize_individual_performance(abl_lit_performance_fuson, "abl_mutations.csv")
        log_update("\nSummarizing FusOn-pLM performance on EML4::ALK")
        summarize_individual_performance(alk_lit_performance_fuson, "alk_mutations.csv")

if __name__ == "__main__":
    main()