Upload 6 files
Browse files- metl/__init__.py +2 -0
- metl/encode.py +58 -0
- metl/main.py +139 -0
- metl/models.py +1064 -0
- metl/relative_attention.py +586 -0
- metl/structure.py +184 -0
metl/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .main import *
|
| 2 |
+
__version__ = "0.1"
|
metl/encode.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Encodes data in different formats """
|
| 2 |
+
from enum import Enum, auto
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Encoding(Enum):
|
| 8 |
+
INT_SEQS = auto()
|
| 9 |
+
ONE_HOT = auto()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DataEncoder:
|
| 13 |
+
chars = ["*", "A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
|
| 14 |
+
num_chars = len(chars)
|
| 15 |
+
mapping = {c: i for i, c in enumerate(chars)}
|
| 16 |
+
|
| 17 |
+
def __init__(self, encoding: Encoding = Encoding.INT_SEQS):
|
| 18 |
+
self.encoding = encoding
|
| 19 |
+
|
| 20 |
+
def _encode_from_int_seqs(self, seq_ints):
|
| 21 |
+
if self.encoding == Encoding.INT_SEQS:
|
| 22 |
+
return seq_ints
|
| 23 |
+
elif self.encoding == Encoding.ONE_HOT:
|
| 24 |
+
one_hot = np.eye(self.num_chars)[seq_ints]
|
| 25 |
+
return one_hot.astype(np.float32)
|
| 26 |
+
|
| 27 |
+
def encode_sequences(self, char_seqs):
|
| 28 |
+
seq_ints = []
|
| 29 |
+
for char_seq in char_seqs:
|
| 30 |
+
int_seq = [self.mapping[c] for c in char_seq]
|
| 31 |
+
seq_ints.append(int_seq)
|
| 32 |
+
seq_ints = np.array(seq_ints).astype(int)
|
| 33 |
+
return self._encode_from_int_seqs(seq_ints)
|
| 34 |
+
|
| 35 |
+
def encode_variants(self, wt, variants):
|
| 36 |
+
# convert wild type seq to integer encoding
|
| 37 |
+
wt_int = np.zeros(len(wt), dtype=np.uint8)
|
| 38 |
+
for i, c in enumerate(wt):
|
| 39 |
+
wt_int[i] = self.mapping[c]
|
| 40 |
+
|
| 41 |
+
# tile the wild-type seq
|
| 42 |
+
seq_ints = np.tile(wt_int, (len(variants), 1))
|
| 43 |
+
|
| 44 |
+
for i, variant in enumerate(variants):
|
| 45 |
+
# special handling if we want to encode the wild-type seq (it's already correct!)
|
| 46 |
+
if variant == "_wt":
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
# variants are a list of mutations [mutation1, mutation2, ....]
|
| 50 |
+
variant = variant.split(",")
|
| 51 |
+
for mutation in variant:
|
| 52 |
+
# mutations are in the form <original char><position><replacement char>
|
| 53 |
+
position = int(mutation[1:-1])
|
| 54 |
+
replacement = self.mapping[mutation[-1]]
|
| 55 |
+
seq_ints[i, position] = replacement
|
| 56 |
+
|
| 57 |
+
seq_ints = seq_ints.astype(int)
|
| 58 |
+
return self._encode_from_int_seqs(seq_ints)
|
metl/main.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.hub
|
| 3 |
+
|
| 4 |
+
import metl.models as models
|
| 5 |
+
from metl.encode import DataEncoder, Encoding
|
| 6 |
+
|
| 7 |
+
UUID_URL_MAP = {
|
| 8 |
+
# global source models
|
| 9 |
+
"D72M9aEp": "https://zenodo.org/records/11051645/files/METL-G-20M-1D-D72M9aEp.pt?download=1",
|
| 10 |
+
"Nr9zCKpR": "https://zenodo.org/records/11051645/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1",
|
| 11 |
+
"auKdzzwX": "https://zenodo.org/records/11051645/files/METL-G-50M-1D-auKdzzwX.pt?download=1",
|
| 12 |
+
"6PSAzdfv": "https://zenodo.org/records/11051645/files/METL-G-50M-3D-6PSAzdfv.pt?download=1",
|
| 13 |
+
|
| 14 |
+
# local source models
|
| 15 |
+
"8gMPQJy4": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1",
|
| 16 |
+
"Hr4GNHws": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1",
|
| 17 |
+
"8iFoiYw2": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1",
|
| 18 |
+
"kt5DdWTa": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1",
|
| 19 |
+
"DMfkjVzT": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1",
|
| 20 |
+
"epegcFiH": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1",
|
| 21 |
+
"kS3rUS7h": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1",
|
| 22 |
+
"X7w83g6S": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1",
|
| 23 |
+
"UKebCQGz": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1",
|
| 24 |
+
"2rr8V4th": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1",
|
| 25 |
+
"PREhfC22": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1",
|
| 26 |
+
"9ASvszux": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1",
|
| 27 |
+
"HscFFkAb": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1",
|
| 28 |
+
"H48oiNZN": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1",
|
| 29 |
+
|
| 30 |
+
# metl bind source models
|
| 31 |
+
"K6mw24Rg": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1",
|
| 32 |
+
"Bo5wn2SG": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1",
|
| 33 |
+
|
| 34 |
+
# finetuned models from GFP design experiment
|
| 35 |
+
"YoQkzoLD": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1",
|
| 36 |
+
"PEkeRuxb": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1",
|
| 37 |
+
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
IDENT_UUID_MAP = {
|
| 41 |
+
# the keys should be all lowercase
|
| 42 |
+
"metl-g-20m-1d": "D72M9aEp",
|
| 43 |
+
"metl-g-20m-3d": "Nr9zCKpR",
|
| 44 |
+
"metl-g-50m-1d": "auKdzzwX",
|
| 45 |
+
"metl-g-50m-3d": "6PSAzdfv",
|
| 46 |
+
|
| 47 |
+
# GFP local source models
|
| 48 |
+
"metl-l-2m-1d-gfp": "8gMPQJy4",
|
| 49 |
+
"metl-l-2m-3d-gfp": "Hr4GNHws",
|
| 50 |
+
|
| 51 |
+
# DLG4 local source models
|
| 52 |
+
"metl-l-2m-1d-dlg4": "8iFoiYw2",
|
| 53 |
+
"metl-l-2m-3d-dlg4": "kt5DdWTa",
|
| 54 |
+
|
| 55 |
+
# GB1 local source models
|
| 56 |
+
"metl-l-2m-1d-gb1": "DMfkjVzT",
|
| 57 |
+
"metl-l-2m-3d-gb1": "epegcFiH",
|
| 58 |
+
|
| 59 |
+
# GRB2 local source models
|
| 60 |
+
"metl-l-2m-1d-grb2": "kS3rUS7h",
|
| 61 |
+
"metl-l-2m-3d-grb2": "X7w83g6S",
|
| 62 |
+
|
| 63 |
+
# Pab1 local source models
|
| 64 |
+
"metl-l-2m-1d-pab1": "UKebCQGz",
|
| 65 |
+
"metl-l-2m-3d-pab1": "2rr8V4th",
|
| 66 |
+
|
| 67 |
+
# TEM-1 local source models
|
| 68 |
+
"metl-l-2m-1d-tem-1": "PREhfC22",
|
| 69 |
+
"metl-l-2m-3d-tem-1": "9ASvszux",
|
| 70 |
+
|
| 71 |
+
# Ube4b local source models
|
| 72 |
+
"metl-l-2m-1d-ube4b": "HscFFkAb",
|
| 73 |
+
"metl-l-2m-3d-ube4b": "H48oiNZN",
|
| 74 |
+
|
| 75 |
+
# METL-Bind for GB1
|
| 76 |
+
"metl-bind-2m-3d-gb1-standard": "K6mw24Rg",
|
| 77 |
+
"metl-bind-2m-3d-gb1-binding": "Bo5wn2SG",
|
| 78 |
+
|
| 79 |
+
# GFP design models, giving them an ident
|
| 80 |
+
"metl-l-2m-1d-gfp-ft-design": "YoQkzoLD",
|
| 81 |
+
"metl-l-2m-3d-gfp-ft-design": "PEkeRuxb",
|
| 82 |
+
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def download_checkpoint(uuid):
|
| 87 |
+
ckpt = torch.hub.load_state_dict_from_url(UUID_URL_MAP[uuid],
|
| 88 |
+
map_location="cpu", file_name=f"{uuid}.pt")
|
| 89 |
+
state_dict = ckpt["state_dict"]
|
| 90 |
+
hyper_parameters = ckpt["hyper_parameters"]
|
| 91 |
+
|
| 92 |
+
return state_dict, hyper_parameters
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _get_data_encoding(hparams):
|
| 96 |
+
if "encoding" in hparams and hparams["encoding"] == "int_seqs":
|
| 97 |
+
encoding = Encoding.INT_SEQS
|
| 98 |
+
elif "encoding" in hparams and hparams["encoding"] == "one_hot":
|
| 99 |
+
encoding = Encoding.ONE_HOT
|
| 100 |
+
elif (("encoding" in hparams and hparams["encoding"] == "auto") or "encoding" not in hparams) and \
|
| 101 |
+
hparams["model_name"] in ["transformer_encoder"]:
|
| 102 |
+
encoding = Encoding.INT_SEQS
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError("Detected unsupported encoding in hyperparameters")
|
| 105 |
+
|
| 106 |
+
return encoding
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_model_and_data_encoder(state_dict, hparams):
|
| 110 |
+
model = models.Model[hparams["model_name"]].cls(**hparams)
|
| 111 |
+
model.load_state_dict(state_dict)
|
| 112 |
+
|
| 113 |
+
data_encoder = DataEncoder(_get_data_encoding(hparams))
|
| 114 |
+
|
| 115 |
+
return model, data_encoder
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_from_uuid(uuid):
|
| 119 |
+
if uuid in UUID_URL_MAP:
|
| 120 |
+
state_dict, hparams = download_checkpoint(uuid)
|
| 121 |
+
return load_model_and_data_encoder(state_dict, hparams)
|
| 122 |
+
else:
|
| 123 |
+
raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_from_ident(ident):
|
| 127 |
+
ident = ident.lower()
|
| 128 |
+
if ident in IDENT_UUID_MAP:
|
| 129 |
+
state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident])
|
| 130 |
+
return load_model_and_data_encoder(state_dict, hparams)
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_from_checkpoint(ckpt_fn):
|
| 136 |
+
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
| 137 |
+
state_dict = ckpt["state_dict"]
|
| 138 |
+
hyper_parameters = ckpt["hyper_parameters"]
|
| 139 |
+
return load_model_and_data_encoder(state_dict, hyper_parameters)
|
metl/models.py
ADDED
|
@@ -0,0 +1,1064 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import math
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
import enum
|
| 5 |
+
from os.path import isfile
|
| 6 |
+
from typing import List, Tuple, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
import metl.relative_attention as ra
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def reset_parameters_helper(m: nn.Module):
|
| 17 |
+
""" helper function for resetting model parameters, meant to be used with model.apply() """
|
| 18 |
+
|
| 19 |
+
# the PyTorch MultiHeadAttention has a private function _reset_parameters()
|
| 20 |
+
# other layers have a public reset_parameters()... go figure
|
| 21 |
+
reset_parameters = getattr(m, "reset_parameters", None)
|
| 22 |
+
reset_parameters_private = getattr(m, "_reset_parameters", None)
|
| 23 |
+
|
| 24 |
+
if callable(reset_parameters) and callable(reset_parameters_private):
|
| 25 |
+
raise RuntimeError("Module has both public and private methods for resetting parameters. "
|
| 26 |
+
"This is unexpected... probably should just call the public one.")
|
| 27 |
+
|
| 28 |
+
if callable(reset_parameters):
|
| 29 |
+
m.reset_parameters()
|
| 30 |
+
|
| 31 |
+
if callable(reset_parameters_private):
|
| 32 |
+
m._reset_parameters()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SequentialWithArgs(nn.Sequential):
|
| 36 |
+
def forward(self, x, **kwargs):
|
| 37 |
+
for module in self:
|
| 38 |
+
if isinstance(module, ra.RelativeTransformerEncoder) or isinstance(module, SequentialWithArgs):
|
| 39 |
+
# for relative transformer encoders, pass in kwargs (pdb_fn)
|
| 40 |
+
x = module(x, **kwargs)
|
| 41 |
+
else:
|
| 42 |
+
# for all modules, don't pass in kwargs
|
| 43 |
+
x = module(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class PositionalEncoding(nn.Module):
|
| 48 |
+
# originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
| 49 |
+
# they have since updated their implementation, but it is functionally equivalent
|
| 50 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 51 |
+
super(PositionalEncoding, self).__init__()
|
| 52 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 53 |
+
|
| 54 |
+
pe = torch.zeros(max_len, d_model)
|
| 55 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 56 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 57 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 58 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 59 |
+
# note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim]
|
| 60 |
+
# however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first)
|
| 61 |
+
# fixed by changing pe = pe.unsqueeze(0).transpose(0, 1) to pe = pe.unsqueeze(0)
|
| 62 |
+
# also down below, changing our indexing into the position encoding to reflect new dimensions
|
| 63 |
+
# pe = pe.unsqueeze(0).transpose(0, 1)
|
| 64 |
+
pe = pe.unsqueeze(0)
|
| 65 |
+
self.register_buffer('pe', pe)
|
| 66 |
+
|
| 67 |
+
def forward(self, x, **kwargs):
|
| 68 |
+
# note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim]
|
| 69 |
+
# however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first)
|
| 70 |
+
# fixed by changing x = x + self.pe[:x.size(0)] to x = x + self.pe[:, :x.size(1), :]
|
| 71 |
+
# x = x + self.pe[:x.size(0), :]
|
| 72 |
+
x = x + self.pe[:, :x.size(1), :]
|
| 73 |
+
return self.dropout(x)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ScaledEmbedding(nn.Module):
|
| 77 |
+
# https://pytorch.org/tutorials/beginner/translation_transformer.html
|
| 78 |
+
# a helper function for embedding that scales by sqrt(d_model) in the forward()
|
| 79 |
+
# makes it, so we don't have to do the scaling in the main AttnModel forward()
|
| 80 |
+
|
| 81 |
+
# todo: be aware of embedding scaling factor
|
| 82 |
+
# regarding the scaling factor, it's unclear exactly what the purpose is and whether it is needed
|
| 83 |
+
# there are several theories on why it is used, and it shows up in all the transformer reference implementations
|
| 84 |
+
# https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod
|
| 85 |
+
# 1. Has something to do with weight sharing between the embedding and the decoder output
|
| 86 |
+
# 2. Scales up the embeddings so the signal doesn't get overwhelmed when adding the absolute positional encoding
|
| 87 |
+
# 3. It cancels out with the scaling factor in scaled dot product attention, and helps make the model robust
|
| 88 |
+
# to the choice of embedding_len
|
| 89 |
+
# 4. It's not actually needed
|
| 90 |
+
|
| 91 |
+
# Regarding #1, not really sure about this. In section 3.4 of attention is all you need,
|
| 92 |
+
# that's where they state they multiply the embedding weights by sqrt(d_model), and the context is that they
|
| 93 |
+
# are sharing the same weight matrix between the two embedding layers and the pre-softmax linear transformation.
|
| 94 |
+
# there may be a reason that we want those weights scaled differently for the embedding layers vs. the linear
|
| 95 |
+
# transformation. It might have something to do with the scale at which embedding weights are initialized
|
| 96 |
+
# is more appropriate for the decoder linear transform vs how they are used in the attention function. Might have
|
| 97 |
+
# something to do with computing the correct next-token probabilities. Overall, I'm really not sure about this,
|
| 98 |
+
# but we aren't using a decoder anyway. So if this is the reason, then we don't need to perform the multiply.
|
| 99 |
+
|
| 100 |
+
# Regarding #2, it seems like in one implementation of transformers (fairseq), the sinusoidal positional encoding
|
| 101 |
+
# has a range of (-1.0, 1.0), but the word embedding are initialized with mean 0 and s.d embedding_dim ** -0.5,
|
| 102 |
+
# which for embedding_dim=512, is a range closer to (-0.10, 0.10). Thus, the positional embedding would overwhelm
|
| 103 |
+
# the word embeddings when they are added together. The scaling factor increases the signal of the word embeddings.
|
| 104 |
+
# for embedding_dim=512, it scales word embeddings by 22, increasing range of the word embeddings to (-2.2, 2.2).
|
| 105 |
+
# link to fairseq implementation, search for nn.init to see them do the initialization
|
| 106 |
+
# https://fairseq.readthedocs.io/en/v0.7.1/_modules/fairseq/models/transformer.html
|
| 107 |
+
#
|
| 108 |
+
# For PyTorch, PyTorch initializes nn.Embedding with a standard normal distribution mean 0, variance 1: N(0,1).
|
| 109 |
+
# this puts the range for the word embeddings around (-3, 3). the pytorch implementation for positional encoding
|
| 110 |
+
# also has a range of (-1.0, 1.0). So already, these are much closer in scale, and it doesn't seem like we need
|
| 111 |
+
# to increase the scale of the word embeddings. However, PyTorch example still multiply by the scaling factor
|
| 112 |
+
# unclear whether this is just a carryover that is not actually needed, or if there is a different reason
|
| 113 |
+
#
|
| 114 |
+
# EDIT! I just realized that even though nn.Embedding defaults to a range of around (-3, 3), the PyTorch
|
| 115 |
+
# transformer example actually re-initializes them using a uniform distribution in the range of (-0.1, 0.1)
|
| 116 |
+
# that makes it very similar to the fairseq implementation, so the scaling factor that PyTorch uses actually would
|
| 117 |
+
# bring the word embedding and positional encodings much closer in scale. So this could be the reason why pytorch
|
| 118 |
+
# does it
|
| 119 |
+
|
| 120 |
+
# Regarding #3, I don't think so. Firstly, does it actually cancel there? Secondly, the purpose of the scaling
|
| 121 |
+
# factor in scaled dot product attention, according to attention is all you need, is to counteract dot products
|
| 122 |
+
# that are very high in magnitude due to choice of large mbedding length (aka d_k). The problem with high magnitude
|
| 123 |
+
# dot products is that potentially, the softmax is pushed into regions where it has extremely small gradients,
|
| 124 |
+
# making learning difficult. If the scaling factor in the embedding was meant to counteract the scaling factor in
|
| 125 |
+
# scaled dot product attention, then what would be the point of doing all that?
|
| 126 |
+
|
| 127 |
+
# Regarding #4, I don't think the scaling will have any effects in practice, it's probably not needed
|
| 128 |
+
|
| 129 |
+
# Overall, I think #2 is the most likely reason why this scaling is performed. In theory, I think
|
| 130 |
+
# even if the scaling wasn't performed, the network might learn to up-scale the word embedding weights to increase
|
| 131 |
+
# word embedding signal vs. the position signal on its own. Another question I have is why not just initialize
|
| 132 |
+
# the embedding weights to have higher initial values? Why put it in the range (-0.1, 0.1)?
|
| 133 |
+
#
|
| 134 |
+
# The fact that most implementations have this scaling concerns me, makes me think I might be missing something.
|
| 135 |
+
# For our purposes, we can train a couple models to see if scaling has any positive or negative effect.
|
| 136 |
+
# Still need to think about potential effects of this scaling on relative position embeddings.
|
| 137 |
+
|
| 138 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool):
|
| 139 |
+
super(ScaledEmbedding, self).__init__()
|
| 140 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 141 |
+
self.emb_size = embedding_dim
|
| 142 |
+
self.embed_scale = math.sqrt(self.emb_size)
|
| 143 |
+
|
| 144 |
+
self.scale = scale
|
| 145 |
+
|
| 146 |
+
self.init_weights()
|
| 147 |
+
|
| 148 |
+
def init_weights(self):
|
| 149 |
+
# todo: not sure why PyTorch example initializes weights like this
|
| 150 |
+
# might have something to do with word embedding scaling factor (see above)
|
| 151 |
+
# could also just try the default weight initialization for nn.Embedding()
|
| 152 |
+
init_range = 0.1
|
| 153 |
+
self.embedding.weight.data.uniform_(-init_range, init_range)
|
| 154 |
+
|
| 155 |
+
def forward(self, tokens: Tensor, **kwargs):
|
| 156 |
+
if self.scale:
|
| 157 |
+
return self.embedding(tokens.long()) * self.embed_scale
|
| 158 |
+
else:
|
| 159 |
+
return self.embedding(tokens.long())
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class FCBlock(nn.Module):
|
| 163 |
+
""" a fully connected block with options for batchnorm and dropout
|
| 164 |
+
can extend in the future with option for different activation, etc """
|
| 165 |
+
|
| 166 |
+
def __init__(self,
|
| 167 |
+
in_features: int,
|
| 168 |
+
num_hidden_nodes: int = 64,
|
| 169 |
+
use_batchnorm: bool = False,
|
| 170 |
+
use_layernorm: bool = False,
|
| 171 |
+
norm_before_activation: bool = False,
|
| 172 |
+
use_dropout: bool = False,
|
| 173 |
+
dropout_rate: float = 0.2,
|
| 174 |
+
activation: str = "relu"):
|
| 175 |
+
|
| 176 |
+
super().__init__()
|
| 177 |
+
|
| 178 |
+
if use_batchnorm and use_layernorm:
|
| 179 |
+
raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
|
| 180 |
+
|
| 181 |
+
self.use_batchnorm = use_batchnorm
|
| 182 |
+
self.use_dropout = use_dropout
|
| 183 |
+
self.use_layernorm = use_layernorm
|
| 184 |
+
self.norm_before_activation = norm_before_activation
|
| 185 |
+
|
| 186 |
+
self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes)
|
| 187 |
+
|
| 188 |
+
self.activation = get_activation_fn(activation, functional=False)
|
| 189 |
+
|
| 190 |
+
if use_batchnorm:
|
| 191 |
+
self.norm = nn.BatchNorm1d(num_hidden_nodes)
|
| 192 |
+
|
| 193 |
+
if use_layernorm:
|
| 194 |
+
self.norm = nn.LayerNorm(num_hidden_nodes)
|
| 195 |
+
|
| 196 |
+
if use_dropout:
|
| 197 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 198 |
+
|
| 199 |
+
def forward(self, x, **kwargs):
|
| 200 |
+
x = self.fc(x)
|
| 201 |
+
|
| 202 |
+
# norm can be before or after activation, using flag
|
| 203 |
+
if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation:
|
| 204 |
+
x = self.norm(x)
|
| 205 |
+
|
| 206 |
+
x = self.activation(x)
|
| 207 |
+
|
| 208 |
+
# batchnorm being applied after activation, there is some discussion on this online
|
| 209 |
+
if (self.use_batchnorm or self.use_layernorm) and not self.norm_before_activation:
|
| 210 |
+
x = self.norm(x)
|
| 211 |
+
|
| 212 |
+
# dropout being applied last
|
| 213 |
+
if self.use_dropout:
|
| 214 |
+
x = self.dropout(x)
|
| 215 |
+
|
| 216 |
+
return x
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class TaskSpecificPredictionLayers(nn.Module):
|
| 220 |
+
""" Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input
|
| 221 |
+
into a single output node. All num_tasks outputs are then concatenated into a single tensor. """
|
| 222 |
+
|
| 223 |
+
# todo: the independent layers are run in sequence rather than in parallel, causing a slowdown that
|
| 224 |
+
# scales with the number of tasks. might be able to run in parallel by hacking convolution operation
|
| 225 |
+
# https://stackoverflow.com/questions/58374980/run-multiple-models-of-an-ensemble-in-parallel-with-pytorch
|
| 226 |
+
# https://github.com/pytorch/pytorch/issues/54147
|
| 227 |
+
# https://github.com/pytorch/pytorch/issues/36459
|
| 228 |
+
|
| 229 |
+
def __init__(self,
|
| 230 |
+
num_tasks: int,
|
| 231 |
+
in_features: int,
|
| 232 |
+
num_hidden_nodes: int = 64,
|
| 233 |
+
use_batchnorm: bool = False,
|
| 234 |
+
use_dropout: bool = False,
|
| 235 |
+
dropout_rate: float = 0.2,
|
| 236 |
+
activation: str = "relu"):
|
| 237 |
+
|
| 238 |
+
super().__init__()
|
| 239 |
+
|
| 240 |
+
# each task-specific layer outputs a single node,
|
| 241 |
+
# which can be combined with torch.cat into prediction vector
|
| 242 |
+
self.task_specific_pred_layers = nn.ModuleList()
|
| 243 |
+
for i in range(num_tasks):
|
| 244 |
+
layers = [FCBlock(in_features=in_features,
|
| 245 |
+
num_hidden_nodes=num_hidden_nodes,
|
| 246 |
+
use_batchnorm=use_batchnorm,
|
| 247 |
+
use_dropout=use_dropout,
|
| 248 |
+
dropout_rate=dropout_rate,
|
| 249 |
+
activation=activation),
|
| 250 |
+
nn.Linear(in_features=num_hidden_nodes, out_features=1)]
|
| 251 |
+
self.task_specific_pred_layers.append(nn.Sequential(*layers))
|
| 252 |
+
|
| 253 |
+
def forward(self, x, **kwargs):
|
| 254 |
+
# run each task-specific layer and concatenate outputs into a single output vector
|
| 255 |
+
task_specific_outputs = []
|
| 256 |
+
for layer in self.task_specific_pred_layers:
|
| 257 |
+
task_specific_outputs.append(layer(x))
|
| 258 |
+
|
| 259 |
+
output = torch.cat(task_specific_outputs, dim=1)
|
| 260 |
+
return output
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class GlobalAveragePooling(nn.Module):
|
| 264 |
+
""" helper class for global average pooling """
|
| 265 |
+
|
| 266 |
+
def __init__(self, dim=1):
|
| 267 |
+
super().__init__()
|
| 268 |
+
# our data is in [batch_size, sequence_length, embedding_length]
|
| 269 |
+
# with global pooling, we want to pool over the sequence dimension (dim=1)
|
| 270 |
+
self.dim = dim
|
| 271 |
+
|
| 272 |
+
def forward(self, x, **kwargs):
|
| 273 |
+
return torch.mean(x, dim=self.dim)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class CLSPooling(nn.Module):
|
| 277 |
+
""" helper class for CLS token extraction """
|
| 278 |
+
|
| 279 |
+
def __init__(self, cls_position=0):
|
| 280 |
+
super().__init__()
|
| 281 |
+
|
| 282 |
+
# the position of the CLS token in the sequence dimension
|
| 283 |
+
# currently, the CLS token is in the first position, but may move it to the last position
|
| 284 |
+
self.cls_position = cls_position
|
| 285 |
+
|
| 286 |
+
def forward(self, x, **kwargs):
|
| 287 |
+
# assumes input is in [batch_size, sequence_len, embedding_len]
|
| 288 |
+
# thus sequence dimension is dimension 1
|
| 289 |
+
return x[:, self.cls_position, :]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class TransformerEncoderWrapper(nn.TransformerEncoder):
|
| 293 |
+
""" wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters,
|
| 294 |
+
so each transformer encoder layer has a different initialization """
|
| 295 |
+
|
| 296 |
+
# todo: PyTorch is changing its transformer API... check up on and see if there is a better way
|
| 297 |
+
def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
|
| 298 |
+
super().__init__(encoder_layer, num_layers, norm)
|
| 299 |
+
if reset_params:
|
| 300 |
+
self.apply(reset_parameters_helper)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class AttnModel(nn.Module):
|
| 304 |
+
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
| 305 |
+
|
| 306 |
+
@staticmethod
|
| 307 |
+
def add_model_specific_args(parent_parser):
|
| 308 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 309 |
+
|
| 310 |
+
parser.add_argument('--pos_encoding', type=str, default="absolute",
|
| 311 |
+
choices=["none", "absolute", "relative", "relative_3D"],
|
| 312 |
+
help="what type of positional encoding to use")
|
| 313 |
+
parser.add_argument('--pos_encoding_dropout', type=float, default=0.1,
|
| 314 |
+
help="out much dropout to use in positional encoding, for pos_encoding==absolute")
|
| 315 |
+
parser.add_argument('--clipping_threshold', type=int, default=3,
|
| 316 |
+
help="clipping threshold for relative position embedding, for relative and relative_3D")
|
| 317 |
+
parser.add_argument('--contact_threshold', type=int, default=7,
|
| 318 |
+
help="threshold, in angstroms, for contact map, for relative_3D")
|
| 319 |
+
parser.add_argument('--embedding_len', type=int, default=128)
|
| 320 |
+
parser.add_argument('--num_heads', type=int, default=2)
|
| 321 |
+
parser.add_argument('--num_hidden', type=int, default=64)
|
| 322 |
+
parser.add_argument('--num_enc_layers', type=int, default=2)
|
| 323 |
+
parser.add_argument('--enc_layer_dropout', type=float, default=0.1)
|
| 324 |
+
parser.add_argument('--use_final_encoder_norm', action="store_true", default=False)
|
| 325 |
+
|
| 326 |
+
parser.add_argument('--global_average_pooling', action="store_true", default=False)
|
| 327 |
+
parser.add_argument('--cls_pooling', action="store_true", default=False)
|
| 328 |
+
|
| 329 |
+
parser.add_argument('--use_task_specific_layers', action="store_true", default=False,
|
| 330 |
+
help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer"
|
| 331 |
+
" if both flags are set")
|
| 332 |
+
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
| 333 |
+
parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
|
| 334 |
+
parser.add_argument('--final_hidden_size', type=int, default=64)
|
| 335 |
+
parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
|
| 336 |
+
parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
|
| 337 |
+
parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
|
| 338 |
+
parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
|
| 339 |
+
|
| 340 |
+
parser.add_argument('--activation', type=str, default="relu",
|
| 341 |
+
help="activation function used for all activations in the network")
|
| 342 |
+
return parser
|
| 343 |
+
|
| 344 |
+
def __init__(self,
|
| 345 |
+
# data args
|
| 346 |
+
num_tasks: int,
|
| 347 |
+
aa_seq_len: int,
|
| 348 |
+
num_tokens: int,
|
| 349 |
+
# transformer encoder model args
|
| 350 |
+
pos_encoding: str = "absolute",
|
| 351 |
+
pos_encoding_dropout: float = 0.1,
|
| 352 |
+
clipping_threshold: int = 3,
|
| 353 |
+
contact_threshold: int = 7,
|
| 354 |
+
pdb_fns: List[str] = None,
|
| 355 |
+
embedding_len: int = 64,
|
| 356 |
+
num_heads: int = 2,
|
| 357 |
+
num_hidden: int = 64,
|
| 358 |
+
num_enc_layers: int = 2,
|
| 359 |
+
enc_layer_dropout: float = 0.1,
|
| 360 |
+
use_final_encoder_norm: bool = False,
|
| 361 |
+
# pooling to fixed-length representation
|
| 362 |
+
global_average_pooling: bool = True,
|
| 363 |
+
cls_pooling: bool = False,
|
| 364 |
+
# prediction layers
|
| 365 |
+
use_task_specific_layers: bool = False,
|
| 366 |
+
task_specific_hidden_nodes: int = 64,
|
| 367 |
+
use_final_hidden_layer: bool = False,
|
| 368 |
+
final_hidden_size: int = 64,
|
| 369 |
+
use_final_hidden_layer_norm: bool = False,
|
| 370 |
+
final_hidden_layer_norm_before_activation: bool = False,
|
| 371 |
+
use_final_hidden_layer_dropout: bool = False,
|
| 372 |
+
final_hidden_layer_dropout_rate: float = 0.2,
|
| 373 |
+
# activation function
|
| 374 |
+
activation: str = "relu",
|
| 375 |
+
*args, **kwargs):
|
| 376 |
+
|
| 377 |
+
super().__init__()
|
| 378 |
+
|
| 379 |
+
# store embedding length for use in the forward function
|
| 380 |
+
self.embedding_len = embedding_len
|
| 381 |
+
self.aa_seq_len = aa_seq_len
|
| 382 |
+
|
| 383 |
+
# build up layers
|
| 384 |
+
layers = collections.OrderedDict()
|
| 385 |
+
|
| 386 |
+
# amino acid embedding
|
| 387 |
+
layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True)
|
| 388 |
+
|
| 389 |
+
# absolute positional encoding
|
| 390 |
+
if pos_encoding == "absolute":
|
| 391 |
+
layers["pos_encoder"] = PositionalEncoding(embedding_len, dropout=pos_encoding_dropout, max_len=512)
|
| 392 |
+
|
| 393 |
+
# transformer encoder layer for none or absolute positional encoding
|
| 394 |
+
if pos_encoding in ["none", "absolute"]:
|
| 395 |
+
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_len,
|
| 396 |
+
nhead=num_heads,
|
| 397 |
+
dim_feedforward=num_hidden,
|
| 398 |
+
dropout=enc_layer_dropout,
|
| 399 |
+
activation=get_activation_fn(activation),
|
| 400 |
+
norm_first=True,
|
| 401 |
+
batch_first=True)
|
| 402 |
+
|
| 403 |
+
# layer norm that is used after the transformer encoder layers
|
| 404 |
+
# if the norm_first is False, this is *redundant* and not needed
|
| 405 |
+
# but if norm_first is True, this can be used to normalize outputs from
|
| 406 |
+
# the transformer encoder before inputting to the final fully connected layer
|
| 407 |
+
encoder_norm = None
|
| 408 |
+
if use_final_encoder_norm:
|
| 409 |
+
encoder_norm = nn.LayerNorm(embedding_len)
|
| 410 |
+
|
| 411 |
+
layers["tr_encoder"] = TransformerEncoderWrapper(encoder_layer=encoder_layer,
|
| 412 |
+
num_layers=num_enc_layers,
|
| 413 |
+
norm=encoder_norm)
|
| 414 |
+
|
| 415 |
+
# transformer encoder layer for relative position encoding
|
| 416 |
+
elif pos_encoding in ["relative", "relative_3D"]:
|
| 417 |
+
relative_encoder_layer = ra.RelativeTransformerEncoderLayer(d_model=embedding_len,
|
| 418 |
+
nhead=num_heads,
|
| 419 |
+
pos_encoding=pos_encoding,
|
| 420 |
+
clipping_threshold=clipping_threshold,
|
| 421 |
+
contact_threshold=contact_threshold,
|
| 422 |
+
pdb_fns=pdb_fns,
|
| 423 |
+
dim_feedforward=num_hidden,
|
| 424 |
+
dropout=enc_layer_dropout,
|
| 425 |
+
activation=get_activation_fn(activation),
|
| 426 |
+
norm_first=True)
|
| 427 |
+
|
| 428 |
+
encoder_norm = None
|
| 429 |
+
if use_final_encoder_norm:
|
| 430 |
+
encoder_norm = nn.LayerNorm(embedding_len)
|
| 431 |
+
|
| 432 |
+
layers["tr_encoder"] = ra.RelativeTransformerEncoder(encoder_layer=relative_encoder_layer,
|
| 433 |
+
num_layers=num_enc_layers,
|
| 434 |
+
norm=encoder_norm)
|
| 435 |
+
|
| 436 |
+
# GLOBAL AVERAGE POOLING OR CLS TOKEN
|
| 437 |
+
# set up the layers and output shapes (i.e. input shapes for the pred layer)
|
| 438 |
+
if global_average_pooling:
|
| 439 |
+
# pool over the sequence dimension
|
| 440 |
+
layers["avg_pooling"] = GlobalAveragePooling(dim=1)
|
| 441 |
+
pred_layer_input_features = embedding_len
|
| 442 |
+
elif cls_pooling:
|
| 443 |
+
layers["cls_pooling"] = CLSPooling(cls_position=0)
|
| 444 |
+
pred_layer_input_features = embedding_len
|
| 445 |
+
else:
|
| 446 |
+
# no global average pooling or CLS token
|
| 447 |
+
# sequence dimension is still there, just flattened
|
| 448 |
+
layers["flatten"] = nn.Flatten()
|
| 449 |
+
pred_layer_input_features = embedding_len * aa_seq_len
|
| 450 |
+
|
| 451 |
+
# PREDICTION
|
| 452 |
+
if use_task_specific_layers:
|
| 453 |
+
# task specific prediction layers (nonlinear transform for each task)
|
| 454 |
+
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
| 455 |
+
in_features=pred_layer_input_features,
|
| 456 |
+
num_hidden_nodes=task_specific_hidden_nodes,
|
| 457 |
+
activation=activation)
|
| 458 |
+
elif use_final_hidden_layer:
|
| 459 |
+
# combined prediction linear (linear transform for each task)
|
| 460 |
+
layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
|
| 461 |
+
num_hidden_nodes=final_hidden_size,
|
| 462 |
+
use_batchnorm=False,
|
| 463 |
+
use_layernorm=use_final_hidden_layer_norm,
|
| 464 |
+
norm_before_activation=final_hidden_layer_norm_before_activation,
|
| 465 |
+
use_dropout=use_final_hidden_layer_dropout,
|
| 466 |
+
dropout_rate=final_hidden_layer_dropout_rate,
|
| 467 |
+
activation=activation)
|
| 468 |
+
|
| 469 |
+
layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
|
| 470 |
+
else:
|
| 471 |
+
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
|
| 472 |
+
|
| 473 |
+
# FINAL MODEL
|
| 474 |
+
self.model = SequentialWithArgs(layers)
|
| 475 |
+
|
| 476 |
+
def forward(self, x, **kwargs):
|
| 477 |
+
return self.model(x, **kwargs)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class Transpose(nn.Module):
|
| 481 |
+
""" helper layer to swap data from (batch, seq, channels) to (batch, channels, seq)
|
| 482 |
+
used as a helper in the convolutional network which pytorch defaults to channels-first """
|
| 483 |
+
|
| 484 |
+
def __init__(self, dims: Tuple[int, ...] = (1, 2)):
|
| 485 |
+
super().__init__()
|
| 486 |
+
self.dims = dims
|
| 487 |
+
|
| 488 |
+
def forward(self, x, **kwargs):
|
| 489 |
+
x = x.transpose(*self.dims).contiguous()
|
| 490 |
+
return x
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1):
|
| 494 |
+
return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class ConvBlock(nn.Module):
|
| 498 |
+
def __init__(self,
|
| 499 |
+
in_channels: int,
|
| 500 |
+
out_channels: int,
|
| 501 |
+
kernel_size: int,
|
| 502 |
+
dilation: int = 1,
|
| 503 |
+
padding: str = "same",
|
| 504 |
+
use_batchnorm: bool = False,
|
| 505 |
+
use_layernorm: bool = False,
|
| 506 |
+
norm_before_activation: bool = False,
|
| 507 |
+
use_dropout: bool = False,
|
| 508 |
+
dropout_rate: float = 0.2,
|
| 509 |
+
activation: str = "relu"):
|
| 510 |
+
|
| 511 |
+
super().__init__()
|
| 512 |
+
|
| 513 |
+
if use_batchnorm and use_layernorm:
|
| 514 |
+
raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
|
| 515 |
+
|
| 516 |
+
self.use_batchnorm = use_batchnorm
|
| 517 |
+
self.use_layernorm = use_layernorm
|
| 518 |
+
self.norm_before_activation = norm_before_activation
|
| 519 |
+
self.use_dropout = use_dropout
|
| 520 |
+
|
| 521 |
+
self.conv = nn.Conv1d(in_channels=in_channels,
|
| 522 |
+
out_channels=out_channels,
|
| 523 |
+
kernel_size=kernel_size,
|
| 524 |
+
padding=padding,
|
| 525 |
+
dilation=dilation)
|
| 526 |
+
|
| 527 |
+
self.activation = get_activation_fn(activation, functional=False)
|
| 528 |
+
|
| 529 |
+
if use_batchnorm:
|
| 530 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
| 531 |
+
|
| 532 |
+
if use_layernorm:
|
| 533 |
+
self.norm = nn.LayerNorm(out_channels)
|
| 534 |
+
|
| 535 |
+
if use_dropout:
|
| 536 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 537 |
+
|
| 538 |
+
def forward(self, x, **kwargs):
|
| 539 |
+
x = self.conv(x)
|
| 540 |
+
|
| 541 |
+
# norm can be before or after activation, using flag
|
| 542 |
+
if self.use_batchnorm and self.norm_before_activation:
|
| 543 |
+
x = self.norm(x)
|
| 544 |
+
elif self.use_layernorm and self.norm_before_activation:
|
| 545 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
| 546 |
+
|
| 547 |
+
x = self.activation(x)
|
| 548 |
+
|
| 549 |
+
# batchnorm being applied after activation, there is some discussion on this online
|
| 550 |
+
if self.use_batchnorm and not self.norm_before_activation:
|
| 551 |
+
x = self.norm(x)
|
| 552 |
+
elif self.use_layernorm and not self.norm_before_activation:
|
| 553 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
| 554 |
+
|
| 555 |
+
# dropout being applied after batchnorm, there is some discussion on this online
|
| 556 |
+
if self.use_dropout:
|
| 557 |
+
x = self.dropout(x)
|
| 558 |
+
|
| 559 |
+
return x
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class ConvModel2(nn.Module):
|
| 563 |
+
""" convolutional source model that supports padded inputs, pooling, etc """
|
| 564 |
+
|
| 565 |
+
@staticmethod
|
| 566 |
+
def add_model_specific_args(parent_parser):
|
| 567 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 568 |
+
parser.add_argument('--use_embedding', action="store_true", default=False)
|
| 569 |
+
parser.add_argument('--embedding_len', type=int, default=128)
|
| 570 |
+
|
| 571 |
+
parser.add_argument('--num_conv_layers', type=int, default=1)
|
| 572 |
+
parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
|
| 573 |
+
parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
|
| 574 |
+
parser.add_argument('--dilations', type=int, nargs="+", default=[1])
|
| 575 |
+
parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
|
| 576 |
+
parser.add_argument('--use_conv_layer_norm', action="store_true", default=False)
|
| 577 |
+
parser.add_argument('--conv_layer_norm_before_activation', action="store_true", default=False)
|
| 578 |
+
parser.add_argument('--use_conv_layer_dropout', action="store_true", default=False)
|
| 579 |
+
parser.add_argument('--conv_layer_dropout_rate', type=float, default=0.2)
|
| 580 |
+
|
| 581 |
+
parser.add_argument('--global_average_pooling', action="store_true", default=False)
|
| 582 |
+
|
| 583 |
+
parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
|
| 584 |
+
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
| 585 |
+
parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
|
| 586 |
+
parser.add_argument('--final_hidden_size', type=int, default=64)
|
| 587 |
+
parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
|
| 588 |
+
parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
|
| 589 |
+
parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
|
| 590 |
+
parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
|
| 591 |
+
|
| 592 |
+
parser.add_argument('--activation', type=str, default="relu",
|
| 593 |
+
help="activation function used for all activations in the network")
|
| 594 |
+
|
| 595 |
+
return parser
|
| 596 |
+
|
| 597 |
+
def __init__(self,
|
| 598 |
+
# data
|
| 599 |
+
num_tasks: int,
|
| 600 |
+
aa_seq_len: int,
|
| 601 |
+
aa_encoding_len: int,
|
| 602 |
+
num_tokens: int,
|
| 603 |
+
# convolutional model args
|
| 604 |
+
use_embedding: bool = False,
|
| 605 |
+
embedding_len: int = 64,
|
| 606 |
+
num_conv_layers: int = 1,
|
| 607 |
+
kernel_sizes: List[int] = (7,),
|
| 608 |
+
out_channels: List[int] = (128,),
|
| 609 |
+
dilations: List[int] = (1,),
|
| 610 |
+
padding: str = "valid",
|
| 611 |
+
use_conv_layer_norm: bool = False,
|
| 612 |
+
conv_layer_norm_before_activation: bool = False,
|
| 613 |
+
use_conv_layer_dropout: bool = False,
|
| 614 |
+
conv_layer_dropout_rate: float = 0.2,
|
| 615 |
+
# pooling
|
| 616 |
+
global_average_pooling: bool = True,
|
| 617 |
+
# prediction layers
|
| 618 |
+
use_task_specific_layers: bool = False,
|
| 619 |
+
task_specific_hidden_nodes: int = 64,
|
| 620 |
+
use_final_hidden_layer: bool = False,
|
| 621 |
+
final_hidden_size: int = 64,
|
| 622 |
+
use_final_hidden_layer_norm: bool = False,
|
| 623 |
+
final_hidden_layer_norm_before_activation: bool = False,
|
| 624 |
+
use_final_hidden_layer_dropout: bool = False,
|
| 625 |
+
final_hidden_layer_dropout_rate: float = 0.2,
|
| 626 |
+
# activation function
|
| 627 |
+
activation: str = "relu",
|
| 628 |
+
*args, **kwargs):
|
| 629 |
+
|
| 630 |
+
super(ConvModel2, self).__init__()
|
| 631 |
+
|
| 632 |
+
# build up the layers
|
| 633 |
+
layers = collections.OrderedDict()
|
| 634 |
+
|
| 635 |
+
# amino acid embedding
|
| 636 |
+
if use_embedding:
|
| 637 |
+
layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False)
|
| 638 |
+
|
| 639 |
+
# transpose the input to match PyTorch's expected format
|
| 640 |
+
layers["transpose"] = Transpose(dims=(1, 2))
|
| 641 |
+
|
| 642 |
+
# build up the convolutional layers
|
| 643 |
+
for layer_num in range(num_conv_layers):
|
| 644 |
+
# determine the number of input channels for the first convolutional layer
|
| 645 |
+
if layer_num == 0 and use_embedding:
|
| 646 |
+
# for the first convolutional layer, the in_channels is the embedding_len
|
| 647 |
+
in_channels = embedding_len
|
| 648 |
+
elif layer_num == 0 and not use_embedding:
|
| 649 |
+
# for the first convolutional layer, the in_channels is the aa_encoding_len
|
| 650 |
+
in_channels = aa_encoding_len
|
| 651 |
+
else:
|
| 652 |
+
in_channels = out_channels[layer_num - 1]
|
| 653 |
+
|
| 654 |
+
layers[f"conv{layer_num}"] = ConvBlock(in_channels=in_channels,
|
| 655 |
+
out_channels=out_channels[layer_num],
|
| 656 |
+
kernel_size=kernel_sizes[layer_num],
|
| 657 |
+
dilation=dilations[layer_num],
|
| 658 |
+
padding=padding,
|
| 659 |
+
use_batchnorm=False,
|
| 660 |
+
use_layernorm=use_conv_layer_norm,
|
| 661 |
+
norm_before_activation=conv_layer_norm_before_activation,
|
| 662 |
+
use_dropout=use_conv_layer_dropout,
|
| 663 |
+
dropout_rate=conv_layer_dropout_rate,
|
| 664 |
+
activation=activation)
|
| 665 |
+
|
| 666 |
+
# handle transition from convolutional layers to fully connected layer
|
| 667 |
+
# either use global average pooling or flatten
|
| 668 |
+
# take into consideration whether we are using valid or same padding
|
| 669 |
+
if global_average_pooling:
|
| 670 |
+
# global average pooling (mean across the seq len dimension)
|
| 671 |
+
# the seq len dimensions is the last dimension (batch_size, num_filters, seq_len)
|
| 672 |
+
layers["avg_pooling"] = GlobalAveragePooling(dim=-1)
|
| 673 |
+
# the prediction layers will take num_filters input features
|
| 674 |
+
pred_layer_input_features = out_channels[-1]
|
| 675 |
+
|
| 676 |
+
else:
|
| 677 |
+
# no global average pooling. flatten instead.
|
| 678 |
+
layers["flatten"] = nn.Flatten()
|
| 679 |
+
# calculate the final output len of the convolutional layers
|
| 680 |
+
# and the number of input features for the prediction layers
|
| 681 |
+
if padding == "valid":
|
| 682 |
+
# valid padding (aka no padding) results in shrinking length in progressive layers
|
| 683 |
+
conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0])
|
| 684 |
+
for layer_num in range(1, num_conv_layers):
|
| 685 |
+
conv_out_len = conv1d_out_shape(conv_out_len,
|
| 686 |
+
kernel_size=kernel_sizes[layer_num],
|
| 687 |
+
dilation=dilations[layer_num])
|
| 688 |
+
pred_layer_input_features = conv_out_len * out_channels[-1]
|
| 689 |
+
else:
|
| 690 |
+
# padding == "same"
|
| 691 |
+
pred_layer_input_features = aa_seq_len * out_channels[-1]
|
| 692 |
+
|
| 693 |
+
# prediction layer
|
| 694 |
+
if use_task_specific_layers:
|
| 695 |
+
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
| 696 |
+
in_features=pred_layer_input_features,
|
| 697 |
+
num_hidden_nodes=task_specific_hidden_nodes,
|
| 698 |
+
activation=activation)
|
| 699 |
+
|
| 700 |
+
# final hidden layer (with potential additional dropout)
|
| 701 |
+
elif use_final_hidden_layer:
|
| 702 |
+
layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
|
| 703 |
+
num_hidden_nodes=final_hidden_size,
|
| 704 |
+
use_batchnorm=False,
|
| 705 |
+
use_layernorm=use_final_hidden_layer_norm,
|
| 706 |
+
norm_before_activation=final_hidden_layer_norm_before_activation,
|
| 707 |
+
use_dropout=use_final_hidden_layer_dropout,
|
| 708 |
+
dropout_rate=final_hidden_layer_dropout_rate,
|
| 709 |
+
activation=activation)
|
| 710 |
+
layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
|
| 711 |
+
|
| 712 |
+
else:
|
| 713 |
+
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
|
| 714 |
+
|
| 715 |
+
self.model = nn.Sequential(layers)
|
| 716 |
+
|
| 717 |
+
def forward(self, x, **kwargs):
|
| 718 |
+
output = self.model(x)
|
| 719 |
+
return output
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
class ConvModel(nn.Module):
|
| 723 |
+
""" a convolutional network with convolutional layers followed by a fully connected layer """
|
| 724 |
+
|
| 725 |
+
@staticmethod
|
| 726 |
+
def add_model_specific_args(parent_parser):
|
| 727 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 728 |
+
parser.add_argument('--num_conv_layers', type=int, default=1)
|
| 729 |
+
parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
|
| 730 |
+
parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
|
| 731 |
+
parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
|
| 732 |
+
parser.add_argument('--use_final_hidden_layer', action="store_true",
|
| 733 |
+
help="whether to use a final hidden layer")
|
| 734 |
+
parser.add_argument('--final_hidden_size', type=int, default=128,
|
| 735 |
+
help="number of nodes in the final hidden layer")
|
| 736 |
+
parser.add_argument('--use_dropout', action="store_true",
|
| 737 |
+
help="whether to use dropout in the final hidden layer")
|
| 738 |
+
parser.add_argument('--dropout_rate', type=float, default=0.2,
|
| 739 |
+
help="dropout rate in the final hidden layer")
|
| 740 |
+
parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
|
| 741 |
+
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
| 742 |
+
return parser
|
| 743 |
+
|
| 744 |
+
def __init__(self,
|
| 745 |
+
num_tasks: int,
|
| 746 |
+
aa_seq_len: int,
|
| 747 |
+
aa_encoding_len: int,
|
| 748 |
+
num_conv_layers: int = 1,
|
| 749 |
+
kernel_sizes: List[int] = (7,),
|
| 750 |
+
out_channels: List[int] = (128,),
|
| 751 |
+
padding: str = "valid",
|
| 752 |
+
use_final_hidden_layer: bool = True,
|
| 753 |
+
final_hidden_size: int = 128,
|
| 754 |
+
use_dropout: bool = False,
|
| 755 |
+
dropout_rate: float = 0.2,
|
| 756 |
+
use_task_specific_layers: bool = False,
|
| 757 |
+
task_specific_hidden_nodes: int = 64,
|
| 758 |
+
*args, **kwargs):
|
| 759 |
+
|
| 760 |
+
super(ConvModel, self).__init__()
|
| 761 |
+
|
| 762 |
+
# set up the model as a Sequential block (less to do in forward())
|
| 763 |
+
layers = collections.OrderedDict()
|
| 764 |
+
|
| 765 |
+
layers["transpose"] = Transpose(dims=(1, 2))
|
| 766 |
+
|
| 767 |
+
for layer_num in range(num_conv_layers):
|
| 768 |
+
# for the first convolutional layer, the in_channels is the feature_len
|
| 769 |
+
in_channels = aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1]
|
| 770 |
+
|
| 771 |
+
layers["conv{}".format(layer_num)] = nn.Sequential(
|
| 772 |
+
nn.Conv1d(in_channels=in_channels,
|
| 773 |
+
out_channels=out_channels[layer_num],
|
| 774 |
+
kernel_size=kernel_sizes[layer_num],
|
| 775 |
+
padding=padding),
|
| 776 |
+
nn.ReLU()
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
layers["flatten"] = nn.Flatten()
|
| 780 |
+
|
| 781 |
+
# calculate the final output len of the convolutional layers
|
| 782 |
+
# and the number of input features for the prediction layers
|
| 783 |
+
if padding == "valid":
|
| 784 |
+
# valid padding (aka no padding) results in shrinking length in progressive layers
|
| 785 |
+
conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0])
|
| 786 |
+
for layer_num in range(1, num_conv_layers):
|
| 787 |
+
conv_out_len = conv1d_out_shape(conv_out_len, kernel_size=kernel_sizes[layer_num])
|
| 788 |
+
next_dim = conv_out_len * out_channels[-1]
|
| 789 |
+
elif padding == "same":
|
| 790 |
+
next_dim = aa_seq_len * out_channels[-1]
|
| 791 |
+
else:
|
| 792 |
+
raise ValueError("unexpected value for padding: {}".format(padding))
|
| 793 |
+
|
| 794 |
+
# final hidden layer (with potential additional dropout)
|
| 795 |
+
if use_final_hidden_layer:
|
| 796 |
+
layers["fc1"] = FCBlock(in_features=next_dim,
|
| 797 |
+
num_hidden_nodes=final_hidden_size,
|
| 798 |
+
use_batchnorm=False,
|
| 799 |
+
use_dropout=use_dropout,
|
| 800 |
+
dropout_rate=dropout_rate)
|
| 801 |
+
next_dim = final_hidden_size
|
| 802 |
+
|
| 803 |
+
# final prediction layer
|
| 804 |
+
# either task specific nonlinear layers or a single linear layer
|
| 805 |
+
if use_task_specific_layers:
|
| 806 |
+
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
| 807 |
+
in_features=next_dim,
|
| 808 |
+
num_hidden_nodes=task_specific_hidden_nodes)
|
| 809 |
+
else:
|
| 810 |
+
layers["prediction"] = nn.Linear(in_features=next_dim, out_features=num_tasks)
|
| 811 |
+
|
| 812 |
+
self.model = nn.Sequential(layers)
|
| 813 |
+
|
| 814 |
+
def forward(self, x, **kwargs):
|
| 815 |
+
output = self.model(x)
|
| 816 |
+
return output
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
class FCModel(nn.Module):
|
| 820 |
+
|
| 821 |
+
@staticmethod
|
| 822 |
+
def add_model_specific_args(parent_parser):
|
| 823 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 824 |
+
parser.add_argument('--num_layers', type=int, default=1)
|
| 825 |
+
parser.add_argument('--num_hidden', nargs="+", type=int, default=[128])
|
| 826 |
+
parser.add_argument('--use_batchnorm', action="store_true", default=False)
|
| 827 |
+
parser.add_argument('--use_layernorm', action="store_true", default=False)
|
| 828 |
+
parser.add_argument('--norm_before_activation', action="store_true", default=False)
|
| 829 |
+
parser.add_argument('--use_dropout', action="store_true", default=False)
|
| 830 |
+
parser.add_argument('--dropout_rate', type=float, default=0.2)
|
| 831 |
+
return parser
|
| 832 |
+
|
| 833 |
+
def __init__(self,
|
| 834 |
+
num_tasks: int,
|
| 835 |
+
seq_encoding_len: int,
|
| 836 |
+
num_layers: int = 1,
|
| 837 |
+
num_hidden: List[int] = (128,),
|
| 838 |
+
use_batchnorm: bool = False,
|
| 839 |
+
use_layernorm: bool = False,
|
| 840 |
+
norm_before_activation: bool = False,
|
| 841 |
+
use_dropout: bool = False,
|
| 842 |
+
dropout_rate: float = 0.2,
|
| 843 |
+
activation: str = "relu",
|
| 844 |
+
*args, **kwargs):
|
| 845 |
+
super().__init__()
|
| 846 |
+
|
| 847 |
+
# set up the model as a Sequential block (less to do in forward())
|
| 848 |
+
layers = collections.OrderedDict()
|
| 849 |
+
|
| 850 |
+
# flatten inputs as this is all fully connected
|
| 851 |
+
layers["flatten"] = nn.Flatten()
|
| 852 |
+
|
| 853 |
+
# build up the variable number of hidden layers (fully connected + ReLU + dropout (if set))
|
| 854 |
+
for layer_num in range(num_layers):
|
| 855 |
+
# for the first layer (layer_num == 0), in_features is determined by given input
|
| 856 |
+
# for subsequent layers, the in_features is the previous layer's num_hidden
|
| 857 |
+
in_features = seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1]
|
| 858 |
+
|
| 859 |
+
layers["fc{}".format(layer_num)] = FCBlock(in_features=in_features,
|
| 860 |
+
num_hidden_nodes=num_hidden[layer_num],
|
| 861 |
+
use_batchnorm=use_batchnorm,
|
| 862 |
+
use_layernorm=use_layernorm,
|
| 863 |
+
norm_before_activation=norm_before_activation,
|
| 864 |
+
use_dropout=use_dropout,
|
| 865 |
+
dropout_rate=dropout_rate,
|
| 866 |
+
activation=activation)
|
| 867 |
+
|
| 868 |
+
# finally, the linear output layer
|
| 869 |
+
in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len
|
| 870 |
+
layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks)
|
| 871 |
+
|
| 872 |
+
self.model = nn.Sequential(layers)
|
| 873 |
+
|
| 874 |
+
def forward(self, x, **kwargs):
|
| 875 |
+
output = self.model(x)
|
| 876 |
+
return output
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
class LRModel(nn.Module):
|
| 880 |
+
""" a simple linear model """
|
| 881 |
+
|
| 882 |
+
def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs):
|
| 883 |
+
super().__init__()
|
| 884 |
+
|
| 885 |
+
self.model = nn.Sequential(
|
| 886 |
+
nn.Flatten(),
|
| 887 |
+
nn.Linear(seq_encoding_len, out_features=num_tasks))
|
| 888 |
+
|
| 889 |
+
def forward(self, x, **kwargs):
|
| 890 |
+
output = self.model(x)
|
| 891 |
+
return output
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
class TransferModel(nn.Module):
|
| 895 |
+
""" transfer learning model """
|
| 896 |
+
|
| 897 |
+
@staticmethod
|
| 898 |
+
def add_model_specific_args(parent_parser):
|
| 899 |
+
|
| 900 |
+
def none_or_int(value: str):
|
| 901 |
+
return None if value.lower() == "none" else int(value)
|
| 902 |
+
|
| 903 |
+
p = ArgumentParser(parents=[parent_parser], add_help=False)
|
| 904 |
+
|
| 905 |
+
# for model set up
|
| 906 |
+
p.add_argument('--pretrained_ckpt_path', type=str, default=None)
|
| 907 |
+
|
| 908 |
+
# where to cut off the backbone
|
| 909 |
+
p.add_argument("--backbone_cutoff", type=none_or_int, default=-1,
|
| 910 |
+
help="where to cut off the backbone. can be a negative int, indexing back from "
|
| 911 |
+
"pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. "
|
| 912 |
+
"a value of -2 chops the prediction head and FC layer. a value of -3 chops"
|
| 913 |
+
"the above, as well as the global average pooling layer. all depends on architecture.")
|
| 914 |
+
|
| 915 |
+
p.add_argument("--pred_layer_input_features", type=int, default=None,
|
| 916 |
+
help="if None, number of features will be determined based on backbone_cutoff and standard "
|
| 917 |
+
"architecture. otherwise, specify the number of input features for the prediction layer")
|
| 918 |
+
|
| 919 |
+
# top net args
|
| 920 |
+
p.add_argument("--top_net_type", type=str, default="linear", choices=["linear", "nonlinear", "sklearn"])
|
| 921 |
+
p.add_argument("--top_net_hidden_nodes", type=int, default=256)
|
| 922 |
+
p.add_argument("--top_net_use_batchnorm", action="store_true")
|
| 923 |
+
p.add_argument("--top_net_use_dropout", action="store_true")
|
| 924 |
+
p.add_argument("--top_net_dropout_rate", type=float, default=0.1)
|
| 925 |
+
|
| 926 |
+
return p
|
| 927 |
+
|
| 928 |
+
def __init__(self,
|
| 929 |
+
# pretrained model
|
| 930 |
+
pretrained_ckpt_path: Optional[str] = None,
|
| 931 |
+
pretrained_hparams: Optional[dict] = None,
|
| 932 |
+
backbone_cutoff: Optional[int] = -1,
|
| 933 |
+
# top net
|
| 934 |
+
pred_layer_input_features: Optional[int] = None,
|
| 935 |
+
top_net_type: str = "linear",
|
| 936 |
+
top_net_hidden_nodes: int = 256,
|
| 937 |
+
top_net_use_batchnorm: bool = False,
|
| 938 |
+
top_net_use_dropout: bool = False,
|
| 939 |
+
top_net_dropout_rate: float = 0.1,
|
| 940 |
+
*args, **kwargs):
|
| 941 |
+
|
| 942 |
+
super().__init__()
|
| 943 |
+
|
| 944 |
+
# error checking: if pretrained_ckpt_path is None, then pretrained_hparams must be specified
|
| 945 |
+
if pretrained_ckpt_path is None and pretrained_hparams is None:
|
| 946 |
+
raise ValueError("Either pretrained_ckpt_path or pretrained_hparams must be specified")
|
| 947 |
+
|
| 948 |
+
# note: pdb_fns is loaded from transfer model arguments rather than original source model hparams
|
| 949 |
+
# if pdb_fns is specified as a kwarg, pass it on for structure-based RPE
|
| 950 |
+
# otherwise, can just set pdb_fns to None, and structure-based RPE will handle new PDBs on the fly
|
| 951 |
+
pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None
|
| 952 |
+
|
| 953 |
+
# generate a fresh backbone using pretrained_hparams if specified
|
| 954 |
+
# otherwise load the backbone from the pretrained checkpoint
|
| 955 |
+
# we prioritize pretrained_hparams over pretrained_ckpt_path because
|
| 956 |
+
# pretrained_hparams will only really be specified if we are loading from a DMSTask checkpoint
|
| 957 |
+
# meaning the TransferModel has already been fine-tuned on DMS data, and we are likely loading
|
| 958 |
+
# weights from that finetuning (including weights for the backbone)
|
| 959 |
+
# whereas if pretrained_hparams is not specified but pretrained_ckpt_path is, then we are
|
| 960 |
+
# likely finetuning the TransferModel for the first time, and we need the pretrained weights for the
|
| 961 |
+
# backbone from the RosettaTask checkpoint
|
| 962 |
+
if pretrained_hparams is not None:
|
| 963 |
+
# pretrained_hparams will only be specified if we are loading from a DMSTask checkpoint
|
| 964 |
+
pretrained_hparams["pdb_fns"] = pdb_fns
|
| 965 |
+
pretrained_model = Model[pretrained_hparams["model_name"]].cls(**pretrained_hparams)
|
| 966 |
+
self.pretrained_hparams = pretrained_hparams
|
| 967 |
+
else:
|
| 968 |
+
# not supported in metl-pretrained
|
| 969 |
+
raise NotImplementedError("Loading pretrained weights from RosettaTask checkpoint not supported")
|
| 970 |
+
|
| 971 |
+
layers = collections.OrderedDict()
|
| 972 |
+
|
| 973 |
+
# set the backbone to all layers except the last layer (the pre-trained prediction layer)
|
| 974 |
+
if backbone_cutoff is None:
|
| 975 |
+
layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children()))
|
| 976 |
+
else:
|
| 977 |
+
layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children())[0:backbone_cutoff])
|
| 978 |
+
|
| 979 |
+
if top_net_type == "sklearn":
|
| 980 |
+
# sklearn top not doesn't require any more layers, just return model for the repr layer
|
| 981 |
+
self.model = SequentialWithArgs(layers)
|
| 982 |
+
return
|
| 983 |
+
|
| 984 |
+
# figure out dimensions of input into the prediction layer
|
| 985 |
+
if pred_layer_input_features is None:
|
| 986 |
+
# todo: can make this more robust by checking if the pretrained_mode.hparams for use_final_hidden_layer,
|
| 987 |
+
# global_average_pooling, etc. then can determine what the layer will be based on backbone_cutoff.
|
| 988 |
+
# currently, assumes that pretrained_model uses global average pooling and a final_hidden_layer
|
| 989 |
+
if backbone_cutoff is None:
|
| 990 |
+
# no backbone cutoff... use the full network (including tasks) as the backbone
|
| 991 |
+
pred_layer_input_features = self.pretrained_hparams["num_tasks"]
|
| 992 |
+
elif backbone_cutoff == -1:
|
| 993 |
+
pred_layer_input_features = self.pretrained_hparams["final_hidden_size"]
|
| 994 |
+
elif backbone_cutoff == -2:
|
| 995 |
+
pred_layer_input_features = self.pretrained_hparams["embedding_len"]
|
| 996 |
+
elif backbone_cutoff == -3:
|
| 997 |
+
pred_layer_input_features = self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"]
|
| 998 |
+
else:
|
| 999 |
+
raise ValueError("can't automatically determine pred_layer_input_features for given backbone_cutoff")
|
| 1000 |
+
|
| 1001 |
+
layers["flatten"] = nn.Flatten(start_dim=1)
|
| 1002 |
+
|
| 1003 |
+
# create a new prediction layer on top of the backbone
|
| 1004 |
+
if top_net_type == "linear":
|
| 1005 |
+
# linear layer for prediction
|
| 1006 |
+
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=1)
|
| 1007 |
+
elif top_net_type == "nonlinear":
|
| 1008 |
+
# fully connected with hidden layer
|
| 1009 |
+
fc_block = FCBlock(in_features=pred_layer_input_features,
|
| 1010 |
+
num_hidden_nodes=top_net_hidden_nodes,
|
| 1011 |
+
use_batchnorm=top_net_use_batchnorm,
|
| 1012 |
+
use_dropout=top_net_use_dropout,
|
| 1013 |
+
dropout_rate=top_net_dropout_rate)
|
| 1014 |
+
|
| 1015 |
+
pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1)
|
| 1016 |
+
|
| 1017 |
+
layers["prediction"] = SequentialWithArgs(fc_block, pred_layer)
|
| 1018 |
+
else:
|
| 1019 |
+
raise ValueError("Unexpected type of top net layer: {}".format(top_net_type))
|
| 1020 |
+
|
| 1021 |
+
self.model = SequentialWithArgs(layers)
|
| 1022 |
+
|
| 1023 |
+
def forward(self, x, **kwargs):
|
| 1024 |
+
return self.model(x, **kwargs)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
def get_activation_fn(activation, functional=True):
|
| 1028 |
+
if activation == "relu":
|
| 1029 |
+
return F.relu if functional else nn.ReLU()
|
| 1030 |
+
elif activation == "gelu":
|
| 1031 |
+
return F.gelu if functional else nn.GELU()
|
| 1032 |
+
elif activation == "silo" or activation == "swish":
|
| 1033 |
+
return F.silu if functional else nn.SiLU()
|
| 1034 |
+
elif activation == "leaky_relu" or activation == "lrelu":
|
| 1035 |
+
return F.leaky_relu if functional else nn.LeakyReLU()
|
| 1036 |
+
else:
|
| 1037 |
+
raise RuntimeError("unknown activation: {}".format(activation))
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
class Model(enum.Enum):
|
| 1041 |
+
def __new__(cls, *args, **kwds):
|
| 1042 |
+
value = len(cls.__members__) + 1
|
| 1043 |
+
obj = object.__new__(cls)
|
| 1044 |
+
obj._value_ = value
|
| 1045 |
+
return obj
|
| 1046 |
+
|
| 1047 |
+
def __init__(self, cls, transfer_model):
|
| 1048 |
+
self.cls = cls
|
| 1049 |
+
self.transfer_model = transfer_model
|
| 1050 |
+
|
| 1051 |
+
linear = LRModel, False
|
| 1052 |
+
fully_connected = FCModel, False
|
| 1053 |
+
cnn = ConvModel, False
|
| 1054 |
+
cnn2 = ConvModel2, False
|
| 1055 |
+
transformer_encoder = AttnModel, False
|
| 1056 |
+
transfer_model = TransferModel, True
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def main():
|
| 1060 |
+
pass
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
if __name__ == "__main__":
|
| 1064 |
+
main()
|
metl/relative_attention.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" implementation of transformer encoder with relative attention
|
| 2 |
+
references:
|
| 3 |
+
- https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a
|
| 4 |
+
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
| 5 |
+
- https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py
|
| 6 |
+
- https://github.com/jiezouguihuafu/ClassicalModelreproduced/blob/main/Transformer/transfor_rpe.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
from os.path import basename, dirname, join, isfile
|
| 11 |
+
from typing import Optional, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
from torch.nn import Linear, Dropout, LayerNorm
|
| 18 |
+
import time
|
| 19 |
+
import networkx as nx
|
| 20 |
+
|
| 21 |
+
import metl.structure as structure
|
| 22 |
+
import metl.models as models
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RelativePosition3D(nn.Module):
|
| 26 |
+
""" Contact map-based relative position embeddings """
|
| 27 |
+
|
| 28 |
+
# need to compute a bucket_mtx for each structure
|
| 29 |
+
# need to know which bucket_mtx to use when grabbing the embeddings in forward()
|
| 30 |
+
# - on init, get a list of all PDB files we will be using
|
| 31 |
+
# - use a dictionary to store PDB files --> bucket_mtxs
|
| 32 |
+
# - forward() gets a new arg: the pdb file, which indexes into the dictionary to grab the right bucket_mtx
|
| 33 |
+
def __init__(self,
|
| 34 |
+
embedding_len: int,
|
| 35 |
+
contact_threshold: int,
|
| 36 |
+
clipping_threshold: int,
|
| 37 |
+
pdb_fns: Optional[Union[str, list, tuple]] = None,
|
| 38 |
+
default_pdb_dir: str = "data/pdb_files"):
|
| 39 |
+
|
| 40 |
+
# preferably, pdb_fns contains full paths to the PDBs, but if just the PDB filename is given
|
| 41 |
+
# then it defaults to the path data/pdb_files/<pdb_fn>
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.embedding_len = embedding_len
|
| 44 |
+
self.clipping_threshold = clipping_threshold
|
| 45 |
+
self.contact_threshold = contact_threshold
|
| 46 |
+
self.default_pdb_dir = default_pdb_dir
|
| 47 |
+
|
| 48 |
+
# dummy buffer for getting correct device for on-the-fly bucket matrix generation
|
| 49 |
+
self.register_buffer("dummy_buffer", torch.empty(0), persistent=False)
|
| 50 |
+
|
| 51 |
+
# for 3D-based positions, the number of embeddings is generally the number of buckets
|
| 52 |
+
# for contact map-based distances, that is clipping_threshold + 1
|
| 53 |
+
num_embeddings = clipping_threshold + 1
|
| 54 |
+
|
| 55 |
+
# this is the embedding lookup table E_r
|
| 56 |
+
self.embeddings_table = nn.Embedding(num_embeddings, embedding_len)
|
| 57 |
+
|
| 58 |
+
# set up pdb_fns that were passed in on init (can also be set up during runtime in forward())
|
| 59 |
+
# todo: i'm using a hacky workaround to move the bucket_mtxs to the correct device
|
| 60 |
+
# i tried to make it more efficient by registering bucket matrices as buffers, but i was
|
| 61 |
+
# having problems with DDP syncing the buffers across processes
|
| 62 |
+
self.bucket_mtxs = {}
|
| 63 |
+
self.bucket_mtxs_device = self.dummy_buffer.device
|
| 64 |
+
self._init_pdbs(pdb_fns)
|
| 65 |
+
|
| 66 |
+
def forward(self, pdb_fn):
|
| 67 |
+
# compute matrix R by grabbing the embeddings from the embeddings lookup table
|
| 68 |
+
embeddings = self.embeddings_table(self._get_bucket_mtx(pdb_fn))
|
| 69 |
+
return embeddings
|
| 70 |
+
|
| 71 |
+
# def _get_bucket_mtx(self, pdb_fn):
|
| 72 |
+
# """ retrieve a bucket matrix given the pdb_fn.
|
| 73 |
+
# if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be
|
| 74 |
+
# retrieved from the object buffer. if the bucket matrix has not been computed yet, it will be here """
|
| 75 |
+
# pdb_attr = self._pdb_key(pdb_fn)
|
| 76 |
+
# if hasattr(self, pdb_attr):
|
| 77 |
+
# return getattr(self, pdb_attr)
|
| 78 |
+
# else:
|
| 79 |
+
# # encountering a new PDB at runtime... process it
|
| 80 |
+
# # todo: if there's a new PDB at runtime, it will be initialized separately in each instance
|
| 81 |
+
# # of RelativePosition3D, for each layer. It would be more efficient to have a global
|
| 82 |
+
# # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through
|
| 83 |
+
# self._init_pdb(pdb_fn)
|
| 84 |
+
# return getattr(self, pdb_attr)
|
| 85 |
+
|
| 86 |
+
def _move_bucket_mtxs(self, device):
|
| 87 |
+
for k, v in self.bucket_mtxs.items():
|
| 88 |
+
self.bucket_mtxs[k] = v.to(device)
|
| 89 |
+
self.bucket_mtxs_device = device
|
| 90 |
+
|
| 91 |
+
def _get_bucket_mtx(self, pdb_fn):
|
| 92 |
+
""" retrieve a bucket matrix given the pdb_fn.
|
| 93 |
+
if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be
|
| 94 |
+
retrieved from the bucket_mtxs dictionary. else, it will be computed now on-the-fly """
|
| 95 |
+
|
| 96 |
+
# ensure that all the bucket matrices are on the same device as the nn.Embedding
|
| 97 |
+
if self.bucket_mtxs_device != self.dummy_buffer.device:
|
| 98 |
+
self._move_bucket_mtxs(self.dummy_buffer.device)
|
| 99 |
+
|
| 100 |
+
pdb_attr = self._pdb_key(pdb_fn)
|
| 101 |
+
if pdb_attr in self.bucket_mtxs:
|
| 102 |
+
return self.bucket_mtxs[pdb_attr]
|
| 103 |
+
else:
|
| 104 |
+
# encountering a new PDB at runtime... process it
|
| 105 |
+
# todo: if there's a new PDB at runtime, it will be initialized separately in each instance
|
| 106 |
+
# of RelativePosition3D, for each layer. It would be more efficient to have a global
|
| 107 |
+
# bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through
|
| 108 |
+
self._init_pdb(pdb_fn)
|
| 109 |
+
return self.bucket_mtxs[pdb_attr]
|
| 110 |
+
|
| 111 |
+
# def _set_bucket_mtx(self, pdb_fn, bucket_mtx):
|
| 112 |
+
# """ store a bucket matrix as a buffer """
|
| 113 |
+
# # if PyTorch ever implements a BufferDict, we could use it here efficiently
|
| 114 |
+
# # there is also BufferDict from https://botorch.org/api/_modules/botorch/utils/torch.html
|
| 115 |
+
# # would just need to modify it to have an option for persistent=False
|
| 116 |
+
# bucket_mtx = bucket_mtx.to(self.dummy_buffer.device)
|
| 117 |
+
#
|
| 118 |
+
# self.register_buffer(self._pdb_key(pdb_fn), bucket_mtx, persistent=False)
|
| 119 |
+
|
| 120 |
+
def _set_bucket_mtx(self, pdb_fn, bucket_mtx):
|
| 121 |
+
""" store a bucket matrix in the bucket dict """
|
| 122 |
+
|
| 123 |
+
# move the bucket_mtx to the same device that the other bucket matrices are on
|
| 124 |
+
bucket_mtx = bucket_mtx.to(self.bucket_mtxs_device)
|
| 125 |
+
|
| 126 |
+
self.bucket_mtxs[self._pdb_key(pdb_fn)] = bucket_mtx
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def _pdb_key(pdb_fn):
|
| 130 |
+
""" return a unique key for the given pdb_fn, used to map unique PDBs """
|
| 131 |
+
# note this key does NOT currently support PDBs with the same basename but different paths
|
| 132 |
+
# assumes every PDB is in the format <pdb_name>.pdb
|
| 133 |
+
# should be a compatible with being a class attribute, as it is used as a pytorch buffer name
|
| 134 |
+
return f"pdb_{basename(pdb_fn).split('.')[0]}"
|
| 135 |
+
|
| 136 |
+
def _init_pdbs(self, pdb_fns):
|
| 137 |
+
start = time.time()
|
| 138 |
+
|
| 139 |
+
if pdb_fns is None:
|
| 140 |
+
# nothing to initialize if pdb_fns is None
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
# make sure pdb_fns is a list
|
| 144 |
+
if not isinstance(pdb_fns, list) and not isinstance(pdb_fns, tuple):
|
| 145 |
+
pdb_fns = [pdb_fns]
|
| 146 |
+
|
| 147 |
+
# init each pdb fn in the list
|
| 148 |
+
for pdb_fn in pdb_fns:
|
| 149 |
+
self._init_pdb(pdb_fn)
|
| 150 |
+
|
| 151 |
+
print("Initialized PDB bucket matrices in: {:.3f}".format(time.time() - start))
|
| 152 |
+
|
| 153 |
+
def _init_pdb(self, pdb_fn):
|
| 154 |
+
""" process a pdb file for use with structure-based relative attention """
|
| 155 |
+
# if pdb_fn is not a full path, default to the path data/pdb_files/<pdb_fn>
|
| 156 |
+
if dirname(pdb_fn) == "":
|
| 157 |
+
# handle the case where the pdb file is in the current working directory
|
| 158 |
+
# if there is a PDB file in the cwd.... then just use it as is. otherwise, append the default.
|
| 159 |
+
if not isfile(pdb_fn):
|
| 160 |
+
pdb_fn = join(self.default_pdb_dir, pdb_fn)
|
| 161 |
+
|
| 162 |
+
# create a structure graph from the pdb_fn and contact threshold
|
| 163 |
+
cbeta_mtx = structure.cbeta_distance_matrix(pdb_fn)
|
| 164 |
+
structure_graph = structure.dist_thresh_graph(cbeta_mtx, self.contact_threshold)
|
| 165 |
+
|
| 166 |
+
# bucket_mtx indexes into the embedding lookup table to create the final distance matrix
|
| 167 |
+
bucket_mtx = self._compute_bucket_mtx(structure_graph)
|
| 168 |
+
|
| 169 |
+
self._set_bucket_mtx(pdb_fn, bucket_mtx)
|
| 170 |
+
|
| 171 |
+
def _compute_bucketed_neighbors(self, structure_graph, source_node):
|
| 172 |
+
""" gets the bucketed neighbors from the given source node and structure graph"""
|
| 173 |
+
if self.clipping_threshold < 0:
|
| 174 |
+
raise ValueError("Clipping threshold must be >= 0")
|
| 175 |
+
|
| 176 |
+
sspl = _inv_dict(nx.single_source_shortest_path_length(structure_graph, source_node))
|
| 177 |
+
|
| 178 |
+
if self.clipping_threshold is not None:
|
| 179 |
+
num_buckets = 1 + self.clipping_threshold
|
| 180 |
+
sspl = _combine_d(sspl, self.clipping_threshold, num_buckets - 1)
|
| 181 |
+
|
| 182 |
+
return sspl
|
| 183 |
+
|
| 184 |
+
def _compute_bucket_mtx(self, structure_graph):
|
| 185 |
+
""" get the bucket_mtx for the given structure_graph
|
| 186 |
+
calls _get_bucketed_neighbors for every node in the structure_graph """
|
| 187 |
+
num_residues = len(list(structure_graph))
|
| 188 |
+
|
| 189 |
+
# index into the embedding lookup table to create the final distance matrix
|
| 190 |
+
bucket_mtx = torch.zeros(num_residues, num_residues, dtype=torch.long)
|
| 191 |
+
|
| 192 |
+
for node_num in sorted(list(structure_graph)):
|
| 193 |
+
bucketed_neighbors = self._compute_bucketed_neighbors(structure_graph, node_num)
|
| 194 |
+
|
| 195 |
+
for bucket_num, neighbors in bucketed_neighbors.items():
|
| 196 |
+
bucket_mtx[node_num, neighbors] = bucket_num
|
| 197 |
+
|
| 198 |
+
return bucket_mtx
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class RelativePosition(nn.Module):
|
| 202 |
+
""" creates the embedding lookup table E_r and computes R
|
| 203 |
+
note this inherits from pl.LightningModule instead of nn.Module
|
| 204 |
+
makes it easier to access the device with `self.device`
|
| 205 |
+
might be able to keep it as an nn.Module using the hacky dummy_param or commented out .device property """
|
| 206 |
+
|
| 207 |
+
def __init__(self, embedding_len: int, clipping_threshold: int):
|
| 208 |
+
"""
|
| 209 |
+
embedding_len: the length of the embedding, may be d_model, or d_model // num_heads for multihead
|
| 210 |
+
clipping_threshold: the maximum relative position, referred to as k by Shaw et al.
|
| 211 |
+
"""
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.embedding_len = embedding_len
|
| 214 |
+
self.clipping_threshold = clipping_threshold
|
| 215 |
+
# for sequence-based distances, the number of embeddings is 2*k+1, where k is the clipping threshold
|
| 216 |
+
num_embeddings = 2 * clipping_threshold + 1
|
| 217 |
+
|
| 218 |
+
# this is the embedding lookup table E_r
|
| 219 |
+
self.embeddings_table = nn.Embedding(num_embeddings, embedding_len)
|
| 220 |
+
|
| 221 |
+
# for getting the correct device for range vectors in forward
|
| 222 |
+
self.register_buffer("dummy_buffer", torch.empty(0), persistent=False)
|
| 223 |
+
|
| 224 |
+
def forward(self, length_q, length_k):
|
| 225 |
+
# supports different length sequences, but in self-attention length_q and length_k are the same
|
| 226 |
+
range_vec_q = torch.arange(length_q, device=self.dummy_buffer.device)
|
| 227 |
+
range_vec_k = torch.arange(length_k, device=self.dummy_buffer.device)
|
| 228 |
+
|
| 229 |
+
# this sets up the standard sequence-based distance matrix for relative positions
|
| 230 |
+
# the current position is 0, positions to the right are +1, +2, etc, and to the left -1, -2, etc
|
| 231 |
+
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
| 232 |
+
distance_mat_clipped = torch.clamp(distance_mat, -self.clipping_threshold, self.clipping_threshold)
|
| 233 |
+
|
| 234 |
+
# convert to indices, indexing into the embedding table
|
| 235 |
+
final_mat = (distance_mat_clipped + self.clipping_threshold).long()
|
| 236 |
+
|
| 237 |
+
# compute matrix R by grabbing the embeddings from the embedding lookup table
|
| 238 |
+
embeddings = self.embeddings_table(final_mat)
|
| 239 |
+
|
| 240 |
+
return embeddings
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class RelativeMultiHeadAttention(nn.Module):
|
| 244 |
+
def __init__(self, embed_dim, num_heads, dropout, pos_encoding, clipping_threshold, contact_threshold, pdb_fns):
|
| 245 |
+
"""
|
| 246 |
+
Multi-head attention with relative position embeddings. Input data should be in batch_first format.
|
| 247 |
+
:param embed_dim: aka d_model, aka hid_dim
|
| 248 |
+
:param num_heads: number of heads
|
| 249 |
+
:param dropout: how much dropout for scaled dot product attention
|
| 250 |
+
|
| 251 |
+
:param pos_encoding: what type of positional encoding to use, relative or relative3D
|
| 252 |
+
:param clipping_threshold: clipping threshold for relative position embedding
|
| 253 |
+
:param contact_threshold: for relative_3D, the threshold in angstroms for the contact map
|
| 254 |
+
:param pdb_fns: pdb file(s) to set up the relative position object
|
| 255 |
+
|
| 256 |
+
"""
|
| 257 |
+
super().__init__()
|
| 258 |
+
|
| 259 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 260 |
+
|
| 261 |
+
# model dimensions
|
| 262 |
+
self.embed_dim = embed_dim
|
| 263 |
+
self.num_heads = num_heads
|
| 264 |
+
self.head_dim = embed_dim // num_heads
|
| 265 |
+
|
| 266 |
+
# pos encoding stuff
|
| 267 |
+
self.pos_encoding = pos_encoding
|
| 268 |
+
self.clipping_threshold = clipping_threshold
|
| 269 |
+
self.contact_threshold = contact_threshold
|
| 270 |
+
if pdb_fns is not None and not isinstance(pdb_fns, list):
|
| 271 |
+
pdb_fns = [pdb_fns]
|
| 272 |
+
self.pdb_fns = pdb_fns
|
| 273 |
+
|
| 274 |
+
# relative position embeddings for use with keys and values
|
| 275 |
+
# Shaw et al. uses relative position information for both keys and values
|
| 276 |
+
# Huang et al. only uses it for the keys, which is probably enough
|
| 277 |
+
if pos_encoding == "relative":
|
| 278 |
+
self.relative_position_k = RelativePosition(self.head_dim, self.clipping_threshold)
|
| 279 |
+
self.relative_position_v = RelativePosition(self.head_dim, self.clipping_threshold)
|
| 280 |
+
elif pos_encoding == "relative_3D":
|
| 281 |
+
self.relative_position_k = RelativePosition3D(self.head_dim, self.contact_threshold,
|
| 282 |
+
self.clipping_threshold, self.pdb_fns)
|
| 283 |
+
self.relative_position_v = RelativePosition3D(self.head_dim, self.contact_threshold,
|
| 284 |
+
self.clipping_threshold, self.pdb_fns)
|
| 285 |
+
else:
|
| 286 |
+
raise ValueError("unrecognized pos_encoding: {}".format(pos_encoding))
|
| 287 |
+
|
| 288 |
+
# WQ, WK, and WV from attention is all you need
|
| 289 |
+
# note these default to bias=True, same as PyTorch implementation
|
| 290 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 291 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 292 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 293 |
+
|
| 294 |
+
# WO from attention is all you need
|
| 295 |
+
# used for the final projection when computing multi-head attention
|
| 296 |
+
# PyTorch uses NonDynamicallyQuantizableLinear instead of Linear to avoid triggering an obscure
|
| 297 |
+
# error quantizing the model https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L122
|
| 298 |
+
# todo: if quantizing the model, explore if the above is a concern for us
|
| 299 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 300 |
+
|
| 301 |
+
# dropout for scaled dot product attention
|
| 302 |
+
self.dropout = nn.Dropout(dropout)
|
| 303 |
+
|
| 304 |
+
# scaling factor for scaled dot product attention
|
| 305 |
+
scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
|
| 306 |
+
# persistent=False if you don't want to save it inside state_dict
|
| 307 |
+
self.register_buffer('scale', scale)
|
| 308 |
+
|
| 309 |
+
# toggles meant to be set directly by user
|
| 310 |
+
self.need_weights = False
|
| 311 |
+
self.average_attn_weights = True
|
| 312 |
+
|
| 313 |
+
def _compute_attn_weights(self, query, key, len_q, len_k, batch_size, mask, pdb_fn):
|
| 314 |
+
""" computes the attention weights (a "compatability function" of queries with corresponding keys) """
|
| 315 |
+
|
| 316 |
+
# calculate the first term in the numerator attn1, which is Q*K
|
| 317 |
+
# todo: pytorch reshapes q,k and v to 3 dimensions (similar to how r_q2 is below)
|
| 318 |
+
# is that functionally equivalent to what we're doing? is their way faster?
|
| 319 |
+
# r_q1 = [batch_size, num_heads, len_q, head_dim]
|
| 320 |
+
r_q1 = query.view(batch_size, len_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 321 |
+
# todo: we could directly permute r_k1 to [batch_size, num_heads, head_dim, len_k]
|
| 322 |
+
# to make it compatible for matrix multiplication with r_q1, instead of 2-step approach
|
| 323 |
+
# r_k1 = [batch_size, num_heads, len_k, head_dim]
|
| 324 |
+
r_k1 = key.view(batch_size, len_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 325 |
+
# attn1 = [batch_size, num_heads, len_q, len_k]
|
| 326 |
+
attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))
|
| 327 |
+
|
| 328 |
+
# calculate the second term in the numerator attn2, which is Q*R
|
| 329 |
+
# r_q2 = [query_len, batch_size * num_heads, head_dim]
|
| 330 |
+
r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.num_heads, self.head_dim)
|
| 331 |
+
|
| 332 |
+
# todo: support multiple different PDB base structures per batch
|
| 333 |
+
# one option:
|
| 334 |
+
# - require batches to be all the same protein
|
| 335 |
+
# - add argument to forward() to accept the PDB file for the protein in the batch
|
| 336 |
+
# - then we just pass in the PDB file to relative position's forward()
|
| 337 |
+
# to support multiple different structures per batch:
|
| 338 |
+
# - add argument to forward() to accept PDB files, one for each item in batch
|
| 339 |
+
# - make corresponding changing in relative_position object to return R for each structure
|
| 340 |
+
# - note: if there are a lot of of different structures, and the sequence lengths are long,
|
| 341 |
+
# this could be memory prohibitive because R (rel_pos_k) can take up a lot of mem for long seqs
|
| 342 |
+
# - adjust the attn2 calculation to factor in the multiple different R matrices.
|
| 343 |
+
# the way to do this might have to be to do multiple matmuls, one for each each structure.
|
| 344 |
+
# basically, would split up r_q2 into several matrices grouped by structure, and then
|
| 345 |
+
# multiply with corresponding R, then combine back into the exact same order of the original r_q2
|
| 346 |
+
# note: this may be computationally intensive (splitting, more matrix muliplies, joining)
|
| 347 |
+
# another option would be to create views(?), repeating the different Rs so we can do a
|
| 348 |
+
# a matris multiply directly with r_q2
|
| 349 |
+
# - would shapes be affected if there was padding in the queries, keys, values?
|
| 350 |
+
|
| 351 |
+
if self.pos_encoding == "relative":
|
| 352 |
+
# rel_pos_k = [len_q, len_k, head_dim]
|
| 353 |
+
rel_pos_k = self.relative_position_k(len_q, len_k)
|
| 354 |
+
elif self.pos_encoding == "relative_3D":
|
| 355 |
+
# rel_pos_k = [sequence length (from PDB structure), head_dim]
|
| 356 |
+
rel_pos_k = self.relative_position_k(pdb_fn)
|
| 357 |
+
else:
|
| 358 |
+
raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding))
|
| 359 |
+
|
| 360 |
+
# the matmul basically computes the dot product between each input position’s query vector and
|
| 361 |
+
# its corresponding relative position embeddings across all input sequences in the heads and batch
|
| 362 |
+
# attn2 = [batch_size * num_heads, len_q, len_k]
|
| 363 |
+
attn2 = torch.matmul(r_q2, rel_pos_k.transpose(1, 2)).transpose(0, 1)
|
| 364 |
+
# attn2 = [batch_size, num_heads, len_q, len_k]
|
| 365 |
+
attn2 = attn2.contiguous().view(batch_size, self.num_heads, len_q, len_k)
|
| 366 |
+
|
| 367 |
+
# calculate attention weights
|
| 368 |
+
attn_weights = (attn1 + attn2) / self.scale
|
| 369 |
+
|
| 370 |
+
# apply mask if given
|
| 371 |
+
if mask is not None:
|
| 372 |
+
# todo: pytorch uses float("-inf") instead of -1e10
|
| 373 |
+
attn_weights = attn_weights.masked_fill(mask == 0, -1e10)
|
| 374 |
+
|
| 375 |
+
# softmax gives us attn_weights weights
|
| 376 |
+
attn_weights = torch.softmax(attn_weights, dim=-1)
|
| 377 |
+
# attn_weights = [batch_size, num_heads, len_q, len_k]
|
| 378 |
+
attn_weights = self.dropout(attn_weights)
|
| 379 |
+
|
| 380 |
+
return attn_weights
|
| 381 |
+
|
| 382 |
+
def _compute_avg_val(self, value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn):
|
| 383 |
+
# todo: add option to not factor in relative position embeddings in value calculation
|
| 384 |
+
# calculate the first term, the attn*values
|
| 385 |
+
# r_v1 = [batch_size, num_heads, len_v, head_dim]
|
| 386 |
+
r_v1 = value.view(batch_size, len_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 387 |
+
# avg1 = [batch_size, num_heads, len_q, head_dim]
|
| 388 |
+
avg1 = torch.matmul(attn_weights, r_v1)
|
| 389 |
+
|
| 390 |
+
# calculate the second term, the attn*R
|
| 391 |
+
# similar to how relative embeddings are factored in the attention weights calculation
|
| 392 |
+
if self.pos_encoding == "relative":
|
| 393 |
+
# rel_pos_v = [query_len, value_len, head_dim]
|
| 394 |
+
rel_pos_v = self.relative_position_v(len_q, len_v)
|
| 395 |
+
elif self.pos_encoding == "relative_3D":
|
| 396 |
+
# rel_pos_v = [sequence length (from PDB structure), head_dim]
|
| 397 |
+
rel_pos_v = self.relative_position_v(pdb_fn)
|
| 398 |
+
else:
|
| 399 |
+
raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding))
|
| 400 |
+
|
| 401 |
+
# r_attn_weights = [len_q, batch_size * num_heads, len_v]
|
| 402 |
+
r_attn_weights = attn_weights.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.num_heads, len_k)
|
| 403 |
+
avg2 = torch.matmul(r_attn_weights, rel_pos_v)
|
| 404 |
+
# avg2 = [batch_size, num_heads, len_q, head_dim]
|
| 405 |
+
avg2 = avg2.transpose(0, 1).contiguous().view(batch_size, self.num_heads, len_q, self.head_dim)
|
| 406 |
+
|
| 407 |
+
# calculate avg value
|
| 408 |
+
x = avg1 + avg2 # [batch_size, num_heads, len_q, head_dim]
|
| 409 |
+
x = x.permute(0, 2, 1, 3).contiguous() # [batch_size, len_q, num_heads, head_dim]
|
| 410 |
+
# x = [batch_size, len_q, embed_dim]
|
| 411 |
+
x = x.view(batch_size, len_q, self.embed_dim)
|
| 412 |
+
|
| 413 |
+
return x
|
| 414 |
+
|
| 415 |
+
def forward(self, query, key, value, pdb_fn=None, mask=None):
|
| 416 |
+
# query = [batch_size, q_len, embed_dim]
|
| 417 |
+
# key = [batch_size, k_len, embed_dim]
|
| 418 |
+
# value = [batch_size, v_en, embed_dim]
|
| 419 |
+
batch_size = query.shape[0]
|
| 420 |
+
len_k, len_q, len_v = (key.shape[1], query.shape[1], value.shape[1])
|
| 421 |
+
|
| 422 |
+
# in projection (multiply inputs by WQ, WK, WV)
|
| 423 |
+
query = self.q_proj(query)
|
| 424 |
+
key = self.k_proj(key)
|
| 425 |
+
value = self.v_proj(value)
|
| 426 |
+
|
| 427 |
+
# first compute the attention weights, then multiply with values
|
| 428 |
+
# attn = [batch size, num_heads, len_q, len_k]
|
| 429 |
+
attn_weights = self._compute_attn_weights(query, key, len_q, len_k, batch_size, mask, pdb_fn)
|
| 430 |
+
|
| 431 |
+
# take weighted average of values (weighted by attention weights)
|
| 432 |
+
attn_output = self._compute_avg_val(value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn)
|
| 433 |
+
|
| 434 |
+
# output projection
|
| 435 |
+
# attn_output = [batch_size, len_q, embed_dim]
|
| 436 |
+
attn_output = self.out_proj(attn_output)
|
| 437 |
+
|
| 438 |
+
if self.need_weights:
|
| 439 |
+
# return attention weights in addition to attention
|
| 440 |
+
# average the weights over the heads (to get overall attention)
|
| 441 |
+
# attn_weights = [batch_size, len_q, len_k]
|
| 442 |
+
if self.average_attn_weights:
|
| 443 |
+
attn_weights = attn_weights.sum(dim=1) / self.num_heads
|
| 444 |
+
return {"attn_output": attn_output, "attn_weights": attn_weights}
|
| 445 |
+
else:
|
| 446 |
+
return attn_output
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class RelativeTransformerEncoderLayer(nn.Module):
|
| 450 |
+
"""
|
| 451 |
+
d_model: the number of expected features in the input (required).
|
| 452 |
+
nhead: the number of heads in the MultiHeadAttention models (required).
|
| 453 |
+
clipping_threshold: the clipping threshold for relative position embeddings
|
| 454 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 455 |
+
dropout: the dropout value (default=0.1).
|
| 456 |
+
activation: the activation function of the intermediate layer, can be a string
|
| 457 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 458 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 459 |
+
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
| 460 |
+
operations, respectively. Otherwise, it's done after. Default: ``False`` (after).
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
# this is some kind of torch jit compiling helper... will also ensure these values don't change
|
| 464 |
+
__constants__ = ['batch_first', 'norm_first']
|
| 465 |
+
|
| 466 |
+
def __init__(self,
|
| 467 |
+
d_model,
|
| 468 |
+
nhead,
|
| 469 |
+
pos_encoding="relative",
|
| 470 |
+
clipping_threshold=3,
|
| 471 |
+
contact_threshold=7,
|
| 472 |
+
pdb_fns=None,
|
| 473 |
+
dim_feedforward=2048,
|
| 474 |
+
dropout=0.1,
|
| 475 |
+
activation=F.relu,
|
| 476 |
+
layer_norm_eps=1e-5,
|
| 477 |
+
norm_first=False) -> None:
|
| 478 |
+
|
| 479 |
+
self.batch_first = True
|
| 480 |
+
|
| 481 |
+
super(RelativeTransformerEncoderLayer, self).__init__()
|
| 482 |
+
|
| 483 |
+
self.self_attn = RelativeMultiHeadAttention(d_model, nhead, dropout,
|
| 484 |
+
pos_encoding, clipping_threshold, contact_threshold, pdb_fns)
|
| 485 |
+
|
| 486 |
+
# feed forward model
|
| 487 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
| 488 |
+
self.dropout = Dropout(dropout)
|
| 489 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
| 490 |
+
|
| 491 |
+
self.norm_first = norm_first
|
| 492 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
|
| 493 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
|
| 494 |
+
self.dropout1 = Dropout(dropout)
|
| 495 |
+
self.dropout2 = Dropout(dropout)
|
| 496 |
+
|
| 497 |
+
# Legacy string support for activation function.
|
| 498 |
+
if isinstance(activation, str):
|
| 499 |
+
self.activation = models.get_activation_fn(activation)
|
| 500 |
+
else:
|
| 501 |
+
self.activation = activation
|
| 502 |
+
|
| 503 |
+
def forward(self, src: Tensor, pdb_fn=None) -> Tensor:
|
| 504 |
+
x = src
|
| 505 |
+
if self.norm_first:
|
| 506 |
+
x = x + self._sa_block(self.norm1(x), pdb_fn=pdb_fn)
|
| 507 |
+
x = x + self._ff_block(self.norm2(x))
|
| 508 |
+
else:
|
| 509 |
+
x = self.norm1(x + self._sa_block(x))
|
| 510 |
+
x = self.norm2(x + self._ff_block(x))
|
| 511 |
+
|
| 512 |
+
return x
|
| 513 |
+
|
| 514 |
+
# self-attention block
|
| 515 |
+
def _sa_block(self, x: Tensor, pdb_fn=None) -> Tensor:
|
| 516 |
+
x = self.self_attn(x, x, x, pdb_fn=pdb_fn)
|
| 517 |
+
if isinstance(x, dict):
|
| 518 |
+
# handle the case where we are returning attention weights
|
| 519 |
+
x = x["attn_output"]
|
| 520 |
+
return self.dropout1(x)
|
| 521 |
+
|
| 522 |
+
# feed forward block
|
| 523 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 524 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 525 |
+
return self.dropout2(x)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class RelativeTransformerEncoder(nn.Module):
|
| 529 |
+
def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
|
| 530 |
+
super(RelativeTransformerEncoder, self).__init__()
|
| 531 |
+
# using get_clones means all layers have the same initialization
|
| 532 |
+
# this is also a problem in PyTorch's TransformerEncoder implementation, which this is based on
|
| 533 |
+
# todo: PyTorch is changing its transformer API... check up on and see if there is a better way
|
| 534 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 535 |
+
self.num_layers = num_layers
|
| 536 |
+
self.norm = norm
|
| 537 |
+
|
| 538 |
+
# important because get_clones means all layers have same initialization
|
| 539 |
+
# should recursively reset parameters for all submodules
|
| 540 |
+
if reset_params:
|
| 541 |
+
self.apply(models.reset_parameters_helper)
|
| 542 |
+
|
| 543 |
+
def forward(self, src: Tensor, pdb_fn=None) -> Tensor:
|
| 544 |
+
output = src
|
| 545 |
+
|
| 546 |
+
for mod in self.layers:
|
| 547 |
+
output = mod(output, pdb_fn=pdb_fn)
|
| 548 |
+
|
| 549 |
+
if self.norm is not None:
|
| 550 |
+
output = self.norm(output)
|
| 551 |
+
|
| 552 |
+
return output
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _get_clones(module, num_clones):
|
| 556 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)])
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def _inv_dict(d):
|
| 560 |
+
""" helper function for contact map-based position embeddings """
|
| 561 |
+
inv = dict()
|
| 562 |
+
for k, v in d.items():
|
| 563 |
+
# collect dict keys into lists based on value
|
| 564 |
+
inv.setdefault(v, list()).append(k)
|
| 565 |
+
for k, v in inv.items():
|
| 566 |
+
# put in sorted order
|
| 567 |
+
inv[k] = sorted(v)
|
| 568 |
+
return inv
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _combine_d(d, threshold, combined_key):
|
| 572 |
+
""" helper function for contact map-based position embeddings
|
| 573 |
+
d is a dictionary with ints as keys and lists as values.
|
| 574 |
+
for all keys >= threshold, this function combines the values of those keys into a single list """
|
| 575 |
+
out_d = {}
|
| 576 |
+
for k, v in d.items():
|
| 577 |
+
if k < threshold:
|
| 578 |
+
out_d[k] = v
|
| 579 |
+
elif k >= threshold:
|
| 580 |
+
if combined_key not in out_d:
|
| 581 |
+
out_d[combined_key] = v
|
| 582 |
+
else:
|
| 583 |
+
out_d[combined_key] += v
|
| 584 |
+
if combined_key in out_d:
|
| 585 |
+
out_d[combined_key] = sorted(out_d[combined_key])
|
| 586 |
+
return out_d
|
metl/structure.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from os.path import isfile
|
| 3 |
+
from enum import Enum, auto
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.spatial.distance import cdist
|
| 7 |
+
import networkx as nx
|
| 8 |
+
from biopandas.pdb import PandasPdb
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GraphType(Enum):
|
| 12 |
+
LINEAR = auto()
|
| 13 |
+
COMPLETE = auto()
|
| 14 |
+
DISCONNECTED = auto()
|
| 15 |
+
DIST_THRESH = auto()
|
| 16 |
+
DIST_THRESH_SHUFFLED = auto()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_graph(g, fn):
|
| 20 |
+
""" Saves graph to file """
|
| 21 |
+
nx.write_gexf(g, fn)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_graph(fn):
|
| 25 |
+
""" Loads graph from file """
|
| 26 |
+
g = nx.read_gexf(fn, node_type=int)
|
| 27 |
+
return g
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def shuffle_nodes(g, seed=7):
|
| 31 |
+
""" Shuffles the nodes of the given graph and returns a copy of the shuffled graph """
|
| 32 |
+
# get the list of nodes in this graph
|
| 33 |
+
nodes = g.nodes()
|
| 34 |
+
|
| 35 |
+
# create a permuted list of nodes
|
| 36 |
+
np.random.seed(seed)
|
| 37 |
+
nodes_shuffled = np.random.permutation(nodes)
|
| 38 |
+
|
| 39 |
+
# create a dictionary mapping from old node label to new node label
|
| 40 |
+
mapping = {n: ns for n, ns in zip(nodes, nodes_shuffled)}
|
| 41 |
+
|
| 42 |
+
g_shuffled = nx.relabel_nodes(g, mapping, copy=True)
|
| 43 |
+
|
| 44 |
+
return g_shuffled
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def linear_graph(num_residues):
|
| 48 |
+
""" Creates a linear graph where each node is connected to its sequence neighbor in order """
|
| 49 |
+
g = nx.Graph()
|
| 50 |
+
g.add_nodes_from(np.arange(0, num_residues))
|
| 51 |
+
for i in range(num_residues-1):
|
| 52 |
+
g.add_edge(i, i+1)
|
| 53 |
+
return g
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def complete_graph(num_residues):
|
| 57 |
+
""" Creates a graph where each node is connected to all other nodes"""
|
| 58 |
+
g = nx.complete_graph(num_residues)
|
| 59 |
+
return g
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def disconnected_graph(num_residues):
|
| 63 |
+
g = nx.Graph()
|
| 64 |
+
g.add_nodes_from(np.arange(0, num_residues))
|
| 65 |
+
return g
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def dist_thresh_graph(dist_mtx, threshold):
|
| 69 |
+
""" Creates undirected graph based on a distance threshold """
|
| 70 |
+
g = nx.Graph()
|
| 71 |
+
g.add_nodes_from(np.arange(0, dist_mtx.shape[0]))
|
| 72 |
+
|
| 73 |
+
# loop through each residue
|
| 74 |
+
for rn1 in range(len(dist_mtx)):
|
| 75 |
+
# find all residues that are within threshold distance of current
|
| 76 |
+
rns_within_threshold = np.where(dist_mtx[rn1] < threshold)[0]
|
| 77 |
+
|
| 78 |
+
# add edges from current residue to those that are within threshold
|
| 79 |
+
for rn2 in rns_within_threshold:
|
| 80 |
+
# don't add self edges
|
| 81 |
+
if rn1 != rn2:
|
| 82 |
+
g.add_edge(rn1, rn2)
|
| 83 |
+
return g
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def ordered_adjacency_matrix(g):
|
| 87 |
+
""" returns the adjacency matrix ordered by node label in increasing order as a numpy array """
|
| 88 |
+
node_order = sorted(g.nodes())
|
| 89 |
+
adj_mtx = nx.to_numpy_matrix(g, nodelist=node_order)
|
| 90 |
+
return np.asarray(adj_mtx).astype(np.float32)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def cbeta_distance_matrix(pdb_fn, start=0, end=None):
|
| 94 |
+
# note that start and end are not going by residue number
|
| 95 |
+
# they are going by whatever the listing in the pdb file is
|
| 96 |
+
|
| 97 |
+
# read the pdb file into a biopandas object
|
| 98 |
+
ppdb = PandasPdb().read_pdb(pdb_fn)
|
| 99 |
+
|
| 100 |
+
# group by residue number
|
| 101 |
+
# important to specify sort=True so that group keys (residue number) are in order
|
| 102 |
+
# the reason is we loop through group keys below, and assume that residues are in order
|
| 103 |
+
# the pandas function has sort=True by default, but we specify it anyway because it is important
|
| 104 |
+
grouped = ppdb.df["ATOM"].groupby("residue_number", sort=True)
|
| 105 |
+
|
| 106 |
+
# a list of coords for the cbeta or calpha of each residue
|
| 107 |
+
coords = []
|
| 108 |
+
|
| 109 |
+
# loop through each residue and find the coordinates of cbeta
|
| 110 |
+
for i, (residue_number, values) in enumerate(grouped):
|
| 111 |
+
|
| 112 |
+
# skip residues not in the range
|
| 113 |
+
end_index = (len(grouped) if end is None else end)
|
| 114 |
+
if i not in range(start, end_index):
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
residue_group = grouped.get_group(residue_number)
|
| 118 |
+
|
| 119 |
+
atom_names = residue_group["atom_name"]
|
| 120 |
+
if "CB" in atom_names.values:
|
| 121 |
+
# print("Using CB...")
|
| 122 |
+
atom_name = "CB"
|
| 123 |
+
elif "CA" in atom_names.values:
|
| 124 |
+
# print("Using CA...")
|
| 125 |
+
atom_name = "CA"
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError("Couldn't find CB or CA for residue {}".format(residue_number))
|
| 128 |
+
|
| 129 |
+
# get the coordinates of cbeta (or calpha)
|
| 130 |
+
coords.append(
|
| 131 |
+
residue_group[residue_group["atom_name"] == atom_name][["x_coord", "y_coord", "z_coord"]].values[0])
|
| 132 |
+
|
| 133 |
+
# stack the coords into a numpy array where each row has the x,y,z coords for a different residue
|
| 134 |
+
coords = np.stack(coords)
|
| 135 |
+
|
| 136 |
+
# compute pairwise euclidean distance between all cbetas
|
| 137 |
+
dist_mtx = cdist(coords, coords, metric="euclidean")
|
| 138 |
+
|
| 139 |
+
return dist_mtx
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_neighbors(g, nodes):
|
| 143 |
+
""" returns a list (set) of neighbors of all given nodes """
|
| 144 |
+
neighbors = set()
|
| 145 |
+
for n in nodes:
|
| 146 |
+
neighbors.update(g.neighbors(n))
|
| 147 |
+
return sorted(list(neighbors))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def gen_graph(graph_type, res_dist_mtx, dist_thresh=7, shuffle_seed=7, graph_save_dir=None, save=False):
|
| 151 |
+
""" generate the specified structure graph using the specified residue distance matrix """
|
| 152 |
+
if graph_type is GraphType.LINEAR:
|
| 153 |
+
g = linear_graph(len(res_dist_mtx))
|
| 154 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "linear.graph")
|
| 155 |
+
|
| 156 |
+
elif graph_type is GraphType.COMPLETE:
|
| 157 |
+
g = complete_graph(len(res_dist_mtx))
|
| 158 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "complete.graph")
|
| 159 |
+
|
| 160 |
+
elif graph_type is GraphType.DISCONNECTED:
|
| 161 |
+
g = disconnected_graph(len(res_dist_mtx))
|
| 162 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "disconnected.graph")
|
| 163 |
+
|
| 164 |
+
elif graph_type is GraphType.DIST_THRESH:
|
| 165 |
+
g = dist_thresh_graph(res_dist_mtx, dist_thresh)
|
| 166 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "dist_thresh_{}.graph".format(dist_thresh))
|
| 167 |
+
|
| 168 |
+
elif graph_type is GraphType.DIST_THRESH_SHUFFLED:
|
| 169 |
+
g = dist_thresh_graph(res_dist_mtx, dist_thresh)
|
| 170 |
+
g = shuffle_nodes(g, seed=shuffle_seed)
|
| 171 |
+
save_fn = None if not save else \
|
| 172 |
+
os.path.join(graph_save_dir, "dist_thresh_{}_shuffled_r{}.graph".format(dist_thresh, shuffle_seed))
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
raise ValueError("Graph type {} is not implemented".format(graph_type))
|
| 176 |
+
|
| 177 |
+
if save:
|
| 178 |
+
if isfile(save_fn):
|
| 179 |
+
print("err: graph already exists: {}. to overwrite, delete the existing file first".format(save_fn))
|
| 180 |
+
else:
|
| 181 |
+
os.makedirs(graph_save_dir, exist_ok=True)
|
| 182 |
+
save_graph(g, save_fn)
|
| 183 |
+
|
| 184 |
+
return g
|