Kseniia-Kholina's picture
Rename app.py to app_heatmap.py
5d1d3c4 verified
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import gradio as gr
from io import BytesIO
from PIL import Image
def get_heatmap(sequence):
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
all_logits = []
for i in range(len(sequence)):
# add a masked token
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
# tokenize masked sequence
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
# predict logits for the masked token
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
logits_array = mask_token_logits.cpu().numpy()
# filter out non-amino acid tokens
filtered_indices = list(range(4, 23 + 1))
filtered_logits = logits_array[:, filtered_indices]
all_logits.append(filtered_logits)
token_indices = torch.arange(logits.size(-1))
tokens = [tokenizer.decode([idx]) for idx in token_indices]
filtered_tokens = [tokens[i] for i in filtered_indices]
all_logits_array = np.vstack(all_logits)
normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
transposed_logits_array = normalized_logits_array.T
# Plotting the heatmap
step = 50
y_tick_positions = np.arange(0, len(sequence), step)
y_tick_labels = [str(pos) for pos in y_tick_positions]
plt.figure(figsize=(15, 8))
sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
plt.title('Logits for masked per residue tokens')
plt.ylabel('Token')
plt.xlabel('Residue Index')
plt.yticks(rotation=0)
plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)
# Save the figure to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert BytesIO object to an image
img = Image.open(buf)
return img
demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image")
demo.launch()