Kseniia-Kholina's picture
Update app.py
7623bc9 verified
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch.nn.functional as F
import logging
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
from contextlib import contextmanager
import warnings
import sys
import os
import zipfile
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
@contextmanager
def suppress_output():
with open(os.devnull, 'w') as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
def process_sequence(sequence, domain_bounds, n):
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
AAs_tokens_indices = {'L' : 4, 'A' : 5, 'G' : 6, 'V': 7, 'S' : 8, 'E' : 9, 'R' : 10, 'T' : 11, 'I': 12, 'D' : 13, 'P' : 14,
'K' : 15, 'Q' : 16, 'N' : 17, 'F' : 18, 'Y' : 19, 'M' : 20, 'H' : 21, 'W' : 22, 'C' : 23}
# checking sequence inputs
if not sequence.strip():
raise gr.Error("Error: The sequence input is empty. Please enter a valid protein sequence.")
return None, None, None
if any(char not in AAs_tokens for char in sequence):
raise gr.Error("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.")
return None, None, None
# checking domain bounds inputs
try:
start = int(domain_bounds['start'][0])
end = int(domain_bounds['end'][0])
except ValueError:
raise gr.Error("Error: Start and end indices must be integers.")
return None, None, None
if start >= end:
raise gr.Error("Start index must be smaller than end index.")
return None, None, None
if start == 0 and end != 0:
raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.")
return None, None, None
if start <= 0 or end <= 0:
raise gr.Error("Domain bounds must be positive integers. Please enter valid domain bounds.")
return None, None, None
if start > len(sequence) or end > len(sequence):
raise gr.Error("Domain bounds exceed sequence length.")
return None, None, None
# checking top n tokens input
if n == None:
raise gr.Error("Choose Top N Tokens from the dropdown menu.")
return None, None, None
start_index = int(domain_bounds['start'][0]) - 1
end_index = int(domain_bounds['end'][0])
top_n_mutations = {}
all_logits = []
# these 2 lists are for the 2nd heatmap
originals_logits = []
conservation_likelihoods = {}
for i in range(len(sequence)):
# only iterate through the residues inside the domain
if start_index <= i <= (end_index - 1):
original_residue = sequence[i]
original_residue_index = AAs_tokens_indices[original_residue]
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
# Pick top N tokens
all_tokens_logits = mask_token_logits.squeeze(0)
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True)
top_tokens_logits = all_tokens_logits[top_tokens_indices]
mutation = []
# make sure we don't include non-AA tokens
for token_index in top_tokens_indices:
decoded_token = tokenizer.decode([token_index.item()])
# decoded all tokens, pick the top n amino acid ones
if decoded_token in AAs_tokens:
mutation.append(decoded_token)
if len(mutation) == n:
break
top_n_mutations[(sequence[i], i)] = mutation
# collecting logits for the heatmap
logits_array = mask_token_logits.cpu().numpy()
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1))
filtered_logits = logits_array[:, filtered_indices]
all_logits.append(filtered_logits)
# code for the second heatmap
normalized_mask_token_logits = F.softmax(torch.tensor(mask_token_logits).cpu(), dim=-1).numpy()
normalized_mask_token_logits = np.squeeze(normalized_mask_token_logits)
originals_logit = normalized_mask_token_logits[original_residue_index]
originals_logits.append(originals_logit)
if originals_logit > 0.7:
conservation_likelihoods[(original_residue, i)] = 1
else:
conservation_likelihoods[(original_residue, i)] = 0
# Plotting heatmap 2
domain_len = end - start
if 500 > domain_len > 100:
step_size = 50
elif 500 <= domain_len:
step_size = 100
elif domain_len < 10:
step_size = 1
else:
step_size = 10
x_tick_positions = np.arange(start_index, end_index, step_size)
x_tick_labels = [str(pos + 1) for pos in x_tick_positions]
all_logits_array = np.vstack(originals_logits)
transposed_logits_array = all_logits_array.T
conservation_likelihoods_array = np.array(list(conservation_likelihoods.values())).reshape(1, -1)
# combine to make a 2D heatmap
combined_array = np.vstack((transposed_logits_array, conservation_likelihoods_array))
plt.figure(figsize=(15, 5))
plt.rcParams.update({'font.size': 16.5})
sns.heatmap(combined_array, cmap='viridis', xticklabels=x_tick_labels, yticklabels=['Residue \nLogits', 'Residue \nConservation'], cbar=True)
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
plt.title('Original Residue Probability and Conservation')
plt.xlabel('Residue Index')
plt.show()
buf = BytesIO()
plt.savefig(buf, format='png', dpi=300)
buf.seek(0)
plt.close()
img_2 = Image.open(buf)
# plotting heatmap 1
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(all_logits)
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy()
transposed_logits_array = normalized_logits_array.T
plt.figure(figsize=(15, 8))
plt.rcParams.update({'font.size': 16.5})
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens)
plt.title('Token Probability')
plt.ylabel('Amino Acid')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0)
buf = BytesIO()
plt.savefig(buf, format='png', dpi = 300)
buf.seek(0)
plt.close()
img_1 = Image.open(buf)
# store the predicted mutations in a dataframe
original_residues = []
mutations = []
positions = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(value)
positions.append(position + 1)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues': mutations,
'Position': positions
})
df.to_csv("predicted_tokens.csv", index=False)
img_1.save("heatmap.png", dpi=(300, 300))
img_2.save("heatmap_2.png", dpi=(300, 300))
zip_path = "outputs.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
zipf.write("predicted_tokens.csv")
zipf.write("heatmap.png")
zipf.write("heatmap_2.png")
return df, img_1, img_2, zip_path
# launch the demo
demo = gr.Interface(
fn=process_sequence,
inputs=[
gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"),
gr.Dataframe(
value = [[1, 1]],
headers=["start", "end"],
datatype=["number", "number"],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
label="Domain Bounds"
),
gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"),
],
outputs=[
gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"),
gr.Image(type="pil", label="Probability Distribution for All Tokens"),
gr.Image(type="pil", label="Residue Conservation"),
gr.File(label="Download Outputs"),
],
)
if __name__ == "__main__":
with suppress_output():
demo.launch()