|
import gradio as gr |
|
import pandas as pd |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
import logging |
|
|
|
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
model_name = "ChatterjeeLab/FusOn-pLM" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
|
model.to(device) |
|
model.eval() |
|
|
|
def topn_tokens(sequence, domain_bounds, n): |
|
start_index = int(domain_bounds['start'][0]) - 1 |
|
end_index = int(domain_bounds['end'][0]) - 1 |
|
|
|
top_n_mutations = {} |
|
|
|
for i in range(len(sequence)): |
|
|
|
if start_index <= i <= end_index: |
|
masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
|
inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
mask_token_logits = logits[0, mask_token_index, :] |
|
|
|
top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist() |
|
mutation = [tokenizer.decode([token]) for token in top_n_tokens] |
|
top_n_mutations[(sequence[i], i)] = mutation |
|
|
|
original_residues = [] |
|
mutations = [] |
|
positions = [] |
|
|
|
for key, value in top_n_mutations.items(): |
|
original_residue, position = key |
|
original_residues.append(original_residue) |
|
mutations.append(value) |
|
positions.append(position + 1) |
|
|
|
df = pd.DataFrame({ |
|
'Original Residue': original_residues, |
|
'Predicted Residues (in order of decreasing likelihood)': mutations, |
|
'Position': positions |
|
}) |
|
|
|
return df |
|
|
|
demo = gr.Interface( |
|
fn=topn_tokens, |
|
inputs=[ |
|
"text", |
|
gr.Dataframe( |
|
headers=["start", "end"], |
|
datatype=["number", "number"], |
|
row_count=(1, "fixed"), |
|
col_count=(2, "fixed"), |
|
), |
|
gr.Dropdown([i for i in range(1, 21)]), |
|
], |
|
outputs="dataframe", |
|
description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|