File size: 2,739 Bytes
be331b6 a0ba8b5 7f75a38 f7a575a 7f75a38 f7a575a 7f75a38 be331b6 1b1c116 7ca3799 1b1c116 7f75a38 1b1c116 7f75a38 1b1c116 7f75a38 447ecb9 7f75a38 6b081f8 7f75a38 a0ba8b5 be331b6 1b1c116 be331b6 1b1c116 d8c7623 1b1c116 7f75a38 be331b6 1b1c116 be331b6 a0ba8b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()
def topn_tokens(sequence, domain_bounds, n):
start_index = int(domain_bounds['start'][0]) - 1
end_index = int(domain_bounds['end'][0]) - 1
top_n_mutations = {}
for i in range(len(sequence)):
# Only mask and unmask the residues within the specified domain boundaries
if start_index <= i <= end_index:
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = logits[0, mask_token_index, :]
# Decode top n tokens
top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
mutation = [tokenizer.decode([token]) for token in top_n_tokens]
top_n_mutations[(sequence[i], i)] = mutation
original_residues = []
mutations = []
positions = []
for key, value in top_n_mutations.items():
original_residue, position = key
original_residues.append(original_residue)
mutations.append(value)
positions.append(position + 1)
df = pd.DataFrame({
'Original Residue': original_residues,
'Predicted Residues (in order of decreasing likelihood)': mutations,
'Position': positions
})
return df
demo = gr.Interface(
fn=topn_tokens,
inputs=[
"text",
gr.Dataframe(
headers=["start", "end"],
datatype=["number", "number"],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
),
gr.Dropdown([i for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as integers
],
outputs="dataframe",
description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)
if __name__ == "__main__":
demo.launch()
|