|
import torch |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import os |
|
import pandas as pd |
|
import pickle |
|
from transformers import AutoTokenizer |
|
from fuson_plm.utils.visualizing import set_font |
|
import fuson_plm.benchmarking.mutation_prediction.discovery.config as config |
|
|
|
def get_x_tick_labels(start, end): |
|
|
|
start_index = start - 1 |
|
end_index = end |
|
|
|
|
|
domain_len = end - start |
|
if 500 > domain_len > 100: |
|
step_size = 50 |
|
elif 500 <= domain_len: |
|
step_size = 100 |
|
elif domain_len < 10: |
|
step_size = 1 |
|
else: |
|
step_size = 10 |
|
|
|
|
|
x_tick_positions = np.arange(start_index, end_index, step_size) |
|
x_tick_labels = [str(pos + 1) for pos in x_tick_positions] |
|
|
|
return x_tick_positions, x_tick_labels |
|
|
|
|
|
def plot_conservation_heatmap(mutation_results, fusion_name="Fusion Oncoprotein", save_path="conservation_heatmap.png"): |
|
start = mutation_results['start'] |
|
end = mutation_results['end'] |
|
originals_logits = mutation_results['originals_logits'] |
|
conservation_likelihoods = mutation_results['conservation_likelihoods'] |
|
logits = mutation_results['logits'] |
|
logits_for_each_AA = mutation_results['logits_for_each_AA'] |
|
filtered_indices = mutation_results['filtered_indices'] |
|
top_n_mutations = mutation_results['top_n_mutations'] |
|
|
|
|
|
start_index = start - 1 |
|
end_index = end |
|
|
|
|
|
all_logits_array = np.vstack(originals_logits) |
|
transposed_logits_array = all_logits_array.T |
|
conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1) |
|
|
|
combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array)) |
|
|
|
|
|
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end) |
|
|
|
|
|
set_font() |
|
|
|
sequence_length = end_index - start_index |
|
fig = plt.figure(figsize=(min(15, sequence_length / 10), 3)) |
|
|
|
|
|
ax = sns.heatmap( |
|
combined_array, |
|
cmap='viridis', |
|
xticklabels=x_tick_labels, |
|
yticklabels=['Original Logits', 'Conserved'], |
|
cbar=True, |
|
cbar_kws={'aspect': 2, |
|
'pad': 0.02, |
|
'shrink': 1.0, |
|
} |
|
) |
|
|
|
cbar = ax.collections[0].colorbar |
|
|
|
|
|
cbar.ax.tick_params(labelsize=20) |
|
|
|
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=90, fontsize=20) |
|
plt.yticks(fontsize=20, rotation=0) |
|
plt.title(f'{fusion_name} Residues {start}-{end}', fontsize=30) |
|
plt.xlabel('Residue Index', fontsize=30) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
plt.savefig(save_path, format='png', dpi=300) |
|
|
|
|
|
def plot_full_heatmap(mutation_results, tokenizer, fusion_name="Fusion Oncoprotein", save_path="full_heatmap.png"): |
|
start = mutation_results['start'] |
|
end = mutation_results['end'] |
|
logits = mutation_results['logits'] |
|
logits_for_each_AA = mutation_results['logits_for_each_AA'] |
|
filtered_indices = mutation_results['filtered_indices'] |
|
|
|
|
|
start_index = start - 1 |
|
end_index = end |
|
|
|
|
|
token_indices = torch.arange(logits.size(-1)) |
|
tokens = [tokenizer.decode([idx]) for idx in token_indices] |
|
filtered_tokens = [tokens[i] for i in filtered_indices] |
|
all_logits_array = np.vstack(logits_for_each_AA) |
|
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() |
|
transposed_logits_array = normalized_logits_array.T |
|
|
|
|
|
x_tick_positions, x_tick_labels = get_x_tick_labels(start, end) |
|
|
|
|
|
set_font() |
|
fig = plt.figure(figsize=(15, 8)) |
|
plt.rcParams.update({'font.size': 16.5}) |
|
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens) |
|
plt.title(f'{fusion_name} Residues {start}-{end}: Token Probability') |
|
plt.ylabel('Amino Acid') |
|
plt.xlabel('Residue Index') |
|
plt.yticks(rotation=0) |
|
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) |
|
plt.tight_layout() |
|
plt.savefig(save_path, format='png', dpi = 300) |
|
|
|
def plot_color_bar(): |
|
""" |
|
Create a Viridis color bar ranging from 0 to 1. |
|
""" |
|
|
|
gradient = np.linspace(0, 1, 256).reshape(1, -1) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
ax.imshow(gradient, aspect="auto", cmap="viridis") |
|
ax.set_xticks([0, 255]) |
|
ax.set_xticklabels(["0\nmost likely\nto mutate", "1\nleast likely\nto mutate"], fontsize=40) |
|
ax.set_yticks([]) |
|
ax.set_title("Original Residue Logits", fontsize=40) |
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
plt.savefig("viridis_color_bar.png", dpi=300) |
|
|
|
def main(): |
|
|
|
plot_color_bar() |
|
|
|
results_dir = "results/final" |
|
subfolders = os.listdir(results_dir) |
|
for subfolder in subfolders: |
|
full_path = f"{results_dir}/{subfolder}" |
|
if os.path.isdir(full_path): |
|
with open(f"{full_path}/raw_mutation_results.pkl", "rb") as f: |
|
mutation_results = pickle.load(f) |
|
plot_conservation_heatmap(mutation_results, |
|
fusion_name=subfolder, save_path=f"{full_path}/conservation_heatmap.png") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |