File size: 2,768 Bytes
71532a4
 
 
 
 
 
 
6af3fc8
88f36c7
 
6af3fc8
71532a4
 
 
 
6af3fc8
71532a4
 
 
 
 
 
6af3fc8
71532a4
 
 
 
6af3fc8
71532a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88f36c7
 
 
 
 
 
 
 
 
71532a4
 
 
6af3fc8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import gradio as gr
from io import BytesIO
from PIL import Image

def get_heatmap(sequence):
  logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")

  # Load the tokenizer and model
  model_name = "ChatterjeeLab/FusOn-pLM"
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
  model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
  model.to(device)
  model.eval()

  all_logits = []
  for i in range(len(sequence)):
    # add a masked token
    masked_seq = sequence[:i] + '<mask>' + sequence[i+1:]

    # tokenize masked sequence
    inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True,max_length=2000)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # predict logits for the masked token
    with torch.no_grad():
      logits = model(**inputs).logits
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = logits[0, mask_token_index, :]
    top_1_tokens = torch.topk(mask_token_logits, 1, dim=1).indices[0].item()
    logits_array = mask_token_logits.cpu().numpy()

    # filter out non-amino acid tokens
    filtered_indices = list(range(4, 23 + 1))
    filtered_logits = logits_array[:, filtered_indices]
    all_logits.append(filtered_logits)

  token_indices = torch.arange(logits.size(-1))
  tokens = [tokenizer.decode([idx]) for idx in token_indices]
  filtered_tokens = [tokens[i] for i in filtered_indices]

  all_logits_array = np.vstack(all_logits)
  normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min())
  transposed_logits_array = normalized_logits_array.T



  # Plotting the heatmap
  step = 50
  y_tick_positions = np.arange(0, len(sequence), step)
  y_tick_labels = [str(pos) for pos in y_tick_positions]

  plt.figure(figsize=(15, 8))
  sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens)
  plt.title('Logits for masked per residue tokens')
  plt.ylabel('Token')
  plt.xlabel('Residue Index')
  plt.yticks(rotation=0)
  plt.xticks(y_tick_positions, y_tick_labels, rotation = 0)

# Save the figure to a BytesIO object
  buf = BytesIO()
  plt.savefig(buf, format='png')
  buf.seek(0)
  plt.close()
    
# Convert BytesIO object to an image
  img = Image.open(buf)
  return img


demo = gr.Interface(fn=get_heatmap, inputs="text", outputs="image")

demo.launch()