Spaces:
Sleeping
Sleeping
Add ruff, run ruff and black
Browse files- hexviz/attention.py +20 -37
- hexviz/ec_number.py +1 -3
- hexviz/models.py +1 -3
- hexviz/pages/1_🗺️Identify_Interesting_Heads.py +7 -10
- hexviz/pages/2_📄Documentation.py +42 -19
- hexviz/plot.py +4 -12
- hexviz/view.py +9 -8
- hexviz/🧬Attention_Visualization.py +19 -40
- poetry.lock +10 -1
- pyproject.toml +7 -0
- tests/test_attention.py +22 -15
- tests/test_models.py +1 -2
hexviz/attention.py
CHANGED
|
@@ -68,18 +68,14 @@ def res_to_1letter(residues: list[Residue]) -> str:
|
|
| 68 |
Residues not in the standard 20 amino acids are replaced with X
|
| 69 |
"""
|
| 70 |
res_names = [residue.get_resname() for residue in residues]
|
| 71 |
-
residues_single_letter = map(
|
| 72 |
-
lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names
|
| 73 |
-
)
|
| 74 |
|
| 75 |
return "".join(list(residues_single_letter))
|
| 76 |
|
| 77 |
|
| 78 |
def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
| 79 |
lines = sequence.split("\n")
|
| 80 |
-
cleaned_sequence = "".join(
|
| 81 |
-
line.upper() for line in lines if not line.startswith(">")
|
| 82 |
-
)
|
| 83 |
cleaned_sequence = cleaned_sequence.replace(" ", "")
|
| 84 |
valid_residues = set(Polypeptide.protein_letters_3to1.values())
|
| 85 |
residues_in_sequence = set(cleaned_sequence)
|
|
@@ -87,7 +83,9 @@ def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
|
| 87 |
# Check if the sequence exceeds the max allowed length
|
| 88 |
max_sequence_length = 400
|
| 89 |
if len(cleaned_sequence) > max_sequence_length:
|
| 90 |
-
error_message =
|
|
|
|
|
|
|
| 91 |
return cleaned_sequence, error_message
|
| 92 |
|
| 93 |
illegal_residues = residues_in_sequence - valid_residues
|
|
@@ -103,9 +101,7 @@ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
|
|
| 103 |
tokens = tokenizer.tokenize(sequence)
|
| 104 |
|
| 105 |
indices_to_remove = [
|
| 106 |
-
i
|
| 107 |
-
for i, token in enumerate(tokens)
|
| 108 |
-
if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
|
| 109 |
]
|
| 110 |
|
| 111 |
new_attentions = []
|
|
@@ -113,9 +109,7 @@ def remove_special_tokens_and_periods(attentions_tuple, sequence, tokenizer):
|
|
| 113 |
for attentions in attentions_tuple:
|
| 114 |
# Remove rows and columns corresponding to special tokens and periods
|
| 115 |
for idx in sorted(indices_to_remove, reverse=True):
|
| 116 |
-
attentions = torch.cat(
|
| 117 |
-
(attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2
|
| 118 |
-
)
|
| 119 |
attentions = torch.cat(
|
| 120 |
(attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
|
| 121 |
)
|
|
@@ -131,7 +125,7 @@ def get_attention(
|
|
| 131 |
sequence: str,
|
| 132 |
model_type: ModelType = ModelType.TAPE_BERT,
|
| 133 |
remove_special_tokens: bool = True,
|
| 134 |
-
ec_number:
|
| 135 |
):
|
| 136 |
"""
|
| 137 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
|
@@ -153,24 +147,18 @@ def get_attention(
|
|
| 153 |
tokenizer, model = get_zymctrl()
|
| 154 |
|
| 155 |
if ec_number:
|
| 156 |
-
sequence = f"{
|
| 157 |
|
| 158 |
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
| 159 |
-
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(
|
| 160 |
-
device
|
| 161 |
-
)
|
| 162 |
|
| 163 |
with torch.no_grad():
|
| 164 |
-
outputs = model(
|
| 165 |
-
inputs, attention_mask=attention_mask, output_attentions=True
|
| 166 |
-
)
|
| 167 |
attentions = outputs.attentions
|
| 168 |
|
| 169 |
if ec_number:
|
| 170 |
# Remove attention to special tokens and periods separating EC number components
|
| 171 |
-
attentions = remove_special_tokens_and_periods(
|
| 172 |
-
attentions, sequence, tokenizer
|
| 173 |
-
)
|
| 174 |
|
| 175 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
| 176 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
|
@@ -196,9 +184,7 @@ def get_attention(
|
|
| 196 |
token_idxs = tokenizer.encode(sequence_separated)
|
| 197 |
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
| 198 |
with torch.no_grad():
|
| 199 |
-
attentions = model(inputs, output_attentions=True)[
|
| 200 |
-
-1
|
| 201 |
-
] # Do you need an attention mask?
|
| 202 |
|
| 203 |
if remove_special_tokens:
|
| 204 |
# Remove attention to </s> (last) token
|
|
@@ -262,17 +248,16 @@ def get_attention_pairs(
|
|
| 262 |
top_residues = []
|
| 263 |
|
| 264 |
ec_tag_length = 4
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
|
| 267 |
for i, chain in enumerate(chains):
|
| 268 |
ec_number = ec_numbers[i] if ec_numbers else None
|
|
|
|
| 269 |
sequence = res_to_1letter(chain)
|
| 270 |
-
attention = get_attention(
|
| 271 |
-
|
| 272 |
-
)
|
| 273 |
-
attention_unidirectional = unidirectional_avg_filtered(
|
| 274 |
-
attention, layer, head, threshold
|
| 275 |
-
)
|
| 276 |
|
| 277 |
# Store sum of attention in to a resiue (from the unidirectional attention)
|
| 278 |
residue_attention = {}
|
|
@@ -305,9 +290,7 @@ def get_attention_pairs(
|
|
| 305 |
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
| 306 |
)
|
| 307 |
|
| 308 |
-
top_n_residues = sorted(
|
| 309 |
-
residue_attention.items(), key=lambda x: x[1], reverse=True
|
| 310 |
-
)[:top_n]
|
| 311 |
|
| 312 |
for res, attn_sum in top_n_residues:
|
| 313 |
coord = chain[res]["CA"].coord.tolist()
|
|
|
|
| 68 |
Residues not in the standard 20 amino acids are replaced with X
|
| 69 |
"""
|
| 70 |
res_names = [residue.get_resname() for residue in residues]
|
| 71 |
+
residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), res_names)
|
|
|
|
|
|
|
| 72 |
|
| 73 |
return "".join(list(residues_single_letter))
|
| 74 |
|
| 75 |
|
| 76 |
def clean_and_validate_sequence(sequence: str) -> tuple[str, str | None]:
|
| 77 |
lines = sequence.split("\n")
|
| 78 |
+
cleaned_sequence = "".join(line.upper() for line in lines if not line.startswith(">"))
|
|
|
|
|
|
|
| 79 |
cleaned_sequence = cleaned_sequence.replace(" ", "")
|
| 80 |
valid_residues = set(Polypeptide.protein_letters_3to1.values())
|
| 81 |
residues_in_sequence = set(cleaned_sequence)
|
|
|
|
| 83 |
# Check if the sequence exceeds the max allowed length
|
| 84 |
max_sequence_length = 400
|
| 85 |
if len(cleaned_sequence) > max_sequence_length:
|
| 86 |
+
error_message = (
|
| 87 |
+
f"Sequence exceeds the max allowed length of {max_sequence_length} characters"
|
| 88 |
+
)
|
| 89 |
return cleaned_sequence, error_message
|
| 90 |
|
| 91 |
illegal_residues = residues_in_sequence - valid_residues
|
|
|
|
| 101 |
tokens = tokenizer.tokenize(sequence)
|
| 102 |
|
| 103 |
indices_to_remove = [
|
| 104 |
+
i for i, token in enumerate(tokens) if token in {".", "<sep>", "<start>", "<end>", "<pad>"}
|
|
|
|
|
|
|
| 105 |
]
|
| 106 |
|
| 107 |
new_attentions = []
|
|
|
|
| 109 |
for attentions in attentions_tuple:
|
| 110 |
# Remove rows and columns corresponding to special tokens and periods
|
| 111 |
for idx in sorted(indices_to_remove, reverse=True):
|
| 112 |
+
attentions = torch.cat((attentions[:, :, :idx], attentions[:, :, idx + 1 :]), dim=2)
|
|
|
|
|
|
|
| 113 |
attentions = torch.cat(
|
| 114 |
(attentions[:, :, :, :idx], attentions[:, :, :, idx + 1 :]), dim=3
|
| 115 |
)
|
|
|
|
| 125 |
sequence: str,
|
| 126 |
model_type: ModelType = ModelType.TAPE_BERT,
|
| 127 |
remove_special_tokens: bool = True,
|
| 128 |
+
ec_number: str = None,
|
| 129 |
):
|
| 130 |
"""
|
| 131 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
|
|
|
| 147 |
tokenizer, model = get_zymctrl()
|
| 148 |
|
| 149 |
if ec_number:
|
| 150 |
+
sequence = f"{ec_number}<sep><start>{sequence}<end><pad>"
|
| 151 |
|
| 152 |
inputs = tokenizer(sequence, return_tensors="pt").input_ids.to(device)
|
| 153 |
+
attention_mask = tokenizer(sequence, return_tensors="pt").attention_mask.to(device)
|
|
|
|
|
|
|
| 154 |
|
| 155 |
with torch.no_grad():
|
| 156 |
+
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
|
|
|
|
|
|
| 157 |
attentions = outputs.attentions
|
| 158 |
|
| 159 |
if ec_number:
|
| 160 |
# Remove attention to special tokens and periods separating EC number components
|
| 161 |
+
attentions = remove_special_tokens_and_periods(attentions, sequence, tokenizer)
|
|
|
|
|
|
|
| 162 |
|
| 163 |
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
| 164 |
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
|
|
|
| 184 |
token_idxs = tokenizer.encode(sequence_separated)
|
| 185 |
inputs = torch.tensor(token_idxs).unsqueeze(0).to(device)
|
| 186 |
with torch.no_grad():
|
| 187 |
+
attentions = model(inputs, output_attentions=True)[-1] # Do you need an attention mask?
|
|
|
|
|
|
|
| 188 |
|
| 189 |
if remove_special_tokens:
|
| 190 |
# Remove attention to </s> (last) token
|
|
|
|
| 248 |
top_residues = []
|
| 249 |
|
| 250 |
ec_tag_length = 4
|
| 251 |
+
|
| 252 |
+
def is_tag(x):
|
| 253 |
+
return x < ec_tag_length
|
| 254 |
|
| 255 |
for i, chain in enumerate(chains):
|
| 256 |
ec_number = ec_numbers[i] if ec_numbers else None
|
| 257 |
+
ec_string = ".".join([ec.number for ec in ec_number]) if ec_number else ""
|
| 258 |
sequence = res_to_1letter(chain)
|
| 259 |
+
attention = get_attention(sequence=sequence, model_type=model_type, ec_number=ec_string)
|
| 260 |
+
attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
# Store sum of attention in to a resiue (from the unidirectional attention)
|
| 263 |
residue_attention = {}
|
|
|
|
| 290 |
residue_attention.get(res - ec_tag_length, 0) + attn_value
|
| 291 |
)
|
| 292 |
|
| 293 |
+
top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
|
|
|
|
|
|
| 294 |
|
| 295 |
for res, attn_sum in top_n_residues:
|
| 296 |
coord = chain[res]["CA"].coord.tolist()
|
hexviz/ec_number.py
CHANGED
|
@@ -6,6 +6,4 @@ class ECNumber:
|
|
| 6 |
self.radius = radius
|
| 7 |
|
| 8 |
def __str__(self):
|
| 9 |
-
return (
|
| 10 |
-
f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
|
| 11 |
-
)
|
|
|
|
| 6 |
self.radius = radius
|
| 7 |
|
| 8 |
def __str__(self):
|
| 9 |
+
return f"(EC: {self.number}, Coordinate: {self.coordinate}, Color: {self.color})"
|
|
|
|
|
|
hexviz/models.py
CHANGED
|
@@ -60,7 +60,5 @@ def get_prot_t5():
|
|
| 60 |
tokenizer = T5Tokenizer.from_pretrained(
|
| 61 |
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
| 62 |
)
|
| 63 |
-
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
|
| 64 |
-
device
|
| 65 |
-
)
|
| 66 |
return tokenizer, model
|
|
|
|
| 60 |
tokenizer = T5Tokenizer.from_pretrained(
|
| 61 |
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
| 62 |
)
|
| 63 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)
|
|
|
|
|
|
|
| 64 |
return tokenizer, model
|
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
|
@@ -27,14 +27,10 @@ models = [
|
|
| 27 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
| 28 |
]
|
| 29 |
|
| 30 |
-
with st.expander(
|
| 31 |
-
"Input a PDB id, upload a PDB file or input a sequence", expanded=True
|
| 32 |
-
):
|
| 33 |
pdb_id = select_pdb()
|
| 34 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
| 35 |
-
input_sequence = st.text_area(
|
| 36 |
-
"3.Input sequence", "", key="input_sequence", max_chars=400
|
| 37 |
-
)
|
| 38 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
| 39 |
if error:
|
| 40 |
st.error(error)
|
|
@@ -65,7 +61,9 @@ truncated_sequence = sequence[slice_start - 1 : slice_end]
|
|
| 65 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
| 66 |
|
| 67 |
st.markdown(
|
| 68 |
-
f"Each tile is a heatmap of attention for a section of the {source} chain
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
|
| 71 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
|
@@ -74,11 +72,10 @@ attention = get_attention(
|
|
| 74 |
sequence=truncated_sequence,
|
| 75 |
model_type=selected_model.name,
|
| 76 |
remove_special_tokens=True,
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
-
fig = plot_tiled_heatmap(
|
| 80 |
-
attention, layer_sequence=layer_sequence, head_sequence=head_sequence
|
| 81 |
-
)
|
| 82 |
|
| 83 |
|
| 84 |
st.pyplot(fig)
|
|
|
|
| 27 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
| 28 |
]
|
| 29 |
|
| 30 |
+
with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
|
|
|
|
|
|
|
| 31 |
pdb_id = select_pdb()
|
| 32 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
| 33 |
+
input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
|
|
|
|
|
|
|
| 34 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
| 35 |
if error:
|
| 36 |
st.error(error)
|
|
|
|
| 61 |
layer_sequence, head_sequence = select_heads_and_layers(st.sidebar, selected_model)
|
| 62 |
|
| 63 |
st.markdown(
|
| 64 |
+
f"""Each tile is a heatmap of attention for a section of the {source} chain
|
| 65 |
+
({chain_selection}) from residue {slice_start} to {slice_end}. Adjust the
|
| 66 |
+
section length and starting point in the sidebar."""
|
| 67 |
)
|
| 68 |
|
| 69 |
# TODO: Decide if you should get attention for the full sequence or just the truncated sequence
|
|
|
|
| 72 |
sequence=truncated_sequence,
|
| 73 |
model_type=selected_model.name,
|
| 74 |
remove_special_tokens=True,
|
| 75 |
+
ec_number=ec_number,
|
| 76 |
)
|
| 77 |
|
| 78 |
+
fig = plot_tiled_heatmap(attention, layer_sequence=layer_sequence, head_sequence=head_sequence)
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
st.pyplot(fig)
|
hexviz/pages/2_📄Documentation.py
CHANGED
|
@@ -5,42 +5,65 @@ from hexviz.config import URL
|
|
| 5 |
st.markdown(
|
| 6 |
f"""
|
| 7 |
## Protein language models
|
| 8 |
-
There has been an explosion of capabilities in natural language processing
|
| 9 |
-
These architectural advances from NLP have proven
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
For an introduction to protein language models for protein design check out
|
|
|
|
|
|
|
| 13 |
|
| 14 |
## Interpreting protein language models by visualizing attention patterns
|
| 15 |
-
With these impressive capabilities it is natural to ask what protein language
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
Most existing tools for analyzing and visualizing attention patterns focus on
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
## How to use Hexviz
|
| 27 |
There are two views:
|
| 28 |
1. <a href="{URL}Attention_Visualization" target="_self">🧬Attention Visualization</a> Shows attention weights from a single head as red bars between residues on a protein structure.
|
| 29 |
2. <a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> Plots attention weights between residues as a heatmap for each head in the model.
|
| 30 |
|
| 31 |
-
The first view is the meat of the application and is where you can investigate
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
The second view is a customizable heatmap plot of attention between residue for
|
| 36 |
-
a
|
|
|
|
| 37 |
1. Vertical lines: Paying attention so a single or a few residues
|
| 38 |
2. Diagonal: Attention to the same residue or residues in front or behind the current residue.
|
| 39 |
3. Block attention: Attention is segmented so parts of the sequence are attended to by one part of the sequence.
|
| 40 |
4. Heterogeneous: More complex attention patterns that are not easily categorized.
|
| 41 |
TODO: Add examples of attention patterns
|
| 42 |
|
| 43 |
-
Read more about attention patterns in fex [Revealing the dark secrets of
|
|
|
|
| 44 |
|
| 45 |
## Protein Language models in Hexviz
|
| 46 |
Hexviz currently supports the following models:
|
|
|
|
| 5 |
st.markdown(
|
| 6 |
f"""
|
| 7 |
## Protein language models
|
| 8 |
+
There has been an explosion of capabilities in natural language processing
|
| 9 |
+
models in the last few years. These architectural advances from NLP have proven
|
| 10 |
+
to work very well for protein sequences, and we now have protein language models
|
| 11 |
+
(pLMs) that can generate novel functional proteins sequences
|
| 12 |
+
[ProtGPT2](https://www.nature.com/articles/s42256-022-00499-z) and auto-encoding
|
| 13 |
+
models that excel at capturing biophysical features of protein sequences
|
| 14 |
+
[ProtTrans](https://www.biorxiv.org/content/10.1101/2020.07.12.199554v3).
|
| 15 |
|
| 16 |
+
For an introduction to protein language models for protein design check out
|
| 17 |
+
[Controllable protein design with language
|
| 18 |
+
models](https://www.nature.com/articles/s42256-022-00499-z).
|
| 19 |
|
| 20 |
## Interpreting protein language models by visualizing attention patterns
|
| 21 |
+
With these impressive capabilities it is natural to ask what protein language
|
| 22 |
+
models are learning and how they work -- we want to **interpret** the models.
|
| 23 |
+
In natural language processing **attention analysis** has proven to be a useful
|
| 24 |
+
tool for interpreting transformer model internals see fex ([Abnar et al.
|
| 25 |
+
2020](https://arxiv.org/abs/2005.00928v2)). [BERTology meets
|
| 26 |
+
biology](https://arxiv.org/abs/2006.15222) provides a thorough introduction to
|
| 27 |
+
how we can analyze Transformer protein models through the lens of attention,
|
| 28 |
+
they show exciting findings such as: > Attention: (1) captures the folding
|
| 29 |
+
structure of proteins, connecting amino acids that are far apart in the
|
| 30 |
+
underlying sequence, but spatially close in the three-dimensional structure, (2)
|
| 31 |
+
targets binding sites, a key functional component of proteins, and (3) focuses
|
| 32 |
+
on progressively more complex biophysical properties with increasing layer depth
|
| 33 |
|
| 34 |
+
Most existing tools for analyzing and visualizing attention patterns focus on
|
| 35 |
+
models trained on text. It can be hard to analyze protein sequences using these
|
| 36 |
+
tools as sequences can be long and we lack intuition about how the language of
|
| 37 |
+
proteins work. BERTology meets biology shows visualizing attention patterns in
|
| 38 |
+
the context of protein structure can facilitate novel discoveries about what
|
| 39 |
+
models learn. [**Hexviz**](https://huggingface.co/spaces/aksell/hexviz) is a
|
| 40 |
+
tool to simplify analyzing attention patterns in the context of protein
|
| 41 |
+
structure. We hope this can enable domain experts to explore and interpret the
|
| 42 |
+
knowledge contained in pLMs.
|
| 43 |
|
| 44 |
## How to use Hexviz
|
| 45 |
There are two views:
|
| 46 |
1. <a href="{URL}Attention_Visualization" target="_self">🧬Attention Visualization</a> Shows attention weights from a single head as red bars between residues on a protein structure.
|
| 47 |
2. <a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> Plots attention weights between residues as a heatmap for each head in the model.
|
| 48 |
|
| 49 |
+
The first view is the meat of the application and is where you can investigate
|
| 50 |
+
how attention patterns map onto the structure of a protein you're interested in.
|
| 51 |
+
Use the second view to narrow down to a few heads that you want to investigate
|
| 52 |
+
attention patterns from in detail. pLM are large and can have many heads, as an
|
| 53 |
+
example ProtBERT with it's 30 layers and 16 heads has 480 heads, so we need a
|
| 54 |
+
way to identify heads with patterns we're interested in.
|
| 55 |
|
| 56 |
+
The second view is a customizable heatmap plot of attention between residue for
|
| 57 |
+
all heads and layers in a model. From here it is possible to identify heads that
|
| 58 |
+
specialize in a particular attention pattern, such as:
|
| 59 |
1. Vertical lines: Paying attention so a single or a few residues
|
| 60 |
2. Diagonal: Attention to the same residue or residues in front or behind the current residue.
|
| 61 |
3. Block attention: Attention is segmented so parts of the sequence are attended to by one part of the sequence.
|
| 62 |
4. Heterogeneous: More complex attention patterns that are not easily categorized.
|
| 63 |
TODO: Add examples of attention patterns
|
| 64 |
|
| 65 |
+
Read more about attention patterns in fex [Revealing the dark secrets of
|
| 66 |
+
BERT](https://arxiv.org/abs/1908.08593).
|
| 67 |
|
| 68 |
## Protein Language models in Hexviz
|
| 69 |
Hexviz currently supports the following models:
|
hexviz/plot.py
CHANGED
|
@@ -15,30 +15,22 @@ def plot_tiled_heatmap(tensor, layer_sequence: List[int], head_sequence: List[in
|
|
| 15 |
|
| 16 |
x_size = num_heads * 2
|
| 17 |
y_size = num_layers * 2
|
| 18 |
-
fig, axes = plt.subplots(
|
| 19 |
-
num_layers, num_heads, figsize=(x_size, y_size), squeeze=False
|
| 20 |
-
)
|
| 21 |
for i in range(num_layers):
|
| 22 |
for j in range(num_heads):
|
| 23 |
-
axes[i, j].imshow(
|
| 24 |
-
tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal"
|
| 25 |
-
)
|
| 26 |
axes[i, j].axis("off")
|
| 27 |
|
| 28 |
# Enumerate the axes
|
| 29 |
if i == 0:
|
| 30 |
-
axes[i, j].set_title(
|
| 31 |
-
f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05
|
| 32 |
-
)
|
| 33 |
|
| 34 |
# Calculate the row label offset based on the number of columns
|
| 35 |
offset = 0.02 + (12 - num_heads) * 0.0015
|
| 36 |
for i, ax_row in enumerate(axes):
|
| 37 |
row_label = f"{layer_sequence[i]+1}"
|
| 38 |
row_pos = ax_row[num_heads - 1].get_position()
|
| 39 |
-
fig.text(
|
| 40 |
-
row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center"
|
| 41 |
-
)
|
| 42 |
|
| 43 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
| 44 |
return fig
|
|
|
|
| 15 |
|
| 16 |
x_size = num_heads * 2
|
| 17 |
y_size = num_layers * 2
|
| 18 |
+
fig, axes = plt.subplots(num_layers, num_heads, figsize=(x_size, y_size), squeeze=False)
|
|
|
|
|
|
|
| 19 |
for i in range(num_layers):
|
| 20 |
for j in range(num_heads):
|
| 21 |
+
axes[i, j].imshow(tensor[i, j].detach().numpy(), cmap="viridis", aspect="equal")
|
|
|
|
|
|
|
| 22 |
axes[i, j].axis("off")
|
| 23 |
|
| 24 |
# Enumerate the axes
|
| 25 |
if i == 0:
|
| 26 |
+
axes[i, j].set_title(f"Head {head_sequence[j] + 1}", fontsize=10, y=1.05)
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Calculate the row label offset based on the number of columns
|
| 29 |
offset = 0.02 + (12 - num_heads) * 0.0015
|
| 30 |
for i, ax_row in enumerate(axes):
|
| 31 |
row_label = f"{layer_sequence[i]+1}"
|
| 32 |
row_pos = ax_row[num_heads - 1].get_position()
|
| 33 |
+
fig.text(row_pos.x1 + offset, (row_pos.y1 + row_pos.y0) / 2, row_label, va="center")
|
|
|
|
|
|
|
| 34 |
|
| 35 |
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
| 36 |
return fig
|
hexviz/view.py
CHANGED
|
@@ -18,11 +18,7 @@ def get_selecte_model_index(models):
|
|
| 18 |
return 0
|
| 19 |
else:
|
| 20 |
return next(
|
| 21 |
-
(
|
| 22 |
-
i
|
| 23 |
-
for i, model in enumerate(models)
|
| 24 |
-
if model.name.value == selected_model_name
|
| 25 |
-
),
|
| 26 |
None,
|
| 27 |
)
|
| 28 |
|
|
@@ -89,10 +85,10 @@ def select_protein(pdb_code, uploaded_file, input_sequence):
|
|
| 89 |
pdb_str = get_pdb_from_seq(str(input_sequence))
|
| 90 |
if "selected_chains" in st.session_state:
|
| 91 |
del st.session_state.selected_chains
|
| 92 |
-
source =
|
| 93 |
elif "uploaded_pdb_str" in st.session_state:
|
| 94 |
pdb_str = st.session_state.uploaded_pdb_str
|
| 95 |
-
source =
|
| 96 |
else:
|
| 97 |
file = get_pdb_file(pdb_code)
|
| 98 |
pdb_str = file.read()
|
|
@@ -135,7 +131,12 @@ def select_heads_and_layers(sidebar, model):
|
|
| 135 |
|
| 136 |
|
| 137 |
def select_sequence_slice(sequence_length):
|
| 138 |
-
st.sidebar.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
if "sequence_slice" not in st.session_state:
|
| 140 |
st.session_state.sequence_slice = (1, min(50, sequence_length))
|
| 141 |
slice = st.sidebar.slider(
|
|
|
|
| 18 |
return 0
|
| 19 |
else:
|
| 20 |
return next(
|
| 21 |
+
(i for i, model in enumerate(models) if model.name.value == selected_model_name),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
None,
|
| 23 |
)
|
| 24 |
|
|
|
|
| 85 |
pdb_str = get_pdb_from_seq(str(input_sequence))
|
| 86 |
if "selected_chains" in st.session_state:
|
| 87 |
del st.session_state.selected_chains
|
| 88 |
+
source = "Input sequence + ESM-fold"
|
| 89 |
elif "uploaded_pdb_str" in st.session_state:
|
| 90 |
pdb_str = st.session_state.uploaded_pdb_str
|
| 91 |
+
source = "Uploaded file stored in cache"
|
| 92 |
else:
|
| 93 |
file = get_pdb_file(pdb_code)
|
| 94 |
pdb_str = file.read()
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
def select_sequence_slice(sequence_length):
|
| 134 |
+
st.sidebar.markdown(
|
| 135 |
+
"""
|
| 136 |
+
Sequence segment to plot
|
| 137 |
+
---
|
| 138 |
+
"""
|
| 139 |
+
)
|
| 140 |
if "sequence_slice" not in st.session_state:
|
| 141 |
st.session_state.sequence_slice = (1, min(50, sequence_length))
|
| 142 |
slice = st.sidebar.slider(
|
hexviz/🧬Attention_Visualization.py
CHANGED
|
@@ -31,14 +31,10 @@ models = [
|
|
| 31 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
| 32 |
]
|
| 33 |
|
| 34 |
-
with st.expander(
|
| 35 |
-
|
| 36 |
-
):
|
| 37 |
-
pdb_id = select_pdb()
|
| 38 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
| 39 |
-
input_sequence = st.text_area(
|
| 40 |
-
"3.Input sequence", "", key="input_sequence", max_chars=400
|
| 41 |
-
)
|
| 42 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
| 43 |
if error:
|
| 44 |
st.error(error)
|
|
@@ -59,9 +55,7 @@ selected_chains = st.sidebar.multiselect(
|
|
| 59 |
label="Select Chain(s)", options=chains, key="selected_chains"
|
| 60 |
)
|
| 61 |
|
| 62 |
-
show_ligands = st.sidebar.checkbox(
|
| 63 |
-
"Show ligands", value=st.session_state.get("show_ligands", True)
|
| 64 |
-
)
|
| 65 |
st.session_state.show_ligands = show_ligands
|
| 66 |
|
| 67 |
|
|
@@ -71,9 +65,7 @@ st.sidebar.markdown(
|
|
| 71 |
---
|
| 72 |
"""
|
| 73 |
)
|
| 74 |
-
min_attn = st.sidebar.slider(
|
| 75 |
-
"Minimum attention", min_value=0.0, max_value=0.4, value=0.1
|
| 76 |
-
)
|
| 77 |
n_highest_resis = st.sidebar.number_input(
|
| 78 |
"Num highest attention resis to label", value=2, min_value=1, max_value=100
|
| 79 |
)
|
|
@@ -84,9 +76,7 @@ sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
|
|
| 84 |
|
| 85 |
with st.sidebar.expander("Label residues manually"):
|
| 86 |
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
|
| 87 |
-
hl_resi_list = st.multiselect(
|
| 88 |
-
label="Selected Residues", options=list(range(1, 5000))
|
| 89 |
-
)
|
| 90 |
|
| 91 |
label_resi = st.checkbox(label="Label Residues", value=True)
|
| 92 |
|
|
@@ -97,10 +87,13 @@ with left:
|
|
| 97 |
with mid:
|
| 98 |
if "selected_layer" not in st.session_state:
|
| 99 |
st.session_state["selected_layer"] = 5
|
| 100 |
-
layer_one =
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
layer = layer_one - 1
|
| 106 |
with right:
|
|
@@ -135,9 +128,7 @@ if selected_model.name == ModelType.ZymCTRL:
|
|
| 135 |
|
| 136 |
if ec_number:
|
| 137 |
if selected_chains:
|
| 138 |
-
shown_chains = [
|
| 139 |
-
ch for ch in structure.get_chains() if ch.id in selected_chains
|
| 140 |
-
]
|
| 141 |
else:
|
| 142 |
shown_chains = list(structure.get_chains())
|
| 143 |
|
|
@@ -163,14 +154,9 @@ if selected_model.name == ModelType.ZymCTRL:
|
|
| 163 |
reverse_vector = [-v for v in vector]
|
| 164 |
|
| 165 |
# Normalize the reverse vector
|
| 166 |
-
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(
|
| 167 |
-
reverse_vector
|
| 168 |
-
)
|
| 169 |
coordinates = [
|
| 170 |
-
[
|
| 171 |
-
res_1[j] + i * 2 * radius * reverse_vector_normalized[j]
|
| 172 |
-
for j in range(3)
|
| 173 |
-
]
|
| 174 |
for i in range(4)
|
| 175 |
]
|
| 176 |
EC_tag = [
|
|
@@ -213,9 +199,7 @@ def get_3dview(pdb):
|
|
| 213 |
for chain in hidden_chains:
|
| 214 |
xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
|
| 215 |
# Hide ligands for chain too
|
| 216 |
-
xyzview.addStyle(
|
| 217 |
-
{"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}}
|
| 218 |
-
)
|
| 219 |
|
| 220 |
if len(selected_chains) == 1:
|
| 221 |
xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
|
|
@@ -257,7 +241,6 @@ def get_3dview(pdb):
|
|
| 257 |
for _, _, chain, res in top_residues:
|
| 258 |
one_indexed_res = res + 1
|
| 259 |
xyzview.addResLabels(
|
| 260 |
-
|
| 261 |
{"chain": chain, "resi": one_indexed_res},
|
| 262 |
{
|
| 263 |
"backgroundColor": "lightgray",
|
|
@@ -266,9 +249,7 @@ def get_3dview(pdb):
|
|
| 266 |
},
|
| 267 |
)
|
| 268 |
if sidechain_highest:
|
| 269 |
-
xyzview.addStyle(
|
| 270 |
-
{"chain": chain, "resi": res}, {"stick": {"radius": 0.2}}
|
| 271 |
-
)
|
| 272 |
return xyzview
|
| 273 |
|
| 274 |
|
|
@@ -282,9 +263,7 @@ Pick a PDB ID, layer and head to visualize attention from the selected protein l
|
|
| 282 |
unsafe_allow_html=True,
|
| 283 |
)
|
| 284 |
|
| 285 |
-
chain_dict = {
|
| 286 |
-
f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())
|
| 287 |
-
}
|
| 288 |
data = []
|
| 289 |
for att_weight, _, chain, resi in top_residues:
|
| 290 |
try:
|
|
|
|
| 31 |
Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
| 32 |
]
|
| 33 |
|
| 34 |
+
with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
|
| 35 |
+
pdb_id = select_pdb() or "2WK4"
|
|
|
|
|
|
|
| 36 |
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
| 37 |
+
input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
|
|
|
|
|
|
|
| 38 |
sequence, error = clean_and_validate_sequence(input_sequence)
|
| 39 |
if error:
|
| 40 |
st.error(error)
|
|
|
|
| 55 |
label="Select Chain(s)", options=chains, key="selected_chains"
|
| 56 |
)
|
| 57 |
|
| 58 |
+
show_ligands = st.sidebar.checkbox("Show ligands", value=st.session_state.get("show_ligands", True))
|
|
|
|
|
|
|
| 59 |
st.session_state.show_ligands = show_ligands
|
| 60 |
|
| 61 |
|
|
|
|
| 65 |
---
|
| 66 |
"""
|
| 67 |
)
|
| 68 |
+
min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
|
|
|
|
|
|
| 69 |
n_highest_resis = st.sidebar.number_input(
|
| 70 |
"Num highest attention resis to label", value=2, min_value=1, max_value=100
|
| 71 |
)
|
|
|
|
| 76 |
|
| 77 |
with st.sidebar.expander("Label residues manually"):
|
| 78 |
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
|
| 79 |
+
hl_resi_list = st.multiselect(label="Selected Residues", options=list(range(1, 5000)))
|
|
|
|
|
|
|
| 80 |
|
| 81 |
label_resi = st.checkbox(label="Label Residues", value=True)
|
| 82 |
|
|
|
|
| 87 |
with mid:
|
| 88 |
if "selected_layer" not in st.session_state:
|
| 89 |
st.session_state["selected_layer"] = 5
|
| 90 |
+
layer_one = (
|
| 91 |
+
st.selectbox(
|
| 92 |
+
"Layer",
|
| 93 |
+
options=[i for i in range(1, selected_model.layers + 1)],
|
| 94 |
+
key="selected_layer",
|
| 95 |
+
)
|
| 96 |
+
or 5
|
| 97 |
)
|
| 98 |
layer = layer_one - 1
|
| 99 |
with right:
|
|
|
|
| 128 |
|
| 129 |
if ec_number:
|
| 130 |
if selected_chains:
|
| 131 |
+
shown_chains = [ch for ch in structure.get_chains() if ch.id in selected_chains]
|
|
|
|
|
|
|
| 132 |
else:
|
| 133 |
shown_chains = list(structure.get_chains())
|
| 134 |
|
|
|
|
| 154 |
reverse_vector = [-v for v in vector]
|
| 155 |
|
| 156 |
# Normalize the reverse vector
|
| 157 |
+
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(reverse_vector)
|
|
|
|
|
|
|
| 158 |
coordinates = [
|
| 159 |
+
[res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
|
|
|
|
|
|
|
|
|
|
| 160 |
for i in range(4)
|
| 161 |
]
|
| 162 |
EC_tag = [
|
|
|
|
| 199 |
for chain in hidden_chains:
|
| 200 |
xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
|
| 201 |
# Hide ligands for chain too
|
| 202 |
+
xyzview.addStyle({"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}})
|
|
|
|
|
|
|
| 203 |
|
| 204 |
if len(selected_chains) == 1:
|
| 205 |
xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
|
|
|
|
| 241 |
for _, _, chain, res in top_residues:
|
| 242 |
one_indexed_res = res + 1
|
| 243 |
xyzview.addResLabels(
|
|
|
|
| 244 |
{"chain": chain, "resi": one_indexed_res},
|
| 245 |
{
|
| 246 |
"backgroundColor": "lightgray",
|
|
|
|
| 249 |
},
|
| 250 |
)
|
| 251 |
if sidechain_highest:
|
| 252 |
+
xyzview.addStyle({"chain": chain, "resi": res}, {"stick": {"radius": 0.2}})
|
|
|
|
|
|
|
| 253 |
return xyzview
|
| 254 |
|
| 255 |
|
|
|
|
| 263 |
unsafe_allow_html=True,
|
| 264 |
)
|
| 265 |
|
| 266 |
+
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
|
|
|
|
|
|
|
| 267 |
data = []
|
| 268 |
for att_weight, _, chain, resi in top_residues:
|
| 269 |
try:
|
poetry.lock
CHANGED
|
@@ -1609,6 +1609,14 @@ pygments = ">=2.13.0,<3.0.0"
|
|
| 1609 |
[package.extras]
|
| 1610 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
| 1611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1612 |
[[package]]
|
| 1613 |
name = "s3transfer"
|
| 1614 |
version = "0.6.0"
|
|
@@ -2196,7 +2204,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
|
|
| 2196 |
[metadata]
|
| 2197 |
lock-version = "1.1"
|
| 2198 |
python-versions = "^3.10"
|
| 2199 |
-
content-hash = "
|
| 2200 |
|
| 2201 |
[metadata.files]
|
| 2202 |
altair = []
|
|
@@ -2428,6 +2436,7 @@ requests = []
|
|
| 2428 |
rfc3339-validator = []
|
| 2429 |
rfc3986-validator = []
|
| 2430 |
rich = []
|
|
|
|
| 2431 |
s3transfer = []
|
| 2432 |
scipy = []
|
| 2433 |
semver = []
|
|
|
|
| 1609 |
[package.extras]
|
| 1610 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
| 1611 |
|
| 1612 |
+
[[package]]
|
| 1613 |
+
name = "ruff"
|
| 1614 |
+
version = "0.0.264"
|
| 1615 |
+
description = "An extremely fast Python linter, written in Rust."
|
| 1616 |
+
category = "main"
|
| 1617 |
+
optional = false
|
| 1618 |
+
python-versions = ">=3.7"
|
| 1619 |
+
|
| 1620 |
[[package]]
|
| 1621 |
name = "s3transfer"
|
| 1622 |
version = "0.6.0"
|
|
|
|
| 2204 |
[metadata]
|
| 2205 |
lock-version = "1.1"
|
| 2206 |
python-versions = "^3.10"
|
| 2207 |
+
content-hash = "502949174f23054a4b450dfc0bb16df64c43d7d6c3e60d1adaf2835962223c32"
|
| 2208 |
|
| 2209 |
[metadata.files]
|
| 2210 |
altair = []
|
|
|
|
| 2436 |
rfc3339-validator = []
|
| 2437 |
rfc3986-validator = []
|
| 2438 |
rich = []
|
| 2439 |
+
ruff = []
|
| 2440 |
s3transfer = []
|
| 2441 |
scipy = []
|
| 2442 |
semver = []
|
pyproject.toml
CHANGED
|
@@ -14,6 +14,7 @@ torch = "^2.0.0"
|
|
| 14 |
sentencepiece = "^0.1.97"
|
| 15 |
tape-proteins = "^0.5"
|
| 16 |
matplotlib = "^3.7.1"
|
|
|
|
| 17 |
|
| 18 |
[tool.poetry.dev-dependencies]
|
| 19 |
pytest = "^7.2.2"
|
|
@@ -21,3 +22,9 @@ pytest = "^7.2.2"
|
|
| 21 |
[build-system]
|
| 22 |
requires = ["poetry-core>=1.0.0"]
|
| 23 |
build-backend = "poetry.core.masonry.api"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
sentencepiece = "^0.1.97"
|
| 15 |
tape-proteins = "^0.5"
|
| 16 |
matplotlib = "^3.7.1"
|
| 17 |
+
ruff = "^0.0.264"
|
| 18 |
|
| 19 |
[tool.poetry.dev-dependencies]
|
| 20 |
pytest = "^7.2.2"
|
|
|
|
| 22 |
[build-system]
|
| 23 |
requires = ["poetry-core>=1.0.0"]
|
| 24 |
build-backend = "poetry.core.masonry.api"
|
| 25 |
+
|
| 26 |
+
[tool.ruff]
|
| 27 |
+
line-length = 100
|
| 28 |
+
|
| 29 |
+
[tool.black]
|
| 30 |
+
line-length = 100
|
tests/test_attention.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
import torch
|
| 2 |
from Bio.PDB.Structure import Structure
|
| 3 |
|
| 4 |
-
from hexviz.attention import (
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def test_get_structure():
|
|
@@ -12,10 +17,11 @@ def test_get_structure():
|
|
| 12 |
assert structure is not None
|
| 13 |
assert isinstance(structure, Structure)
|
| 14 |
|
|
|
|
| 15 |
def test_get_sequences():
|
| 16 |
pdb_id = "1AKE"
|
| 17 |
structure = get_structure(pdb_id)
|
| 18 |
-
|
| 19 |
sequences = get_sequences(structure)
|
| 20 |
|
| 21 |
assert sequences is not None
|
|
@@ -30,26 +36,29 @@ def test_get_attention_zymctrl():
|
|
| 30 |
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
|
| 31 |
|
| 32 |
assert result is not None
|
| 33 |
-
assert result.shape == torch.Size([36,16,3,3])
|
|
|
|
| 34 |
|
| 35 |
def test_get_attention_zymctrl_long_chain():
|
| 36 |
-
structure = get_structure(pdb_code="6A5J")
|
| 37 |
|
| 38 |
sequences = get_sequences(structure)
|
| 39 |
|
| 40 |
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
|
| 41 |
|
| 42 |
assert result is not None
|
| 43 |
-
assert result.shape == torch.Size([36,16,13,13])
|
|
|
|
| 44 |
|
| 45 |
def test_get_attention_tape():
|
| 46 |
-
structure = get_structure(pdb_code="6A5J")
|
| 47 |
sequences = get_sequences(structure)
|
| 48 |
|
| 49 |
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
|
| 50 |
|
| 51 |
assert result is not None
|
| 52 |
-
assert result.shape == torch.Size([12,12,13,13])
|
|
|
|
| 53 |
|
| 54 |
def test_get_attention_prot_bert():
|
| 55 |
|
|
@@ -58,21 +67,19 @@ def test_get_attention_prot_bert():
|
|
| 58 |
assert result is not None
|
| 59 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
| 60 |
|
|
|
|
| 61 |
def test_get_unidirection_avg_filtered():
|
| 62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
| 63 |
-
attention= torch.tensor(
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
[4, 7, 9, 11]]]], dtype=torch.float32)
|
| 67 |
|
| 68 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
| 69 |
|
| 70 |
assert result is not None
|
| 71 |
assert len(result) == 10
|
| 72 |
|
| 73 |
-
attention = torch.tensor([[[[1, 2, 3],
|
| 74 |
-
[2, 5, 6],
|
| 75 |
-
[4, 7, 91]]]], dtype=torch.float32)
|
| 76 |
|
| 77 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
| 78 |
|
|
|
|
| 1 |
import torch
|
| 2 |
from Bio.PDB.Structure import Structure
|
| 3 |
|
| 4 |
+
from hexviz.attention import (
|
| 5 |
+
ModelType,
|
| 6 |
+
get_attention,
|
| 7 |
+
get_sequences,
|
| 8 |
+
get_structure,
|
| 9 |
+
unidirectional_avg_filtered,
|
| 10 |
+
)
|
| 11 |
|
| 12 |
|
| 13 |
def test_get_structure():
|
|
|
|
| 17 |
assert structure is not None
|
| 18 |
assert isinstance(structure, Structure)
|
| 19 |
|
| 20 |
+
|
| 21 |
def test_get_sequences():
|
| 22 |
pdb_id = "1AKE"
|
| 23 |
structure = get_structure(pdb_id)
|
| 24 |
+
|
| 25 |
sequences = get_sequences(structure)
|
| 26 |
|
| 27 |
assert sequences is not None
|
|
|
|
| 36 |
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
|
| 37 |
|
| 38 |
assert result is not None
|
| 39 |
+
assert result.shape == torch.Size([36, 16, 3, 3])
|
| 40 |
+
|
| 41 |
|
| 42 |
def test_get_attention_zymctrl_long_chain():
|
| 43 |
+
structure = get_structure(pdb_code="6A5J") # 13 residues long
|
| 44 |
|
| 45 |
sequences = get_sequences(structure)
|
| 46 |
|
| 47 |
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
|
| 48 |
|
| 49 |
assert result is not None
|
| 50 |
+
assert result.shape == torch.Size([36, 16, 13, 13])
|
| 51 |
+
|
| 52 |
|
| 53 |
def test_get_attention_tape():
|
| 54 |
+
structure = get_structure(pdb_code="6A5J") # 13 residues long
|
| 55 |
sequences = get_sequences(structure)
|
| 56 |
|
| 57 |
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
|
| 58 |
|
| 59 |
assert result is not None
|
| 60 |
+
assert result.shape == torch.Size([12, 12, 13, 13])
|
| 61 |
+
|
| 62 |
|
| 63 |
def test_get_attention_prot_bert():
|
| 64 |
|
|
|
|
| 67 |
assert result is not None
|
| 68 |
assert result.shape == torch.Size([30, 16, 3, 3])
|
| 69 |
|
| 70 |
+
|
| 71 |
def test_get_unidirection_avg_filtered():
|
| 72 |
# 1 head, 1 layer, 4 residues long attention tensor
|
| 73 |
+
attention = torch.tensor(
|
| 74 |
+
[[[[1, 2, 3, 4], [2, 5, 6, 7], [3, 6, 8, 9], [4, 7, 9, 11]]]], dtype=torch.float32
|
| 75 |
+
)
|
|
|
|
| 76 |
|
| 77 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
| 78 |
|
| 79 |
assert result is not None
|
| 80 |
assert len(result) == 10
|
| 81 |
|
| 82 |
+
attention = torch.tensor([[[[1, 2, 3], [2, 5, 6], [4, 7, 91]]]], dtype=torch.float32)
|
|
|
|
|
|
|
| 83 |
|
| 84 |
result = unidirectional_avg_filtered(attention, 0, 0, 0)
|
| 85 |
|
tests/test_models.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
| 3 |
|
| 4 |
from hexviz.models import get_zymctrl
|
|
@@ -13,4 +12,4 @@ def test_get_zymctrl():
|
|
| 13 |
tokenizer, model = result
|
| 14 |
|
| 15 |
assert isinstance(tokenizer, GPT2TokenizerFast)
|
| 16 |
-
assert isinstance(model, GPT2LMHeadModel)
|
|
|
|
|
|
|
| 1 |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
| 2 |
|
| 3 |
from hexviz.models import get_zymctrl
|
|
|
|
| 12 |
tokenizer, model = result
|
| 13 |
|
| 14 |
assert isinstance(tokenizer, GPT2TokenizerFast)
|
| 15 |
+
assert isinstance(model, GPT2LMHeadModel)
|