|
import gradio as gr |
|
import pandas as pd |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
import torch.nn.functional as F |
|
import logging |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from io import BytesIO |
|
from PIL import Image |
|
from contextlib import contextmanager |
|
import warnings |
|
import sys |
|
import os |
|
import zipfile |
|
|
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
model_name = "ChatterjeeLab/FusOn-pLM" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
|
model.to(device) |
|
model.eval() |
|
|
|
@contextmanager |
|
def suppress_output(): |
|
with open(os.devnull, 'w') as devnull: |
|
old_stdout = sys.stdout |
|
sys.stdout = devnull |
|
try: |
|
yield |
|
finally: |
|
sys.stdout = old_stdout |
|
|
|
def process_sequence(sequence, domain_bounds, n): |
|
AAs_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] |
|
|
|
if not sequence.strip(): |
|
raise gr.Error("Error: The sequence input is empty. Please enter a valid protein sequence.") |
|
return None, None, None |
|
if any(char not in AAs_tokens for char in sequence): |
|
raise gr.Error("Error: The sequence input contains non-amino acid characters. Please enter a valid protein sequence.") |
|
return None, None, None |
|
|
|
|
|
try: |
|
start = int(domain_bounds['start'][0]) |
|
end = int(domain_bounds['end'][0]) |
|
except ValueError: |
|
raise gr.Error("Error: Start and end indices must be integers.") |
|
return None, None, None |
|
if start >= end: |
|
raise gr.Error("Start index must be smaller than end index.") |
|
return None, None, None |
|
if start == 0 and end != 0: |
|
raise gr.Error("Indexing starts at 1. Please enter valid domain bounds.") |
|
return None, None, None |
|
if start <= 0 or end <= 0: |
|
raise gr.Error("Domain bounds must be positive integers. Please enter valid domain bounds.") |
|
return None, None, None |
|
if start > len(sequence) or end > len(sequence): |
|
raise gr.Error("Domain bounds exceed sequence length.") |
|
return None, None, None |
|
|
|
|
|
if n == None: |
|
raise gr.Error("Choose Top N Tokens from the dropdown menu.") |
|
return None, None, None |
|
|
|
start_index = int(domain_bounds['start'][0]) - 1 |
|
end_index = int(domain_bounds['end'][0]) |
|
|
|
top_n_mutations = {} |
|
all_logits = [] |
|
|
|
for i in range(len(sequence)): |
|
if start_index <= i <= (end_index - 1): |
|
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
|
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
mask_token_logits = logits[0, mask_token_index, :] |
|
|
|
|
|
all_tokens_logits = mask_token_logits.squeeze(0) |
|
top_tokens_indices = torch.argsort(all_tokens_logits, dim=0, descending=True) |
|
top_tokens_logits = all_tokens_logits[top_tokens_indices] |
|
mutation = [] |
|
|
|
for token_index in top_tokens_indices: |
|
decoded_token = tokenizer.decode([token_index.item()]) |
|
if decoded_token in AAs_tokens: |
|
mutation.append(decoded_token) |
|
if len(mutation) == n: |
|
break |
|
top_n_mutations[(sequence[i], i)] = mutation |
|
|
|
|
|
logits_array = mask_token_logits.cpu().numpy() |
|
|
|
filtered_indices = list(range(4, 23 + 1)) |
|
filtered_logits = logits_array[:, filtered_indices] |
|
all_logits.append(filtered_logits) |
|
|
|
|
|
token_indices = torch.arange(logits.size(-1)) |
|
tokens = [tokenizer.decode([idx]) for idx in token_indices] |
|
filtered_tokens = [tokens[i] for i in filtered_indices] |
|
|
|
all_logits_array = np.vstack(all_logits) |
|
normalized_logits_array = F.softmax(torch.tensor(all_logits_array), dim=-1).numpy() |
|
transposed_logits_array = normalized_logits_array.T |
|
|
|
|
|
domain_len = end - start |
|
if 500 > domain_len > 100: |
|
step_size = 49 |
|
elif 500 <= domain_len: |
|
step_size = 99 |
|
elif domain_len < 10: |
|
step_size = 1 |
|
else: |
|
step_size = 9 |
|
x_tick_positions = np.arange(start_index, end_index, step_size) |
|
x_tick_labels = [str(pos + 1) for pos in x_tick_positions] |
|
|
|
plt.figure(figsize=(15, 8)) |
|
plt.rcParams.update({'font.size': 18}) |
|
|
|
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=x_tick_labels, yticklabels=filtered_tokens) |
|
plt.title('Token Probability Heatmap') |
|
plt.ylabel('Token') |
|
plt.xlabel('Residue Index') |
|
plt.yticks(rotation=0) |
|
plt.xticks(x_tick_positions - start_index + 0.5, x_tick_labels, rotation=0) |
|
|
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format='png', dpi = 300) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
|
|
img = Image.open(buf) |
|
|
|
original_residues = [] |
|
mutations = [] |
|
positions = [] |
|
|
|
for key, value in top_n_mutations.items(): |
|
original_residue, position = key |
|
original_residues.append(original_residue) |
|
mutations.append(value) |
|
positions.append(position + 1) |
|
|
|
df = pd.DataFrame({ |
|
'Original Residue': original_residues, |
|
'Predicted Residues': mutations, |
|
'Position': positions |
|
}) |
|
df.to_csv("predicted_tokens.csv", index=False) |
|
img.save("heatmap.png", dpi=(300, 300)) |
|
zip_path = "outputs.zip" |
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
zipf.write("predicted_tokens.csv") |
|
zipf.write("heatmap.png") |
|
|
|
|
|
return df, img, zip_path |
|
|
|
demo = gr.Interface( |
|
fn=process_sequence, |
|
inputs=[ |
|
gr.Textbox(label="Sequence", placeholder="Enter the protein sequence here"), |
|
gr.Dataframe( |
|
headers=["start", "end"], |
|
datatype=["number", "number"], |
|
row_count=(1, "fixed"), |
|
col_count=(2, "fixed"), |
|
label="Domain Bounds" |
|
), |
|
gr.Dropdown([i for i in range(1, 21)], label="Top N Tokens"), |
|
], |
|
outputs=[ |
|
gr.Dataframe(label="Predicted Tokens (in order of decreasing likelihood)"), |
|
gr.Image(type="pil", label="Heatmap"), |
|
gr.File(label="Download Outputs"), |
|
], |
|
) |
|
if __name__ == "__main__": |
|
with suppress_output(): |
|
demo.launch() |