File size: 2,739 Bytes
be331b6
a0ba8b5
7f75a38
 
f7a575a
7f75a38
f7a575a
 
 
 
 
7f75a38
 
 
 
 
be331b6
1b1c116
7ca3799
 
1b1c116
7f75a38
 
1b1c116
7f75a38
1b1c116
7f75a38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447ecb9
7f75a38
 
 
6b081f8
7f75a38
 
 
a0ba8b5
 
be331b6
1b1c116
 
be331b6
1b1c116
 
 
d8c7623
1b1c116
 
7f75a38
be331b6
1b1c116
 
be331b6
 
 
a0ba8b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging

logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the tokenizer and model
model_name = "ChatterjeeLab/FusOn-pLM"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
model.to(device)
model.eval()

def topn_tokens(sequence, domain_bounds, n):
    start_index = int(domain_bounds['start'][0]) - 1  
    end_index = int(domain_bounds['end'][0]) - 1      

    top_n_mutations = {}

    for i in range(len(sequence)):
        # Only mask and unmask the residues within the specified domain boundaries
        if start_index <= i <= end_index:
            masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
            inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.no_grad():
                logits = model(**inputs).logits
            mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
            mask_token_logits = logits[0, mask_token_index, :]
            # Decode top n tokens
            top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
            mutation = [tokenizer.decode([token]) for token in top_n_tokens]
            top_n_mutations[(sequence[i], i)] = mutation   

    original_residues = []
    mutations = []
    positions = []

    for key, value in top_n_mutations.items():
        original_residue, position = key
        original_residues.append(original_residue)
        mutations.append(value)
        positions.append(position + 1)

    df = pd.DataFrame({
        'Original Residue': original_residues,
        'Predicted Residues (in order of decreasing likelihood)': mutations,
        'Position': positions
    })
        
    return df

demo = gr.Interface(
    fn=topn_tokens,
    inputs=[
        "text",
        gr.Dataframe(
            headers=["start", "end"],
            datatype=["number", "number"],
            row_count=(1, "fixed"),
            col_count=(2, "fixed"),
        ),
        gr.Dropdown([i for i in range(1, 21)]),  # Dropdown with numbers from 1 to 20 as integers
    ],
    outputs="dataframe",
    description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
)

if __name__ == "__main__":
    demo.launch()