Fill-Mask
Transformers
Safetensors
esm
svincoff's picture
mutation prediction discovery and recovery
3efa812
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
import numpy as np
import os
import pandas as pd
import pickle
from transformers import AutoTokenizer
from fuson_plm.utils.visualizing import set_font
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config
def get_x_tick_labels(start, end):
# Define start and end index which we actually use to index the sequence
start_index = start - 1
end_index = end
# Define domain length
domain_len = end - start
if 500 > domain_len > 100:
step_size = 50
elif 500 <= domain_len:
step_size = 100
elif domain_len < 10:
step_size = 1
else:
step_size = 10
# Define x tick positions based on step size
x_tick_positions = np.arange(start_index, end_index, step_size)
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
return x_tick_positions, x_tick_labels
def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"):
start = mutation_results['start']
end = mutation_results['end']
originals_logits = mutation_results['originals_logits']
conservation_likelihoods = mutation_results['conservation_likelihoods']
logits = mutation_results['logits']
logits_for_each_AA = mutation_results['logits_for_each_AA']
filtered_indices = mutation_results['filtered_indices']
top_n_mutations = mutation_results['top_n_mutations']
# Get start index and end index
start_index = start - 1
end_index = end
# Make conservation likelihoods array for plotting
all_logits_array = np.vstack(originals_logits)
transposed_logits_array = all_logits_array.T
conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
# combine to make a 2D heatmap
combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))
# Get ticks
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
# Plot!
set_font()
# Adjust the figure size: constant height (e.g., 3) and width proportional to sequence length
sequence_length = end_index - start_index
fig = plt.figure(figsize=(min(15, sequence_length / 10), 3)) # Adjust width dynamically, keep height constant
#plt.rcParams.update({'font.size': 16.5}) # make font size bigger
ax = sns.heatmap(
combined_array,
cmap='viridis',
xticklabels=x_tick_labels,
yticklabels=['Original Logits', 'Conserved'],
cbar=True,
cbar_kws={'aspect': 2,
'pad': 0.02,
'shrink': 1.0, # Adjust the overall size of the color bar
}
)
# Access the color bar
cbar = ax.collections[0].colorbar
# Change the font size of the tick labels on the color bar
cbar.ax.tick_params(labelsize=20) # Adjust the font size of tick labels
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20)
plt.yticks(fontsize=20, rotation=0)
plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30)
plt.xlabel('Residue Index', fontsize=30)
plt.tight_layout()
plt.show()
# save the figure
plt.savefig(save_path, format='png', dpi=300)
# plotting heatmap 1
def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"):
start = mutation_results['start']
end = mutation_results['end']
logits = mutation_results['logits']
logits_for_each_AA = mutation_results['logits_for_each_AA']
filtered_indices = mutation_results['filtered_indices']
# get start and end index
start_index = start - 1
end_index = end
# prepare data for plotting
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(logits_for_each_AA)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
# get x tick labels
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end)
# make plot
set_font()
fig = plt.figure(figsize=(15, 8))
plt.rcParams.update({'font.size': 16.5})
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability')
plt.ylabel('Amino Acid')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
plt.tight_layout()
plt.savefig(save_path, format='png', dpi = 300)
def plot_color_bar():
"""
Create a Viridis color bar ranging from 0 to 1.
"""
# Create a gradient from 0 to 1
gradient = np.linspace(0, 1, 256).reshape(1, -1)
# Plot the gradient as a color bar
fig, ax = plt.subplots(figsize=(12, 3))
ax.imshow(gradient, aspect="auto", cmap="viridis")
ax.set_xticks([0, 255])
ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40)
ax.set_yticks([])
ax.set_title("Original Residue Logits", fontsize=40)
# Save the figure
plt.tight_layout()
plt.show()
plt.savefig("viridis_color_bar.png", dpi=300)
def main():
# Call the function to create and display the color bar
plot_color_bar()
results_dir = "results/final"
subfolders = os.listdir(results_dir)
for subfolder in subfolders:
full_path = f"{results_dir}/{subfolder}"
if os.path.isdir(full_path):
with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f:
mutation_results = pickle.load(f)
plot_conservation_heatmap(mutation_results,
fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png")
if __name__ == "__main__":
main()