Kseniia-Kholina commited on
Commit
7f75a38
·
verified ·
1 Parent(s): 7ca3799

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -7
app.py CHANGED
@@ -1,18 +1,53 @@
1
  import gradio as gr
2
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
3
 
4
  def topn_tokens(sequence, domain_bounds, n):
5
- example_dict = {}
6
- chars = list(sequence)
7
- # Convert to integer after extracting from domain_bounds
8
  start_index = int(domain_bounds['start'][0]) - 1
9
  end_index = int(domain_bounds['end'][0]) - 1
10
 
 
 
11
  for i in range(len(sequence)):
 
12
  if start_index <= i <= end_index:
13
- example_dict[chars[i]] = 'yo'
14
-
15
- df = pd.DataFrame(list(example_dict.items()), columns=['Original Residue', 'Predicted Residues'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return df
17
 
18
  demo = gr.Interface(
@@ -25,7 +60,7 @@ demo = gr.Interface(
25
  row_count=(1, "fixed"),
26
  col_count=(2, "fixed"),
27
  ),
28
- gr.Dropdown([str(i) for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as strings
29
  ],
30
  outputs="dataframe",
31
  description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+
6
+ # Initialize tokenizer and model globally
7
+ model_name = "ChatterjeeLab/FusOn-pLM"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
+ model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+ model.eval()
13
 
14
  def topn_tokens(sequence, domain_bounds, n):
 
 
 
15
  start_index = int(domain_bounds['start'][0]) - 1
16
  end_index = int(domain_bounds['end'][0]) - 1
17
 
18
+ top_n_mutations = {}
19
+
20
  for i in range(len(sequence)):
21
+ # Only mask and unmask the residues within the specified domain boundaries
22
  if start_index <= i <= end_index:
23
+ masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]
24
+ inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000)
25
+ inputs = {k: v.to(device) for k, v in inputs.items()}
26
+ with torch.no_grad():
27
+ logits = model(**inputs).logits
28
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
29
+ mask_token_logits = logits[0, mask_token_index, :]
30
+ # Decode top n tokens
31
+ top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist()
32
+ mutation = [tokenizer.decode([token]) for token in top_n_tokens]
33
+ top_n_mutations[(sequence[i], i)] = mutation
34
+
35
+ original_residues = []
36
+ mutations = []
37
+ positions = []
38
+
39
+ for key, value in top_n_mutations.items():
40
+ original_residue, position = key
41
+ original_residues.append(original_residue)
42
+ mutations.append(value)
43
+ positions.append(position)
44
+
45
+ df = pd.DataFrame({
46
+ 'Original Residue': original_residues,
47
+ 'Mutation': mutations,
48
+ 'Position': positions
49
+ })
50
+
51
  return df
52
 
53
  demo = gr.Interface(
 
60
  row_count=(1, "fixed"),
61
  col_count=(2, "fixed"),
62
  ),
63
+ gr.Dropdown([i for i in range(1, 21)]), # Dropdown with numbers from 1 to 20 as integers
64
  ],
65
  outputs="dataframe",
66
  description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).",