Spaces:
Sleeping
Sleeping
WIP: Start adding ProtT5
Browse files- hexviz/app.py +1 -0
- hexviz/attention.py +14 -4
hexviz/app.py
CHANGED
@@ -10,6 +10,7 @@ st.title("pLM Attention Visualization")
|
|
10 |
# Define list of model types
|
11 |
models = [
|
12 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
|
|
13 |
]
|
14 |
|
15 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
|
|
10 |
# Define list of model types
|
11 |
models = [
|
12 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
13 |
+
# Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
14 |
]
|
15 |
|
16 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
hexviz/attention.py
CHANGED
@@ -48,6 +48,7 @@ def get_sequences(structure: Structure) -> List[str]:
|
|
48 |
sequences.append(list(residues_single_letter))
|
49 |
return sequences
|
50 |
|
|
|
51 |
def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
|
52 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
53 |
tokenizer = T5Tokenizer.from_pretrained(
|
@@ -69,7 +70,7 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
|
|
69 |
|
70 |
@st.cache
|
71 |
def get_attention(
|
72 |
-
sequence:
|
73 |
):
|
74 |
if model_type == ModelType.TAPE_BERT:
|
75 |
tokenizer, model = get_tape_bert()
|
@@ -81,9 +82,18 @@ def get_attention(
|
|
81 |
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
82 |
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
83 |
elif model_type == ModelType.PROT_T5:
|
84 |
-
|
85 |
-
#
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
tokenizer, model = get_protT5()
|
88 |
else:
|
89 |
raise ValueError(f"Model {model_type} not supported")
|
|
|
48 |
sequences.append(list(residues_single_letter))
|
49 |
return sequences
|
50 |
|
51 |
+
@st.cache
|
52 |
def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
|
53 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
54 |
tokenizer = T5Tokenizer.from_pretrained(
|
|
|
70 |
|
71 |
@st.cache
|
72 |
def get_attention(
|
73 |
+
sequence: str, model_type: ModelType = ModelType.TAPE_BERT
|
74 |
):
|
75 |
if model_type == ModelType.TAPE_BERT:
|
76 |
tokenizer, model = get_tape_bert()
|
|
|
82 |
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
83 |
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
84 |
elif model_type == ModelType.PROT_T5:
|
85 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
86 |
+
# Introduce white-space between all amino acids
|
87 |
+
sequence = " ".join(sequence)
|
88 |
+
# tokenize sequences and pad up to the longest sequence in the batch
|
89 |
+
ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
|
90 |
+
|
91 |
+
input_ids = torch.tensor(ids['input_ids']).to(device)
|
92 |
+
attention_mask = torch.tensor(ids['attention_mask']).to(device)
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
|
96 |
+
|
97 |
tokenizer, model = get_protT5()
|
98 |
else:
|
99 |
raise ValueError(f"Model {model_type} not supported")
|