|
|
|
import fuson_plm.benchmarking.mutation_prediction.recovery.config as config |
|
import os |
|
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
import logging |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import argparse |
|
import os |
|
import torch.nn.functional as F |
|
|
|
from fuson_plm.utils.logging import open_logfile, log_update, get_local_time, print_configpy |
|
from fuson_plm.benchmarking.embed import load_fuson_model |
|
|
|
def check_env_variables(): |
|
log_update("\nChecking on environment variables...") |
|
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
|
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") |
|
for i in range(torch.cuda.device_count()): |
|
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
def get_top_k_aa_mutations(all_probabilities, sequence, i, top_k_mutations, k=10): |
|
""" |
|
Should only return top AA mutations |
|
""" |
|
all_probs = pd.DataFrame.from_dict(all_probabilities, orient='index').reset_index() |
|
all_probs = all_probs.sort_values(by=0,ascending=False).reset_index(drop=True) |
|
top_k_mutation = all_probs['index'].tolist()[0:k] |
|
top_k_mutation = ",".join(top_k_mutation) |
|
top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities) |
|
|
|
return top_k_mutations |
|
|
|
def get_top_k_mutations(tokenizer, mask_token_logits, all_probabilities, sequence, i, top_k_mutations, k=3): |
|
top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices[0].tolist() |
|
top_k_mutation = [] |
|
for token in top_k_tokens: |
|
replaced_text = tokenizer.decode([token]) |
|
top_k_mutation.append(replaced_text) |
|
|
|
top_k_mutation = ",".join(top_k_mutation) |
|
top_k_mutations[(sequence[i], i)] = (top_k_mutation, all_probabilities) |
|
|
|
def predict_positionwise_mutations(model, tokenizer, device, sequence): |
|
log_update("\t\tPredicting position-wise mutations...") |
|
top_10_mutations = {} |
|
decoded_full_sequence = '' |
|
mut_count = 0 |
|
|
|
|
|
for i in range(len(sequence)): |
|
log_update(f"\t\t\t- pos {i+1}/{len(sequence)}") |
|
all_probabilities = {} |
|
|
|
|
|
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
|
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
|
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
mask_token_logits = logits[0, mask_token_index, :] |
|
mask_token_probs = F.softmax(mask_token_logits, dim=-1) |
|
|
|
|
|
for token_idx in range(4, 23 + 1): |
|
token = mask_token_probs[0, token_idx] |
|
replaced_text = tokenizer.decode([token_idx]) |
|
all_probabilities[replaced_text] = token.item() |
|
|
|
|
|
|
|
get_top_k_aa_mutations(all_probabilities, sequence, i, top_10_mutations, k=10) |
|
|
|
|
|
top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item() |
|
new_residue = tokenizer.decode([top_1_tokens]) |
|
decoded_full_sequence += new_residue |
|
|
|
|
|
if sequence[i] != new_residue: |
|
mut_count += 1 |
|
|
|
|
|
original_residues = [] |
|
top10_mutations = [] |
|
positions = [] |
|
all_logits = [] |
|
|
|
for (original_residue, position), (top10, probs) in top_10_mutations.items(): |
|
original_residues.append(original_residue) |
|
top10_mutations.append(top10) |
|
positions.append(position+1) |
|
all_logits.append(probs) |
|
|
|
df = pd.DataFrame({ |
|
'Original Residue': original_residues, |
|
'Position': positions, |
|
'Top 10 Mutations': top10_mutations, |
|
'All Probabilities': all_logits, |
|
}) |
|
df['Top Mutation'] = df['Top 10 Mutations'].apply(lambda x: x.split(',')[0]) |
|
df['Top 3 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:3])) |
|
df['Top 4 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:4])) |
|
df['Top 5 Mutations'] = df['Top 10 Mutations'].apply(lambda x: ','.join(x.split(',')[0:5])) |
|
|
|
return df, decoded_full_sequence, mut_count |
|
|
|
def evaluate_literature_mut_performance(predicted_mutations_df, literature_mutations_df, decoded_full_sequence, mut_count, sequence="", focus_region_start=0, focus_region_end=0, offset=0): |
|
""" |
|
Given a dataframe of predicted mutations and literature mutations, see how well the predicted mutations did |
|
""" |
|
log_update("\t\tComparing predicted mutations to literature-provided mutations") |
|
return_df = predicted_mutations_df.copy(deep=True) |
|
return_df['Literature Mutation'] = [np.nan]*len(return_df) |
|
return_df['Top 1 Hit'] = [np.nan]*len(return_df) |
|
return_df['Top 3 Hit'] = [np.nan]*len(return_df) |
|
return_df['Top 4 Hit'] = [np.nan]*len(return_df) |
|
return_df['Top 5 Hit'] = [np.nan]*len(return_df) |
|
return_df['Top 10 Hit'] = [np.nan]*len(return_df) |
|
|
|
log_update(f"\tFormula: new position = {focus_region_start} + lit_position - {offset}") |
|
|
|
for i, row in literature_mutations_df.iterrows(): |
|
lit_position = row['Position'] |
|
lit_mutations = row['Mutation'] |
|
original_residue = row['Original Residue'] |
|
seq_position = focus_region_start + (lit_position - offset) |
|
|
|
matching_row = return_df[return_df['Position'] == seq_position] |
|
matching_row_index = matching_row.index |
|
matching_residue = matching_row.iloc[0]['Original Residue'] |
|
match = original_residue==matching_residue |
|
log_update(f"\tLit pos: {lit_position}, OG residue: {original_residue}, Full sequence pos: {seq_position}, Full sequence residue: {matching_residue}\n\t\tMatch: {match}") |
|
|
|
|
|
if match: |
|
top_mutation = matching_row.iloc[0]['Top Mutation'] |
|
top_mutation = top_mutation.split(',') |
|
print(top_mutation) |
|
return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations |
|
|
|
if any(letter in lit_mutations for letter in top_mutation): |
|
return_df.loc[matching_row_index, 'Top 1 Hit'] = True |
|
else: |
|
return_df.loc[matching_row_index, 'Top 1 Hit'] = False |
|
|
|
for k in [3,4,5,10]: |
|
top_k_mutations = matching_row.iloc[0][f'Top {k} Mutations'] |
|
top_k_mutations = top_k_mutations.split(",") |
|
print(top_k_mutations) |
|
return_df.loc[matching_row_index, 'Literature Mutation'] = lit_mutations |
|
|
|
if any(letter in lit_mutations for letter in top_k_mutations): |
|
return_df.loc[matching_row_index, f'Top {k} Hit'] = True |
|
else: |
|
return_df.loc[matching_row_index, f'Top {k} Hit'] = False |
|
|
|
return return_df, (decoded_full_sequence, mut_count, (mut_count/len(sequence)) * 100) |
|
|
|
def evaluate_eml4_alk(model, tokenizer, device, model_str): |
|
alk_muts = pd.read_csv("alk_mutations.csv") |
|
decoded_full_sequence, mut_count = None, None |
|
|
|
EML4_ALK_SEQ = np.nan |
|
cons_domain_alk = np.nan |
|
focus_region_start = EML4_ALK_SEQ.find(cons_domain_alk) |
|
|
|
if os.path.isfile(f"eml4_alk_mutations/{model_str}/mutated_df.csv"): |
|
log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from eml4_alk_mutations/{model_str}/mutated_df.csv") |
|
mutated_df = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv") |
|
mutated_summary = pd.read_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv") |
|
decoded_full_sequence = mutated_summary['decoded_full_sequence'][0] |
|
mut_count = mutated_summary['mut_count'][0] |
|
else: |
|
mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, EML4_ALK_SEQ) |
|
mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]}) |
|
mutated_df.to_csv(f"eml4_alk_mutations/{model_str}/mutated_df.csv",index=False) |
|
mutated_summary.to_csv(f"eml4_alk_mutations/{model_str}/mutated_summary.csv",index=False) |
|
|
|
lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, alk_muts, decoded_full_sequence, mut_count, |
|
sequence=EML4_ALK_SEQ, |
|
focus_region_start=focus_region_start, |
|
focus_region_end = focus_region_start + len(cons_domain_alk), |
|
offset=1115 |
|
) |
|
|
|
return lit_performance_df, (mut_seq, mut_count, mut_pcnt) |
|
|
|
def evaluate_bcr_abl(model, tokenizer, device, model_str): |
|
abl_muts = pd.read_csv("abl_mutations.csv") |
|
decoded_full_sequence, mut_count = None, None |
|
|
|
BCR_ABL_SEQ = np.nan |
|
cons_domain_abl = np.nan |
|
focus_region_start = BCR_ABL_SEQ.find(cons_domain_abl) |
|
|
|
if os.path.isfile(f"bcr_abl_mutations/{model_str}/mutated_df.csv"): |
|
log_update(f"Mutation predictions for {model_str} have already been calculated. Loading from bcr_abl_mutations/{model_str}/mutated_df.csv") |
|
mutated_df = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv") |
|
mutated_summary = pd.read_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv") |
|
decoded_full_sequence = mutated_summary['decoded_full_sequence'][0] |
|
mut_count = mutated_summary['mut_count'][0] |
|
else: |
|
mutated_df, decoded_full_sequence, mut_count = predict_positionwise_mutations(model, tokenizer, device, BCR_ABL_SEQ) |
|
mutated_summary = pd.DataFrame(data={'decoded_full_sequence':[decoded_full_sequence],'mut_count':[mut_count]}) |
|
mutated_df.to_csv(f"bcr_abl_mutations/{model_str}/mutated_df.csv",index=False) |
|
mutated_summary.to_csv(f"bcr_abl_mutations/{model_str}/mutated_summary.csv",index=False) |
|
|
|
lit_performance_df, (mut_seq, mut_count, mut_pcnt) = evaluate_literature_mut_performance(mutated_df, abl_muts, decoded_full_sequence, mut_count, |
|
sequence=BCR_ABL_SEQ, |
|
focus_region_start=focus_region_start, |
|
focus_region_end = focus_region_start + len(cons_domain_abl), |
|
offset=241 |
|
) |
|
|
|
return lit_performance_df, (mut_seq, mut_count, mut_pcnt) |
|
|
|
def summarize_individual_performance(performance_df, path_to_lit_df): |
|
""" |
|
performance_df = dataframe with stats on performance |
|
path_to_lit_df = original dataframe |
|
""" |
|
|
|
lit_muts = pd.read_csv(path_to_lit_df) |
|
|
|
|
|
mut_rows = performance_df.loc[performance_df['Literature Mutation'].notna()].reset_index(drop=True) |
|
mut_rows = mut_rows[['Original Residue','Position','Literature Mutation', |
|
'Top Mutation','Top 1 Hit', |
|
'Top 3 Mutations','Top 3 Hit', |
|
'Top 4 Mutations','Top 4 Hit', |
|
'Top 5 Mutations','Top 5 Hit', |
|
'Top 10 Mutations','Top 10 Hit' |
|
]] |
|
|
|
mut_rows_str = mut_rows.to_string(index=False) |
|
mut_rows_str = "\t\t" + mut_rows_str.replace("\n","\n\t\t") |
|
log_update(f"\tPerformance on all mutated positions shown below:\n{mut_rows_str}") |
|
|
|
|
|
total_original_muts = len(lit_muts) |
|
for k in [1,3,4,5,10]: |
|
total_hits = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==True]) |
|
total_misses = len(mut_rows.loc[mut_rows[f'Top {k} Hit']==False]) |
|
total_potential_muts = total_hits+total_misses |
|
hit_pcnt = round(100*total_hits/total_potential_muts, 2) |
|
miss_pcnt = round(100*total_misses/total_potential_muts, 2) |
|
|
|
log_update(f"\tTotal positions tested / total positions mutated in literature: {total_potential_muts}/{total_original_muts}") |
|
log_update(f"\n\t\tTop {k} hit performance:\n\t\t\tHits:{total_hits} ({hit_pcnt}%)\n\t\t\tMisses:{total_misses} ({miss_pcnt}%)") |
|
|
|
def main(): |
|
os.makedirs('results',exist_ok=True) |
|
output_dir = f'results/{get_local_time()}' |
|
os.makedirs(output_dir,exist_ok=True) |
|
os.makedirs("bcr_abl_mutations",exist_ok=True) |
|
os.makedirs("eml4_alk_mutations",exist_ok=True) |
|
with open_logfile(f"{output_dir}/mutation_discovery_log.txt"): |
|
print_configpy(config) |
|
|
|
|
|
check_env_variables() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
log_update(f"Using device: {device}") |
|
|
|
|
|
fuson_ckpt_path = config.FUSON_PLM_CKPT |
|
if fuson_ckpt_path=="FusOn-pLM": |
|
fuson_ckpt_path="../../../.." |
|
model_name = "fuson_plm" |
|
model_epoch = "best" |
|
model_str = f"fuson_plm/best" |
|
else: |
|
model_name = list(fuson_ckpt_path.keys())[0] |
|
epoch = list(fuson_ckpt_path.values())[0] |
|
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}' |
|
model_name, model_epoch = fuson_ckpt_path.split('/')[-2::] |
|
model_epoch = model_epoch.split('checkpoint_')[-1] |
|
model_str = f"{model_name}/{model_epoch}" |
|
|
|
log_update(f"\nLoading FusOn-pLM model from {fuson_ckpt_path}") |
|
fuson_tokenizer = AutoTokenizer.from_pretrained(fuson_ckpt_path) |
|
fuson_model = AutoModelForMaskedLM.from_pretrained(fuson_ckpt_path) |
|
fuson_model.to(device) |
|
fuson_model.eval() |
|
|
|
|
|
|
|
os.makedirs(f"bcr_abl_mutations/{model_name}",exist_ok=True) |
|
os.makedirs(f"bcr_abl_mutations/{model_name}/{model_epoch}",exist_ok=True) |
|
log_update("\tEvaluating performance on BCR::ABL mutation prediction with FusOn") |
|
abl_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_bcr_abl(fuson_model, fuson_tokenizer, device, model_str) |
|
abl_lit_performance_fuson.to_csv(f'{output_dir}/BCR_ABL_mutation_recovery_fuson.csv', index = False) |
|
with open(f'{output_dir}/BCR_ABL_mutation_recovery_fuson_summary.txt', 'w') as f: |
|
f.write(mut_seq) |
|
f.write(f'number of mutations: {mut_count}') |
|
f.write(f'percentage of seq mutated: {mut_pcnt}') |
|
|
|
|
|
os.makedirs(f"eml4_alk_mutations/{model_name}",exist_ok=True) |
|
os.makedirs(f"eml4_alk_mutations/{model_name}/{model_epoch}",exist_ok=True) |
|
log_update("\tEvaluating performance on EML4::ALK mutation prediction with FusOn") |
|
alk_lit_performance_fuson, (mut_seq, mut_count, mut_pcnt) = evaluate_eml4_alk(fuson_model, fuson_tokenizer, device, model_str) |
|
alk_lit_performance_fuson.to_csv(f'{output_dir}/EML4_ALK_mutation_recovery_fuson.csv', index = False) |
|
with open(f'{output_dir}/EML4_ALK_mutation_recovery_fuson_summary.txt', 'w') as f: |
|
f.write(mut_seq) |
|
f.write(f'number of mutations: {mut_count}') |
|
f.write(f'percentage of seq mutated: {mut_pcnt}') |
|
|
|
|
|
log_update("\nSummarizing FusOn-pLM performance on BCR::ABL") |
|
summarize_individual_performance(abl_lit_performance_fuson, "abl_mutations.csv") |
|
log_update("\nSummarizing FusOn-pLM performance on EML4::ALK") |
|
summarize_individual_performance(alk_lit_performance_fuson, "alk_mutations.csv") |
|
|
|
if __name__ == "__main__": |
|
main() |