Spaces:
Running
Running
add codes
Browse files- models.py +392 -0
- new_dataloader.py +349 -0
- requirements.txt +8 -0
- trainer.py +892 -0
- training_data.py +50 -0
- utils.py +462 -0
models.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from layers import TransformerEncoder, TransformerDecoder
|
| 5 |
+
|
| 6 |
+
class Generator(nn.Module):
|
| 7 |
+
"""Generator network."""
|
| 8 |
+
def __init__(self, z_dim, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio, submodel):
|
| 9 |
+
super(Generator, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.submodel = submodel
|
| 12 |
+
self.vertexes = vertexes
|
| 13 |
+
self.edges = edges
|
| 14 |
+
self.nodes = nodes
|
| 15 |
+
self.depth = depth
|
| 16 |
+
self.dim = dim
|
| 17 |
+
self.heads = heads
|
| 18 |
+
self.mlp_ratio = mlp_ratio
|
| 19 |
+
|
| 20 |
+
self.dropout = dropout
|
| 21 |
+
self.z_dim = z_dim
|
| 22 |
+
|
| 23 |
+
if act == "relu":
|
| 24 |
+
act = nn.ReLU()
|
| 25 |
+
elif act == "leaky":
|
| 26 |
+
act = nn.LeakyReLU()
|
| 27 |
+
elif act == "sigmoid":
|
| 28 |
+
act = nn.Sigmoid()
|
| 29 |
+
elif act == "tanh":
|
| 30 |
+
act = nn.Tanh()
|
| 31 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
| 32 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
| 33 |
+
self.pos_enc_dim = 5
|
| 34 |
+
#self.pos_enc = nn.Linear(self.pos_enc_dim, self.dim)
|
| 35 |
+
|
| 36 |
+
self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
|
| 37 |
+
self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
|
| 38 |
+
|
| 39 |
+
self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
|
| 40 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
|
| 41 |
+
|
| 42 |
+
self.readout_e = nn.Linear(self.dim, edges)
|
| 43 |
+
self.readout_n = nn.Linear(self.dim, nodes)
|
| 44 |
+
self.softmax = nn.Softmax(dim = -1)
|
| 45 |
+
|
| 46 |
+
def _generate_square_subsequent_mask(self, sz):
|
| 47 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
| 48 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 49 |
+
return mask
|
| 50 |
+
|
| 51 |
+
def laplacian_positional_enc(self, adj):
|
| 52 |
+
|
| 53 |
+
A = adj
|
| 54 |
+
D = torch.diag(torch.count_nonzero(A, dim=-1))
|
| 55 |
+
L = torch.eye(A.shape[0], device=A.device) - D * A * D
|
| 56 |
+
|
| 57 |
+
EigVal, EigVec = torch.linalg.eig(L)
|
| 58 |
+
|
| 59 |
+
idx = torch.argsort(torch.real(EigVal))
|
| 60 |
+
EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
|
| 61 |
+
pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
|
| 62 |
+
|
| 63 |
+
return pos_enc
|
| 64 |
+
|
| 65 |
+
def forward(self, z_e, z_n):
|
| 66 |
+
b, n, c = z_n.shape
|
| 67 |
+
_, _, _ , d = z_e.shape
|
| 68 |
+
#random_mask_e = torch.randint(low=0,high=2,size=(b,n,n,d)).to(z_e.device).float()
|
| 69 |
+
#random_mask_n = torch.randint(low=0,high=2,size=(b,n,c)).to(z_n.device).float()
|
| 70 |
+
#z_e = F.relu(z_e - random_mask_e)
|
| 71 |
+
#z_n = F.relu(z_n - random_mask_n)
|
| 72 |
+
|
| 73 |
+
#mask = self._generate_square_subsequent_mask(self.vertexes).to(z_e.device)
|
| 74 |
+
|
| 75 |
+
node = self.node_layers(z_n)
|
| 76 |
+
|
| 77 |
+
edge = self.edge_layers(z_e)
|
| 78 |
+
|
| 79 |
+
edge = (edge + edge.permute(0,2,1,3))/2
|
| 80 |
+
|
| 81 |
+
#lap = [self.laplacian_positional_enc(torch.max(x,-1)[1]) for x in edge]
|
| 82 |
+
|
| 83 |
+
#lap = torch.stack(lap).to(node.device)
|
| 84 |
+
|
| 85 |
+
#pos_enc = self.pos_enc(lap)
|
| 86 |
+
|
| 87 |
+
#node = node + pos_enc
|
| 88 |
+
|
| 89 |
+
node, edge = self.TransformerEncoder(node,edge)
|
| 90 |
+
|
| 91 |
+
node_sample = self.softmax(self.readout_n(node))
|
| 92 |
+
|
| 93 |
+
edge_sample = self.softmax(self.readout_e(edge))
|
| 94 |
+
|
| 95 |
+
return node, edge, node_sample, edge_sample
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Generator2(nn.Module):
|
| 100 |
+
def __init__(self, dim, dec_dim, depth, heads, mlp_ratio, drop_rate, drugs_m_dim, drugs_b_dim, submodel):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.submodel = submodel
|
| 103 |
+
self.depth = depth
|
| 104 |
+
self.dim = dim
|
| 105 |
+
self.mlp_ratio = mlp_ratio
|
| 106 |
+
self.heads = heads
|
| 107 |
+
self.dropout_rate = drop_rate
|
| 108 |
+
self.drugs_m_dim = drugs_m_dim
|
| 109 |
+
self.drugs_b_dim = drugs_b_dim
|
| 110 |
+
|
| 111 |
+
self.pos_enc_dim = 5
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if self.submodel == "Prot":
|
| 115 |
+
self.prot_n = torch.nn.Linear(3822, 45) ## exact dimension of protein features
|
| 116 |
+
self.prot_e = torch.nn.Linear(298116, 2025) ## exact dimension of protein features
|
| 117 |
+
|
| 118 |
+
self.protn_dim = torch.nn.Linear(1, dec_dim)
|
| 119 |
+
self.prote_dim = torch.nn.Linear(1, dec_dim)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
self.mol_nodes = nn.Linear(dim, dec_dim)
|
| 123 |
+
self.mol_edges = nn.Linear(dim, dec_dim)
|
| 124 |
+
|
| 125 |
+
self.drug_nodes = nn.Linear(self.drugs_m_dim, dec_dim)
|
| 126 |
+
self.drug_edges = nn.Linear(self.drugs_b_dim, dec_dim)
|
| 127 |
+
|
| 128 |
+
self.TransformerDecoder = TransformerDecoder(dec_dim, depth, heads, mlp_ratio, drop_rate=self.dropout_rate)
|
| 129 |
+
|
| 130 |
+
self.nodes_output_layer = nn.Linear(dec_dim, self.drugs_m_dim)
|
| 131 |
+
self.edges_output_layer = nn.Linear(dec_dim, self.drugs_b_dim)
|
| 132 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 133 |
+
|
| 134 |
+
def laplacian_positional_enc(self, adj):
|
| 135 |
+
|
| 136 |
+
A = adj
|
| 137 |
+
D = torch.diag(torch.count_nonzero(A, dim=-1))
|
| 138 |
+
L = torch.eye(A.shape[0], device=A.device) - D * A * D
|
| 139 |
+
|
| 140 |
+
EigVal, EigVec = torch.linalg.eig(L)
|
| 141 |
+
|
| 142 |
+
idx = torch.argsort(torch.real(EigVal))
|
| 143 |
+
EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
|
| 144 |
+
pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
|
| 145 |
+
|
| 146 |
+
return pos_enc
|
| 147 |
+
|
| 148 |
+
def _generate_square_subsequent_mask(self, sz):
|
| 149 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
| 150 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 151 |
+
return mask
|
| 152 |
+
|
| 153 |
+
def forward(self, edges_logits, nodes_logits ,akt1_adj,akt1_annot):
|
| 154 |
+
|
| 155 |
+
edges_logits = self.mol_edges(edges_logits)
|
| 156 |
+
nodes_logits = self.mol_nodes(nodes_logits)
|
| 157 |
+
|
| 158 |
+
if self.submodel != "Prot":
|
| 159 |
+
akt1_annot = self.drug_nodes(akt1_annot)
|
| 160 |
+
akt1_adj = self.drug_edges(akt1_adj)
|
| 161 |
+
|
| 162 |
+
else:
|
| 163 |
+
akt1_adj = self.prote_dim(self.prot_e(akt1_adj).view(1,45,45,1))
|
| 164 |
+
akt1_annot = self.protn_dim(self.prot_n(akt1_annot).view(1,45,1))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
#lap = [self.laplacian_positional_enc(torch.max(x,-1)[1]) for x in drug_e]
|
| 168 |
+
#lap = torch.stack(lap).to(drug_e.device)
|
| 169 |
+
#pos_enc = self.pos_enc(lap)
|
| 170 |
+
#drug_n = drug_n + pos_enc
|
| 171 |
+
|
| 172 |
+
nodes_logits,akt1_annot, edges_logits, akt1_adj = self.TransformerDecoder(nodes_logits,akt1_annot,edges_logits,akt1_adj)
|
| 173 |
+
|
| 174 |
+
edges_logits = self.edges_output_layer(edges_logits)
|
| 175 |
+
nodes_logits = self.nodes_output_layer(nodes_logits)
|
| 176 |
+
|
| 177 |
+
edges_logits = self.softmax(edges_logits)
|
| 178 |
+
nodes_logits = self.softmax(nodes_logits)
|
| 179 |
+
|
| 180 |
+
return edges_logits, nodes_logits
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class simple_disc(nn.Module):
|
| 184 |
+
def __init__(self, act, m_dim, vertexes, b_dim):
|
| 185 |
+
super().__init__()
|
| 186 |
+
if act == "relu":
|
| 187 |
+
act = nn.ReLU()
|
| 188 |
+
elif act == "leaky":
|
| 189 |
+
act = nn.LeakyReLU()
|
| 190 |
+
elif act == "sigmoid":
|
| 191 |
+
act = nn.Sigmoid()
|
| 192 |
+
elif act == "tanh":
|
| 193 |
+
act = nn.Tanh()
|
| 194 |
+
features = vertexes * m_dim + vertexes * vertexes * b_dim
|
| 195 |
+
|
| 196 |
+
self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
|
| 197 |
+
nn.Linear(64,32), act, nn.Linear(32,16), act,
|
| 198 |
+
nn.Linear(16,1))
|
| 199 |
+
|
| 200 |
+
def forward(self, x):
|
| 201 |
+
|
| 202 |
+
prediction = self.predictor(x)
|
| 203 |
+
|
| 204 |
+
#prediction = F.softmax(prediction,dim=-1)
|
| 205 |
+
|
| 206 |
+
return prediction
|
| 207 |
+
|
| 208 |
+
"""class Discriminator(nn.Module):
|
| 209 |
+
|
| 210 |
+
def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
|
| 211 |
+
super(Discriminator, self).__init__()
|
| 212 |
+
self.degree = deg
|
| 213 |
+
self.aggregators = agg
|
| 214 |
+
self.scalers = sca
|
| 215 |
+
self.pna_in_channels = pna_in_ch
|
| 216 |
+
self.pna_out_channels = pna_out_ch
|
| 217 |
+
self.edge_dimension = edge_dim
|
| 218 |
+
self.towers = towers
|
| 219 |
+
self.pre_layers_num = pre_lay
|
| 220 |
+
self.post_layers_num = post_lay
|
| 221 |
+
self.pna_layer_num = pna_layer_num
|
| 222 |
+
self.graph_add = graph_add
|
| 223 |
+
self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
|
| 224 |
+
pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
|
| 225 |
+
towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
|
| 226 |
+
pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
|
| 227 |
+
|
| 228 |
+
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
| 229 |
+
|
| 230 |
+
h = self.PNA_layer(x, edge_index, edge_attr, batch)
|
| 231 |
+
|
| 232 |
+
h = activation(h) if activation is not None else h
|
| 233 |
+
|
| 234 |
+
return h"""
|
| 235 |
+
|
| 236 |
+
"""class Discriminator2(nn.Module):
|
| 237 |
+
|
| 238 |
+
def __init__(self,deg,agg,sca,pna_in_ch,pna_out_ch,edge_dim,towers,pre_lay,post_lay,pna_layer_num, graph_add):
|
| 239 |
+
super(Discriminator2, self).__init__()
|
| 240 |
+
self.degree = deg
|
| 241 |
+
self.aggregators = agg
|
| 242 |
+
self.scalers = sca
|
| 243 |
+
self.pna_in_channels = pna_in_ch
|
| 244 |
+
self.pna_out_channels = pna_out_ch
|
| 245 |
+
self.edge_dimension = edge_dim
|
| 246 |
+
self.towers = towers
|
| 247 |
+
self.pre_layers_num = pre_lay
|
| 248 |
+
self.post_layers_num = post_lay
|
| 249 |
+
self.pna_layer_num = pna_layer_num
|
| 250 |
+
self.graph_add = graph_add
|
| 251 |
+
self.PNA_layer = PNA(deg=self.degree, agg =self.aggregators,sca = self.scalers,
|
| 252 |
+
pna_in_ch= self.pna_in_channels, pna_out_ch = self.pna_out_channels, edge_dim = self.edge_dimension,
|
| 253 |
+
towers = self.towers, pre_lay = self.pre_layers_num, post_lay = self.post_layers_num,
|
| 254 |
+
pna_layer_num = self.pna_layer_num, graph_add = self.graph_add)
|
| 255 |
+
|
| 256 |
+
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
| 257 |
+
|
| 258 |
+
h = self.PNA_layer(x, edge_index, edge_attr, batch)
|
| 259 |
+
|
| 260 |
+
h = activation(h) if activation is not None else h
|
| 261 |
+
|
| 262 |
+
return h"""
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
"""class Discriminator_old(nn.Module):
|
| 266 |
+
|
| 267 |
+
def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
|
| 268 |
+
super(Discriminator_old, self).__init__()
|
| 269 |
+
|
| 270 |
+
graph_conv_dim, aux_dim, linear_dim = conv_dim
|
| 271 |
+
|
| 272 |
+
# discriminator
|
| 273 |
+
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout,gcn_depth)
|
| 274 |
+
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
|
| 275 |
+
|
| 276 |
+
# multi dense layer
|
| 277 |
+
layers = []
|
| 278 |
+
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
|
| 279 |
+
layers.append(nn.Linear(c0,c1))
|
| 280 |
+
layers.append(nn.Dropout(dropout))
|
| 281 |
+
self.linear_layer = nn.Sequential(*layers)
|
| 282 |
+
|
| 283 |
+
self.output_layer = nn.Linear(linear_dim[-1], 1)
|
| 284 |
+
|
| 285 |
+
def forward(self, adj, hidden, node, activation=None):
|
| 286 |
+
|
| 287 |
+
adj = adj[:,:,:,1:].permute(0,3,1,2)
|
| 288 |
+
|
| 289 |
+
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
|
| 290 |
+
|
| 291 |
+
h = self.gcn_layer(annotations, adj)
|
| 292 |
+
annotations = torch.cat((h, hidden, node) if hidden is not None\
|
| 293 |
+
else (h, node), -1)
|
| 294 |
+
|
| 295 |
+
h = self.agg_layer(annotations, torch.tanh)
|
| 296 |
+
h = self.linear_layer(h)
|
| 297 |
+
|
| 298 |
+
# Need to implement batch discriminator #
|
| 299 |
+
#########################################
|
| 300 |
+
|
| 301 |
+
output = self.output_layer(h)
|
| 302 |
+
output = activation(output) if activation is not None else output
|
| 303 |
+
|
| 304 |
+
return output, h"""
|
| 305 |
+
|
| 306 |
+
"""class Discriminator_old2(nn.Module):
|
| 307 |
+
|
| 308 |
+
def __init__(self, conv_dim, m_dim, b_dim, dropout, gcn_depth):
|
| 309 |
+
super(Discriminator_old2, self).__init__()
|
| 310 |
+
|
| 311 |
+
graph_conv_dim, aux_dim, linear_dim = conv_dim
|
| 312 |
+
|
| 313 |
+
# discriminator
|
| 314 |
+
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout, gcn_depth)
|
| 315 |
+
self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout)
|
| 316 |
+
|
| 317 |
+
# multi dense layer
|
| 318 |
+
layers = []
|
| 319 |
+
for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim):
|
| 320 |
+
layers.append(nn.Linear(c0,c1))
|
| 321 |
+
layers.append(nn.Dropout(dropout))
|
| 322 |
+
self.linear_layer = nn.Sequential(*layers)
|
| 323 |
+
|
| 324 |
+
self.output_layer = nn.Linear(linear_dim[-1], 1)
|
| 325 |
+
|
| 326 |
+
def forward(self, adj, hidden, node, activation=None):
|
| 327 |
+
|
| 328 |
+
adj = adj[:,:,:,1:].permute(0,3,1,2)
|
| 329 |
+
|
| 330 |
+
annotations = torch.cat((hidden, node), -1) if hidden is not None else node
|
| 331 |
+
|
| 332 |
+
h = self.gcn_layer(annotations, adj)
|
| 333 |
+
annotations = torch.cat((h, hidden, node) if hidden is not None\
|
| 334 |
+
else (h, node), -1)
|
| 335 |
+
|
| 336 |
+
h = self.agg_layer(annotations, torch.tanh)
|
| 337 |
+
h = self.linear_layer(h)
|
| 338 |
+
|
| 339 |
+
# Need to implement batch discriminator #
|
| 340 |
+
#########################################
|
| 341 |
+
|
| 342 |
+
output = self.output_layer(h)
|
| 343 |
+
output = activation(output) if activation is not None else output
|
| 344 |
+
|
| 345 |
+
return output, h"""
|
| 346 |
+
|
| 347 |
+
"""class Discriminator3(nn.Module):
|
| 348 |
+
|
| 349 |
+
def __init__(self,in_ch):
|
| 350 |
+
super(Discriminator3, self).__init__()
|
| 351 |
+
self.dim = in_ch
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
self.TraConv_layer = TransformerConv(in_channels = self.dim,out_channels = self.dim//4,edge_dim = self.dim)
|
| 355 |
+
self.mlp = torch.nn.Sequential(torch.nn.Tanh(), torch.nn.Linear(self.dim//4,1))
|
| 356 |
+
def forward(self, x, edge_index, edge_attr, batch, activation=None):
|
| 357 |
+
|
| 358 |
+
h = self.TraConv_layer(x, edge_index, edge_attr)
|
| 359 |
+
h = global_add_pool(h,batch)
|
| 360 |
+
h = self.mlp(h)
|
| 361 |
+
h = activation(h) if activation is not None else h
|
| 362 |
+
|
| 363 |
+
return h"""
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
"""class PNA_Net(nn.Module):
|
| 367 |
+
def __init__(self,deg):
|
| 368 |
+
super().__init__()
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
self.convs = nn.ModuleList()
|
| 373 |
+
|
| 374 |
+
self.lin = nn.Linear(5, 128)
|
| 375 |
+
for _ in range(1):
|
| 376 |
+
conv = DenseGCNConv(128, 128, improved=False, bias=True)
|
| 377 |
+
self.convs.append(conv)
|
| 378 |
+
|
| 379 |
+
self.agg_layer = GraphAggregation(128, 128, 0, dropout=0.1)
|
| 380 |
+
self.mlp = nn.Sequential(nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 32), nn.Tanh(),
|
| 381 |
+
nn.Linear(32, 1))
|
| 382 |
+
|
| 383 |
+
def forward(self, x, adj,mask=None):
|
| 384 |
+
x = self.lin(x)
|
| 385 |
+
|
| 386 |
+
for conv in self.convs:
|
| 387 |
+
x = F.relu(conv(x, adj,mask=None))
|
| 388 |
+
|
| 389 |
+
x = self.agg_layer(x,torch.tanh)
|
| 390 |
+
|
| 391 |
+
return self.mlp(x) """
|
| 392 |
+
|
new_dataloader.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from rdkit import RDLogger
|
| 10 |
+
from torch_geometric.data import (Data, InMemoryDataset)
|
| 11 |
+
|
| 12 |
+
RDLogger.DisableLog('rdApp.*')
|
| 13 |
+
class DruggenDataset(InMemoryDataset):
|
| 14 |
+
|
| 15 |
+
def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None):
|
| 16 |
+
self.dataset_name = dataset_file.split(".")[0]
|
| 17 |
+
self.dataset_file = dataset_file
|
| 18 |
+
self.raw_files = raw_files
|
| 19 |
+
self.max_atom = max_atom
|
| 20 |
+
self.features = features
|
| 21 |
+
|
| 22 |
+
super().__init__(root, transform, pre_transform, pre_filter)
|
| 23 |
+
self.data, self.slices = torch.load(osp.join(root, dataset_file))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def raw_file_names(self):
|
| 28 |
+
return self.raw_files
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def processed_file_names(self):
|
| 32 |
+
'''
|
| 33 |
+
Return the processed file names. If these names are not present, they will be automatically processed using process function of this class.
|
| 34 |
+
'''
|
| 35 |
+
return self.dataset_file
|
| 36 |
+
|
| 37 |
+
def _generate_encoders_decoders(self, data):
|
| 38 |
+
"""
|
| 39 |
+
Generates the encoders and decoders for the atoms and bonds.
|
| 40 |
+
"""
|
| 41 |
+
self.data = data
|
| 42 |
+
print('Creating atoms encoder and decoder..')
|
| 43 |
+
|
| 44 |
+
atom_labels = set()
|
| 45 |
+
# bond_labels = set()
|
| 46 |
+
self.max_atom_size_in_data = 0
|
| 47 |
+
|
| 48 |
+
for smile in data:
|
| 49 |
+
mol = Chem.MolFromSmiles(smile)
|
| 50 |
+
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
|
| 51 |
+
# bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
|
| 52 |
+
self.max_atom_size_in_data = max(self.max_atom_size_in_data, mol.GetNumAtoms())
|
| 53 |
+
atom_labels.update([0]) # add PAD symbol (for unknown atoms)
|
| 54 |
+
atom_labels = sorted(atom_labels) # turn set into list and sort it
|
| 55 |
+
|
| 56 |
+
# atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
|
| 57 |
+
self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
|
| 58 |
+
self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
|
| 59 |
+
self.atom_num_types = len(atom_labels)
|
| 60 |
+
print(f'Created atoms encoder and decoder with {self.atom_num_types - 1} atom types and 1 PAD symbol!')
|
| 61 |
+
print("atom_labels", atom_labels)
|
| 62 |
+
print('Creating bonds encoder and decoder..')
|
| 63 |
+
# bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
|
| 64 |
+
# for mol in self.data
|
| 65 |
+
# for bond in mol.GetBonds())))
|
| 66 |
+
bond_labels = [
|
| 67 |
+
Chem.rdchem.BondType.ZERO,
|
| 68 |
+
Chem.rdchem.BondType.SINGLE,
|
| 69 |
+
Chem.rdchem.BondType.DOUBLE,
|
| 70 |
+
Chem.rdchem.BondType.TRIPLE,
|
| 71 |
+
Chem.rdchem.BondType.AROMATIC,
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
print("bond labels", bond_labels)
|
| 75 |
+
self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
|
| 76 |
+
self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
|
| 77 |
+
self.bond_num_types = len(bond_labels)
|
| 78 |
+
print(f'Created bonds encoder and decoder with {self.bond_num_types - 1} bond types and 1 PAD symbol!')
|
| 79 |
+
#dataset_names = str(self.dataset_name)
|
| 80 |
+
with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
|
| 81 |
+
pickle.dump(self.atom_encoder_m,atom_encoders)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
with open("DrugGEN/data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders:
|
| 85 |
+
pickle.dump(self.atom_decoder_m,atom_decoders)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
with open("DrugGEN/data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders:
|
| 89 |
+
pickle.dump(self.bond_encoder_m,bond_encoders)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
|
| 93 |
+
pickle.dump(self.bond_decoder_m,bond_decoders)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def generate_adjacency_matrix(self, mol, connected=True, max_length=None):
|
| 98 |
+
"""
|
| 99 |
+
Generates the adjacency matrix for a molecule.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
mol (Molecule): The molecule object.
|
| 103 |
+
connected (bool): Whether to check for connectivity in the molecule. Defaults to True.
|
| 104 |
+
max_length (int): The maximum length of the adjacency matrix. Defaults to the number of atoms in the molecule.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
numpy.ndarray or None: The adjacency matrix if connected and all atoms have a degree greater than 0,
|
| 108 |
+
otherwise None.
|
| 109 |
+
"""
|
| 110 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
| 111 |
+
|
| 112 |
+
A = np.zeros(shape=(max_length, max_length))
|
| 113 |
+
|
| 114 |
+
begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
|
| 115 |
+
bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
|
| 116 |
+
|
| 117 |
+
A[begin, end] = bond_type
|
| 118 |
+
A[end, begin] = bond_type
|
| 119 |
+
|
| 120 |
+
degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
|
| 121 |
+
|
| 122 |
+
return A if connected and (degree > 0).all() else None
|
| 123 |
+
|
| 124 |
+
def generate_node_features(self, mol, max_length=None):
|
| 125 |
+
"""
|
| 126 |
+
Generates the node features for a molecule.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
mol (Molecule): The molecule object.
|
| 130 |
+
max_length (int): The maximum length of the node features. Defaults to the number of atoms in the molecule.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
numpy.ndarray: The node features matrix.
|
| 134 |
+
"""
|
| 135 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
| 136 |
+
|
| 137 |
+
return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
|
| 138 |
+
max_length - mol.GetNumAtoms()))
|
| 139 |
+
|
| 140 |
+
def generate_additional_features(self, mol, max_length=None):
|
| 141 |
+
"""
|
| 142 |
+
Generates additional features for a molecule.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
mol (Molecule): The molecule object.
|
| 146 |
+
max_length (int): The maximum length of the additional features. Defaults to the number of atoms in the molecule.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
numpy.ndarray: The additional features matrix.
|
| 150 |
+
"""
|
| 151 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
| 152 |
+
|
| 153 |
+
features = np.array([[*[a.GetDegree() == i for i in range(5)],
|
| 154 |
+
*[a.GetExplicitValence() == i for i in range(9)],
|
| 155 |
+
*[int(a.GetHybridization()) == i for i in range(1, 7)],
|
| 156 |
+
*[a.GetImplicitValence() == i for i in range(9)],
|
| 157 |
+
a.GetIsAromatic(),
|
| 158 |
+
a.GetNoImplicit(),
|
| 159 |
+
*[a.GetNumExplicitHs() == i for i in range(5)],
|
| 160 |
+
*[a.GetNumImplicitHs() == i for i in range(5)],
|
| 161 |
+
*[a.GetNumRadicalElectrons() == i for i in range(5)],
|
| 162 |
+
a.IsInRing(),
|
| 163 |
+
*[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32)
|
| 164 |
+
|
| 165 |
+
return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
|
| 166 |
+
|
| 167 |
+
def decoder_load(self, dictionary_name):
|
| 168 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 169 |
+
return pickle.load(f)
|
| 170 |
+
|
| 171 |
+
def drugs_decoder_load(self, dictionary_name):
|
| 172 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
|
| 173 |
+
return pickle.load(f)
|
| 174 |
+
|
| 175 |
+
def matrices2mol(self, node_labels, edge_labels, strict=True):
|
| 176 |
+
mol = Chem.RWMol()
|
| 177 |
+
RDLogger.DisableLog('rdApp.*')
|
| 178 |
+
atom_decoders = self.decoder_load("atom")
|
| 179 |
+
bond_decoders = self.decoder_load("bond")
|
| 180 |
+
|
| 181 |
+
for node_label in node_labels:
|
| 182 |
+
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
|
| 183 |
+
|
| 184 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
| 185 |
+
if start > end:
|
| 186 |
+
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
|
| 187 |
+
mol = self.correct_mol(mol)
|
| 188 |
+
if strict:
|
| 189 |
+
try:
|
| 190 |
+
|
| 191 |
+
Chem.SanitizeMol(mol)
|
| 192 |
+
except:
|
| 193 |
+
mol = None
|
| 194 |
+
|
| 195 |
+
return mol
|
| 196 |
+
|
| 197 |
+
def drug_decoder_load(self, dictionary_name):
|
| 198 |
+
|
| 199 |
+
''' Loading the atom and bond decoders '''
|
| 200 |
+
|
| 201 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + "akt_train" +'.pkl', 'rb') as f:
|
| 202 |
+
|
| 203 |
+
return pickle.load(f)
|
| 204 |
+
def matrices2mol_drugs(self, node_labels, edge_labels, strict=True):
|
| 205 |
+
mol = Chem.RWMol()
|
| 206 |
+
RDLogger.DisableLog('rdApp.*')
|
| 207 |
+
atom_decoders = self.drug_decoder_load("atom")
|
| 208 |
+
bond_decoders = self.drug_decoder_load("bond")
|
| 209 |
+
|
| 210 |
+
for node_label in node_labels:
|
| 211 |
+
|
| 212 |
+
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
|
| 213 |
+
|
| 214 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
| 215 |
+
if start > end:
|
| 216 |
+
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
|
| 217 |
+
mol = self.correct_mol(mol)
|
| 218 |
+
if strict:
|
| 219 |
+
try:
|
| 220 |
+
Chem.SanitizeMol(mol)
|
| 221 |
+
except:
|
| 222 |
+
mol = None
|
| 223 |
+
|
| 224 |
+
return mol
|
| 225 |
+
def check_valency(self,mol):
|
| 226 |
+
"""
|
| 227 |
+
Checks that no atoms in the mol have exceeded their possible
|
| 228 |
+
valency
|
| 229 |
+
:return: True if no valency issues, False otherwise
|
| 230 |
+
"""
|
| 231 |
+
try:
|
| 232 |
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
|
| 233 |
+
return True, None
|
| 234 |
+
except ValueError as e:
|
| 235 |
+
e = str(e)
|
| 236 |
+
p = e.find('#')
|
| 237 |
+
e_sub = e[p:]
|
| 238 |
+
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
|
| 239 |
+
return False, atomid_valence
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def correct_mol(self,x):
|
| 243 |
+
# xsm = Chem.MolToSmiles(x, isomericSmiles=True)
|
| 244 |
+
mol = x
|
| 245 |
+
while True:
|
| 246 |
+
flag, atomid_valence = self.check_valency(mol)
|
| 247 |
+
if flag:
|
| 248 |
+
break
|
| 249 |
+
else:
|
| 250 |
+
assert len (atomid_valence) == 2
|
| 251 |
+
idx = atomid_valence[0]
|
| 252 |
+
v = atomid_valence[1]
|
| 253 |
+
queue = []
|
| 254 |
+
for b in mol.GetAtomWithIdx(idx).GetBonds():
|
| 255 |
+
queue.append(
|
| 256 |
+
(b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
|
| 257 |
+
)
|
| 258 |
+
queue.sort(key=lambda tup: tup[1], reverse=True)
|
| 259 |
+
if len(queue) > 0:
|
| 260 |
+
start = queue[0][2]
|
| 261 |
+
end = queue[0][3]
|
| 262 |
+
t = queue[0][1] - 1
|
| 263 |
+
mol.RemoveBond(start, end)
|
| 264 |
+
|
| 265 |
+
#if t >= 1:
|
| 266 |
+
|
| 267 |
+
#mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
|
| 268 |
+
# if '.' in Chem.MolToSmiles(mol, isomericSmiles=True):
|
| 269 |
+
# mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
|
| 270 |
+
# print(tt)
|
| 271 |
+
# print(Chem.MolToSmiles(mol, isomericSmiles=True))
|
| 272 |
+
|
| 273 |
+
return mol
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def label2onehot(self, labels, dim):
|
| 278 |
+
|
| 279 |
+
"""Convert label indices to one-hot vectors."""
|
| 280 |
+
|
| 281 |
+
out = torch.zeros(list(labels.size())+[dim])
|
| 282 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
| 283 |
+
|
| 284 |
+
return out.float()
|
| 285 |
+
|
| 286 |
+
def process(self, size= None):
|
| 287 |
+
'''
|
| 288 |
+
Process the dataset. This function will be only run if processed_file_names does not exist in the data folder already.
|
| 289 |
+
'''
|
| 290 |
+
# mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()]
|
| 291 |
+
# mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols))
|
| 292 |
+
# mols = mols[:size] # i
|
| 293 |
+
# indices = range(len(mols))
|
| 294 |
+
|
| 295 |
+
smiles = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
| 296 |
+
self._generate_encoders_decoders(smiles)
|
| 297 |
+
|
| 298 |
+
# pbar.set_description(f'Processing chembl dataset')
|
| 299 |
+
# max_length = max(mol.GetNumAtoms() for mol in mols)
|
| 300 |
+
data_list = []
|
| 301 |
+
max_length = min(self.max_atom_size_in_data, self.max_atom)
|
| 302 |
+
self.m_dim = len(self.atom_decoder_m)
|
| 303 |
+
# for idx in indices:
|
| 304 |
+
for smiles in tqdm(smiles, desc='Processing chembl dataset', total=len(smiles)):
|
| 305 |
+
# mol = mols[idx]
|
| 306 |
+
|
| 307 |
+
mol = Chem.MolFromSmiles(smile)
|
| 308 |
+
|
| 309 |
+
# filter by max atom size
|
| 310 |
+
if mol.GetNumAtoms() > max_length:
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
A = self.generate_adjacency_matrix(mol, connected=True, max_length=max_length)
|
| 314 |
+
if A is not None:
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
x = torch.from_numpy(self.generate_node_features(mol, max_length=max_length)).to(torch.long).view(1, -1)
|
| 318 |
+
|
| 319 |
+
x = self.label2onehot(x,self.m_dim).squeeze()
|
| 320 |
+
if self.features:
|
| 321 |
+
f = torch.from_numpy(self.generate_additional_features(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
|
| 322 |
+
x = torch.concat((x,f), dim=-1)
|
| 323 |
+
|
| 324 |
+
adjacency = torch.from_numpy(A)
|
| 325 |
+
|
| 326 |
+
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|
| 327 |
+
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
|
| 328 |
+
|
| 329 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
| 330 |
+
|
| 331 |
+
if self.pre_filter is not None and not self.pre_filter(data):
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
if self.pre_transform is not None:
|
| 335 |
+
data = self.pre_transform(data)
|
| 336 |
+
|
| 337 |
+
data_list.append(data)
|
| 338 |
+
# pbar.update(1)
|
| 339 |
+
|
| 340 |
+
# pbar.close()
|
| 341 |
+
|
| 342 |
+
torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if __name__ == '__main__':
|
| 348 |
+
data = DruggenDataset("DrugGEN/data")
|
| 349 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
rdkit-pypi
|
| 3 |
+
tqdm
|
| 4 |
+
numpy
|
| 5 |
+
seaborn
|
| 6 |
+
matplotlib
|
| 7 |
+
pandas
|
| 8 |
+
torch_geometric
|
trainer.py
ADDED
|
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import torch.nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from utils import *
|
| 7 |
+
from models import Generator, Generator2, simple_disc
|
| 8 |
+
import torch_geometric.utils as geoutils
|
| 9 |
+
#import #wandb
|
| 10 |
+
import re
|
| 11 |
+
from torch_geometric.loader import DataLoader
|
| 12 |
+
from new_dataloader import DruggenDataset
|
| 13 |
+
import torch.utils.data
|
| 14 |
+
from rdkit import RDLogger
|
| 15 |
+
import pickle
|
| 16 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
| 17 |
+
torch.set_num_threads(5)
|
| 18 |
+
RDLogger.DisableLog('rdApp.*')
|
| 19 |
+
from loss import discriminator_loss, generator_loss, discriminator2_loss, generator2_loss
|
| 20 |
+
from training_data import load_data
|
| 21 |
+
import random
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Trainer(object):
|
| 25 |
+
|
| 26 |
+
"""Trainer for training and testing DrugGEN."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config):
|
| 29 |
+
|
| 30 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
| 31 |
+
"""Initialize configurations."""
|
| 32 |
+
self.submodel = config.submodel
|
| 33 |
+
self.inference_model = config.inference_model
|
| 34 |
+
# Data loader.
|
| 35 |
+
self.raw_file = config.raw_file # SMILES containing text file for first dataset.
|
| 36 |
+
# Write the full path to file.
|
| 37 |
+
|
| 38 |
+
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
|
| 39 |
+
# Write the full path to file.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
self.dataset_file = config.dataset_file # Dataset file name for the first GAN.
|
| 43 |
+
# Contains large number of molecules.
|
| 44 |
+
|
| 45 |
+
self.drugs_dataset_file = config.drug_dataset_file # Drug dataset file name for the second GAN.
|
| 46 |
+
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
| 47 |
+
|
| 48 |
+
self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
|
| 49 |
+
# Write the full path to file.
|
| 50 |
+
|
| 51 |
+
self.inf_drug_raw_file = config.inf_drug_raw_file # SMILES containing text file for second dataset.
|
| 52 |
+
# Write the full path to file.
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN.
|
| 56 |
+
# Contains large number of molecules.
|
| 57 |
+
|
| 58 |
+
self.inf_drugs_dataset_file = config.inf_drug_dataset_file # Drug dataset file name for the second GAN.
|
| 59 |
+
# Contains drug molecules only. (In this case AKT1 inhibitors.)
|
| 60 |
+
|
| 61 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
| 62 |
+
|
| 63 |
+
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
| 64 |
+
|
| 65 |
+
self.dataset_name = self.dataset_file.split(".")[0]
|
| 66 |
+
self.drugs_name = self.drugs_dataset_file.split(".")[0]
|
| 67 |
+
|
| 68 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
| 69 |
+
# Max atom number for molecules must be specified.
|
| 70 |
+
|
| 71 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
| 72 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
self.batch_size = config.batch_size # Batch size for training.
|
| 76 |
+
|
| 77 |
+
self.dataset = DruggenDataset(self.mol_data_dir,
|
| 78 |
+
self.dataset_file,
|
| 79 |
+
self.raw_file,
|
| 80 |
+
self.max_atom,
|
| 81 |
+
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
|
| 82 |
+
# Can create any molecular graph dataset given smiles string.
|
| 83 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
| 84 |
+
# Uses sparse matrix representation for graphs,
|
| 85 |
+
# For computational and speed efficiency.
|
| 86 |
+
|
| 87 |
+
self.loader = DataLoader(self.dataset,
|
| 88 |
+
shuffle=True,
|
| 89 |
+
batch_size=self.batch_size,
|
| 90 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
| 91 |
+
|
| 92 |
+
self.drugs = DruggenDataset(self.drug_data_dir,
|
| 93 |
+
self.drugs_dataset_file,
|
| 94 |
+
self.drug_raw_file,
|
| 95 |
+
self.max_atom,
|
| 96 |
+
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
|
| 97 |
+
# Can create any molecular graph dataset given smiles string.
|
| 98 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
| 99 |
+
# Uses sparse matrix representation for graphs,
|
| 100 |
+
# For computational and speed efficiency.
|
| 101 |
+
|
| 102 |
+
self.drugs_loader = DataLoader(self.drugs,
|
| 103 |
+
shuffle=True,
|
| 104 |
+
batch_size=self.batch_size,
|
| 105 |
+
drop_last=True) # PyG dataloader for the second GAN.
|
| 106 |
+
|
| 107 |
+
# Atom and bond type dimensions for the construction of the model.
|
| 108 |
+
|
| 109 |
+
self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN.
|
| 110 |
+
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
|
| 111 |
+
|
| 112 |
+
self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN.
|
| 113 |
+
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
|
| 114 |
+
|
| 115 |
+
self.m_dim = len(self.atom_decoders) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
|
| 116 |
+
|
| 117 |
+
self.b_dim = len(self.bond_decoders) # Bond type dimension.
|
| 118 |
+
|
| 119 |
+
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
| 120 |
+
|
| 121 |
+
self.drugs_atom_decoders = self.drug_decoder_load("atom") # Atom type decoders for second GAN.
|
| 122 |
+
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
|
| 123 |
+
|
| 124 |
+
self.drugs_bond_decoders = self.drug_decoder_load("bond") # Bond type decoders for second GAN.
|
| 125 |
+
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
|
| 126 |
+
|
| 127 |
+
self.drugs_m_dim = len(self.drugs_atom_decoders) if not self.features else int(self.drugs_loader.dataset[0].x.shape[1]) # Atom type dimension.
|
| 128 |
+
|
| 129 |
+
self.drugs_b_dim = len(self.drugs_bond_decoders) # Bond type dimension.
|
| 130 |
+
|
| 131 |
+
self.drug_vertexes = int(self.drugs_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
| 132 |
+
|
| 133 |
+
# Transformer and Convolution configurations.
|
| 134 |
+
|
| 135 |
+
self.act = config.act
|
| 136 |
+
|
| 137 |
+
self.z_dim = config.z_dim
|
| 138 |
+
|
| 139 |
+
self.lambda_gp = config.lambda_gp
|
| 140 |
+
|
| 141 |
+
self.dim = config.dim
|
| 142 |
+
|
| 143 |
+
self.depth = config.depth
|
| 144 |
+
|
| 145 |
+
self.heads = config.heads
|
| 146 |
+
|
| 147 |
+
self.mlp_ratio = config.mlp_ratio
|
| 148 |
+
|
| 149 |
+
self.dec_depth = config.dec_depth
|
| 150 |
+
|
| 151 |
+
self.dec_heads = config.dec_heads
|
| 152 |
+
|
| 153 |
+
self.dec_dim = config.dec_dim
|
| 154 |
+
|
| 155 |
+
self.dis_select = config.dis_select
|
| 156 |
+
|
| 157 |
+
"""self.la = config.la
|
| 158 |
+
self.la2 = config.la2
|
| 159 |
+
self.gcn_depth = config.gcn_depth
|
| 160 |
+
self.g_conv_dim = config.g_conv_dim
|
| 161 |
+
self.d_conv_dim = config.d_conv_dim"""
|
| 162 |
+
"""# PNA config
|
| 163 |
+
|
| 164 |
+
self.agg = config.aggregators
|
| 165 |
+
self.sca = config.scalers
|
| 166 |
+
self.pna_in_ch = config.pna_in_ch
|
| 167 |
+
self.pna_out_ch = config.pna_out_ch
|
| 168 |
+
self.edge_dim = config.edge_dim
|
| 169 |
+
self.towers = config.towers
|
| 170 |
+
self.pre_lay = config.pre_lay
|
| 171 |
+
self.post_lay = config.post_lay
|
| 172 |
+
self.pna_layer_num = config.pna_layer_num
|
| 173 |
+
self.graph_add = config.graph_add"""
|
| 174 |
+
|
| 175 |
+
# Training configurations.
|
| 176 |
+
|
| 177 |
+
self.epoch = config.epoch
|
| 178 |
+
|
| 179 |
+
self.g_lr = config.g_lr
|
| 180 |
+
|
| 181 |
+
self.d_lr = config.d_lr
|
| 182 |
+
|
| 183 |
+
self.g2_lr = config.g2_lr
|
| 184 |
+
|
| 185 |
+
self.d2_lr = config.d2_lr
|
| 186 |
+
|
| 187 |
+
self.dropout = config.dropout
|
| 188 |
+
|
| 189 |
+
self.dec_dropout = config.dec_dropout
|
| 190 |
+
|
| 191 |
+
self.n_critic = config.n_critic
|
| 192 |
+
|
| 193 |
+
self.beta1 = config.beta1
|
| 194 |
+
|
| 195 |
+
self.beta2 = config.beta2
|
| 196 |
+
|
| 197 |
+
self.resume_iters = config.resume_iters
|
| 198 |
+
|
| 199 |
+
self.warm_up_steps = config.warm_up_steps
|
| 200 |
+
|
| 201 |
+
# Test configurations.
|
| 202 |
+
|
| 203 |
+
self.num_test_epoch = config.num_test_epoch
|
| 204 |
+
|
| 205 |
+
self.test_iters = config.test_iters
|
| 206 |
+
|
| 207 |
+
self.inference_sample_num = config.inference_sample_num
|
| 208 |
+
|
| 209 |
+
# Directories.
|
| 210 |
+
|
| 211 |
+
self.log_dir = config.log_dir
|
| 212 |
+
self.sample_dir = config.sample_dir
|
| 213 |
+
self.model_save_dir = config.model_save_dir
|
| 214 |
+
self.result_dir = config.result_dir
|
| 215 |
+
|
| 216 |
+
# Step size.
|
| 217 |
+
|
| 218 |
+
self.log_step = config.log_sample_step
|
| 219 |
+
self.clipping_value = config.clipping_value
|
| 220 |
+
# Miscellaneous.
|
| 221 |
+
|
| 222 |
+
self.mode = config.mode
|
| 223 |
+
|
| 224 |
+
self.noise_strength_0 = torch.nn.Parameter(torch.zeros([]))
|
| 225 |
+
self.noise_strength_1 = torch.nn.Parameter(torch.zeros([]))
|
| 226 |
+
self.noise_strength_2 = torch.nn.Parameter(torch.zeros([]))
|
| 227 |
+
self.noise_strength_3 = torch.nn.Parameter(torch.zeros([]))
|
| 228 |
+
|
| 229 |
+
self.init_type = config.init_type
|
| 230 |
+
self.build_model()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def build_model(self):
|
| 235 |
+
"""Create generators and discriminators."""
|
| 236 |
+
|
| 237 |
+
''' Generator is based on Transformer Encoder:
|
| 238 |
+
|
| 239 |
+
@ g_conv_dim: Dimensions for first MLP layers before Transformer Encoder
|
| 240 |
+
@ vertexes: maximum length of generated molecules (atom length)
|
| 241 |
+
@ b_dim: number of bond types
|
| 242 |
+
@ m_dim: number of atom types (or number of features used)
|
| 243 |
+
@ dropout: dropout possibility
|
| 244 |
+
@ dim: Hidden dimension of Transformer Encoder
|
| 245 |
+
@ depth: Transformer layer number
|
| 246 |
+
@ heads: Number of multihead-attention heads
|
| 247 |
+
@ mlp_ratio: Read-out layer dimension of Transformer
|
| 248 |
+
@ drop_rate: depricated
|
| 249 |
+
@ tra_conv: Whether module creates output for TransformerConv discriminator
|
| 250 |
+
'''
|
| 251 |
+
|
| 252 |
+
self.G = Generator(self.z_dim,
|
| 253 |
+
self.act,
|
| 254 |
+
self.vertexes,
|
| 255 |
+
self.b_dim,
|
| 256 |
+
self.m_dim,
|
| 257 |
+
self.dropout,
|
| 258 |
+
dim=self.dim,
|
| 259 |
+
depth=self.depth,
|
| 260 |
+
heads=self.heads,
|
| 261 |
+
mlp_ratio=self.mlp_ratio,
|
| 262 |
+
submodel = self.submodel)
|
| 263 |
+
|
| 264 |
+
self.G2 = Generator2(self.dim,
|
| 265 |
+
self.dec_dim,
|
| 266 |
+
self.dec_depth,
|
| 267 |
+
self.dec_heads,
|
| 268 |
+
self.mlp_ratio,
|
| 269 |
+
self.dec_dropout,
|
| 270 |
+
self.drugs_m_dim,
|
| 271 |
+
self.drugs_b_dim,
|
| 272 |
+
self.submodel)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
''' Discriminator implementation with PNA:
|
| 277 |
+
|
| 278 |
+
@ deg: Degree distribution based on used data. (Created with _genDegree() function)
|
| 279 |
+
@ agg: aggregators used in PNA
|
| 280 |
+
@ sca: scalers used in PNA
|
| 281 |
+
@ pna_in_ch: First PNA hidden dimension
|
| 282 |
+
@ pna_out_ch: Last PNA hidden dimension
|
| 283 |
+
@ edge_dim: Edge hidden dimension
|
| 284 |
+
@ towers: Number of towers (Splitting the hidden dimension to multiple parallel processes)
|
| 285 |
+
@ pre_lay: Pre-transformation layer
|
| 286 |
+
@ post_lay: Post-transformation layer
|
| 287 |
+
@ pna_layer_num: number of PNA layers
|
| 288 |
+
@ graph_add: global pooling layer selection
|
| 289 |
+
'''
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
''' Discriminator implementation with Graph Convolution:
|
| 293 |
+
|
| 294 |
+
@ d_conv_dim: convolution dimensions for GCN
|
| 295 |
+
@ m_dim: number of atom types (or number of features used)
|
| 296 |
+
@ b_dim: number of bond types
|
| 297 |
+
@ dropout: dropout possibility
|
| 298 |
+
'''
|
| 299 |
+
|
| 300 |
+
''' Discriminator implementation with MLP:
|
| 301 |
+
|
| 302 |
+
@ act: Activation function for MLP
|
| 303 |
+
@ m_dim: number of atom types (or number of features used)
|
| 304 |
+
@ b_dim: number of bond types
|
| 305 |
+
@ dropout: dropout possibility
|
| 306 |
+
@ vertexes: maximum length of generated molecules (molecule length)
|
| 307 |
+
'''
|
| 308 |
+
|
| 309 |
+
#self.D = Discriminator_old(self.d_conv_dim, self.m_dim , self.b_dim, self.dropout, self.gcn_depth)
|
| 310 |
+
self.D2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
|
| 311 |
+
self.D = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
|
| 312 |
+
self.V = simple_disc("tanh", self.m_dim, self.vertexes, self.b_dim)
|
| 313 |
+
self.V2 = simple_disc("tanh", self.drugs_m_dim, self.drug_vertexes, self.drugs_b_dim)
|
| 314 |
+
|
| 315 |
+
''' Optimizers for G1, G2, D1, and D2:
|
| 316 |
+
|
| 317 |
+
Adam Optimizer is used and different beta1 and beta2s are used for GAN1 and GAN2
|
| 318 |
+
'''
|
| 319 |
+
|
| 320 |
+
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
|
| 321 |
+
self.g2_optimizer = torch.optim.AdamW(self.G2.parameters(), self.g2_lr, [self.beta1, self.beta2])
|
| 322 |
+
|
| 323 |
+
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
|
| 324 |
+
self.d2_optimizer = torch.optim.AdamW(self.D2.parameters(), self.d2_lr, [self.beta1, self.beta2])
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
self.v_optimizer = torch.optim.AdamW(self.V.parameters(), self.d_lr, [self.beta1, self.beta2])
|
| 329 |
+
self.v2_optimizer = torch.optim.AdamW(self.V2.parameters(), self.d2_lr, [self.beta1, self.beta2])
|
| 330 |
+
''' Learning rate scheduler:
|
| 331 |
+
|
| 332 |
+
Changes learning rate based on loss.
|
| 333 |
+
'''
|
| 334 |
+
|
| 335 |
+
#self.scheduler_g = ReduceLROnPlateau(self.g_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
#self.scheduler_d = ReduceLROnPlateau(self.d_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
| 339 |
+
|
| 340 |
+
#self.scheduler_v = ReduceLROnPlateau(self.v_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
| 341 |
+
#self.scheduler_g2 = ReduceLROnPlateau(self.g2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
| 342 |
+
#self.scheduler_d2 = ReduceLROnPlateau(self.d2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
| 343 |
+
#self.scheduler_v2 = ReduceLROnPlateau(self.v2_optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)
|
| 344 |
+
self.print_network(self.G, 'G')
|
| 345 |
+
self.print_network(self.D, 'D')
|
| 346 |
+
|
| 347 |
+
self.print_network(self.G2, 'G2')
|
| 348 |
+
self.print_network(self.D2, 'D2')
|
| 349 |
+
|
| 350 |
+
self.G.to(self.device)
|
| 351 |
+
self.D.to(self.device)
|
| 352 |
+
|
| 353 |
+
self.V.to(self.device)
|
| 354 |
+
self.V2.to(self.device)
|
| 355 |
+
self.G2.to(self.device)
|
| 356 |
+
self.D2.to(self.device)
|
| 357 |
+
|
| 358 |
+
#self.V2.to(self.device)
|
| 359 |
+
#self.modules_of_the_model = (self.G, self.D, self.G2, self.D2)
|
| 360 |
+
"""for p in self.G.parameters():
|
| 361 |
+
if p.dim() > 1:
|
| 362 |
+
if self.init_type == 'uniform':
|
| 363 |
+
torch.nn.init.xavier_uniform_(p)
|
| 364 |
+
elif self.init_type == 'normal':
|
| 365 |
+
torch.nn.init.xavier_normal_(p)
|
| 366 |
+
elif self.init_type == 'random_normal':
|
| 367 |
+
torch.nn.init.normal_(p, 0.0, 0.02)
|
| 368 |
+
for p in self.G2.parameters():
|
| 369 |
+
if p.dim() > 1:
|
| 370 |
+
if self.init_type == 'uniform':
|
| 371 |
+
torch.nn.init.xavier_uniform_(p)
|
| 372 |
+
elif self.init_type == 'normal':
|
| 373 |
+
torch.nn.init.xavier_normal_(p)
|
| 374 |
+
elif self.init_type == 'random_normal':
|
| 375 |
+
torch.nn.init.normal_(p, 0.0, 0.02)
|
| 376 |
+
if self.dis_select == "conv":
|
| 377 |
+
for p in self.D.parameters():
|
| 378 |
+
if p.dim() > 1:
|
| 379 |
+
if self.init_type == 'uniform':
|
| 380 |
+
torch.nn.init.xavier_uniform_(p)
|
| 381 |
+
elif self.init_type == 'normal':
|
| 382 |
+
torch.nn.init.xavier_normal_(p)
|
| 383 |
+
elif self.init_type == 'random_normal':
|
| 384 |
+
torch.nn.init.normal_(p, 0.0, 0.02)
|
| 385 |
+
|
| 386 |
+
if self.dis_select == "conv":
|
| 387 |
+
for p in self.D2.parameters():
|
| 388 |
+
if p.dim() > 1:
|
| 389 |
+
if self.init_type == 'uniform':
|
| 390 |
+
torch.nn.init.xavier_uniform_(p)
|
| 391 |
+
elif self.init_type == 'normal':
|
| 392 |
+
torch.nn.init.xavier_normal_(p)
|
| 393 |
+
elif self.init_type == 'random_normal':
|
| 394 |
+
torch.nn.init.normal_(p, 0.0, 0.02)"""
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def decoder_load(self, dictionary_name):
|
| 398 |
+
|
| 399 |
+
''' Loading the atom and bond decoders'''
|
| 400 |
+
|
| 401 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 402 |
+
|
| 403 |
+
return pickle.load(f)
|
| 404 |
+
|
| 405 |
+
def drug_decoder_load(self, dictionary_name):
|
| 406 |
+
|
| 407 |
+
''' Loading the atom and bond decoders'''
|
| 408 |
+
|
| 409 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
| 410 |
+
|
| 411 |
+
return pickle.load(f)
|
| 412 |
+
|
| 413 |
+
def print_network(self, model, name):
|
| 414 |
+
|
| 415 |
+
"""Print out the network information."""
|
| 416 |
+
|
| 417 |
+
num_params = 0
|
| 418 |
+
for p in model.parameters():
|
| 419 |
+
num_params += p.numel()
|
| 420 |
+
print(model)
|
| 421 |
+
print(name)
|
| 422 |
+
print("The number of parameters: {}".format(num_params))
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def restore_model(self, epoch, iteration, model_directory):
|
| 426 |
+
|
| 427 |
+
"""Restore the trained generator and discriminator."""
|
| 428 |
+
|
| 429 |
+
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
| 430 |
+
|
| 431 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
| 432 |
+
#D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
|
| 433 |
+
|
| 434 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 435 |
+
#self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(epoch, iteration))
|
| 439 |
+
#D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(epoch, iteration))
|
| 440 |
+
|
| 441 |
+
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 442 |
+
#self.D2.load_state_dict(torch.load(D2_path, map_location=lambda storage, loc: storage))
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def save_model(self, model_directory, idx,i):
|
| 446 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
|
| 447 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
|
| 448 |
+
torch.save(self.G.state_dict(), G_path)
|
| 449 |
+
torch.save(self.D.state_dict(), D_path)
|
| 450 |
+
|
| 451 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 452 |
+
G2_path = os.path.join(model_directory, '{}-{}-G2.ckpt'.format(idx+1,i+1))
|
| 453 |
+
D2_path = os.path.join(model_directory, '{}-{}-D2.ckpt'.format(idx+1,i+1))
|
| 454 |
+
|
| 455 |
+
torch.save(self.G2.state_dict(), G2_path)
|
| 456 |
+
torch.save(self.D2.state_dict(), D2_path)
|
| 457 |
+
|
| 458 |
+
def reset_grad(self):
|
| 459 |
+
|
| 460 |
+
"""Reset the gradient buffers."""
|
| 461 |
+
|
| 462 |
+
self.g_optimizer.zero_grad()
|
| 463 |
+
self.v_optimizer.zero_grad()
|
| 464 |
+
self.g2_optimizer.zero_grad()
|
| 465 |
+
self.v2_optimizer.zero_grad()
|
| 466 |
+
|
| 467 |
+
self.d_optimizer.zero_grad()
|
| 468 |
+
self.d2_optimizer.zero_grad()
|
| 469 |
+
|
| 470 |
+
def gradient_penalty(self, y, x):
|
| 471 |
+
|
| 472 |
+
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
|
| 473 |
+
|
| 474 |
+
weight = torch.ones(y.size(),requires_grad=False).to(self.device)
|
| 475 |
+
dydx = torch.autograd.grad(outputs=y,
|
| 476 |
+
inputs=x,
|
| 477 |
+
grad_outputs=weight,
|
| 478 |
+
retain_graph=True,
|
| 479 |
+
create_graph=True,
|
| 480 |
+
only_inputs=True)[0]
|
| 481 |
+
|
| 482 |
+
dydx = dydx.view(dydx.size(0), -1)
|
| 483 |
+
gradient_penalty = ((dydx.norm(2, dim=1) - 1) ** 2).mean()
|
| 484 |
+
|
| 485 |
+
return gradient_penalty
|
| 486 |
+
|
| 487 |
+
def train(self):
|
| 488 |
+
|
| 489 |
+
''' Training Script starts from here'''
|
| 490 |
+
|
| 491 |
+
#wandb.config = {'beta2': 0.999}
|
| 492 |
+
#wandb.init(project="DrugGEN2", entity="atabeyunlu")
|
| 493 |
+
|
| 494 |
+
# Defining sampling paths and creating logger
|
| 495 |
+
|
| 496 |
+
self.arguments = "{}_glr{}_dlr{}_g2lr{}_d2lr{}_dim{}_depth{}_heads{}_decdepth{}_decheads{}_ncritic{}_batch{}_epoch{}_warmup{}_dataset{}_dropout{}".format(self.submodel,self.g_lr,self.d_lr,self.g2_lr,self.d2_lr,self.dim,self.depth,self.heads,self.dec_depth,self.dec_heads,self.n_critic,self.batch_size,self.epoch,self.warm_up_steps,self.dataset_name,self.dropout)
|
| 497 |
+
|
| 498 |
+
self.model_directory= os.path.join(self.model_save_dir,self.arguments)
|
| 499 |
+
self.sample_directory=os.path.join(self.sample_dir,self.arguments)
|
| 500 |
+
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
|
| 501 |
+
if not os.path.exists(self.model_directory):
|
| 502 |
+
os.makedirs(self.model_directory)
|
| 503 |
+
if not os.path.exists(self.sample_directory):
|
| 504 |
+
os.makedirs(self.sample_directory)
|
| 505 |
+
|
| 506 |
+
# Learning rate cache for decaying.
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# protein data
|
| 510 |
+
full_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
|
| 511 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_train.smi", 'r').read().splitlines()]
|
| 512 |
+
|
| 513 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 514 |
+
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 515 |
+
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 516 |
+
|
| 517 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 518 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 519 |
+
|
| 520 |
+
# Start training.
|
| 521 |
+
|
| 522 |
+
print('Start training...')
|
| 523 |
+
self.start_time = time.time()
|
| 524 |
+
for idx in range(self.epoch):
|
| 525 |
+
|
| 526 |
+
# =================================================================================== #
|
| 527 |
+
# 1. Preprocess input data #
|
| 528 |
+
# =================================================================================== #
|
| 529 |
+
|
| 530 |
+
# Load the data
|
| 531 |
+
|
| 532 |
+
dataloader_iterator = iter(self.drugs_loader)
|
| 533 |
+
|
| 534 |
+
for i, data in enumerate(self.loader):
|
| 535 |
+
try:
|
| 536 |
+
drugs = next(dataloader_iterator)
|
| 537 |
+
except StopIteration:
|
| 538 |
+
dataloader_iterator = iter(self.drugs_loader)
|
| 539 |
+
drugs = next(dataloader_iterator)
|
| 540 |
+
|
| 541 |
+
# Preprocess both dataset
|
| 542 |
+
|
| 543 |
+
bulk_data = load_data(data,
|
| 544 |
+
drugs,
|
| 545 |
+
self.batch_size,
|
| 546 |
+
self.device,
|
| 547 |
+
self.b_dim,
|
| 548 |
+
self.m_dim,
|
| 549 |
+
self.drugs_b_dim,
|
| 550 |
+
self.drugs_m_dim,
|
| 551 |
+
self.z_dim,
|
| 552 |
+
self.vertexes)
|
| 553 |
+
|
| 554 |
+
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
| 555 |
+
|
| 556 |
+
if self.submodel == "CrossLoss":
|
| 557 |
+
GAN1_input_e = drugs_a_tensor
|
| 558 |
+
GAN1_input_x = drugs_x_tensor
|
| 559 |
+
GAN1_disc_e = a_tensor
|
| 560 |
+
GAN1_disc_x = x_tensor
|
| 561 |
+
elif self.submodel == "Ligand":
|
| 562 |
+
GAN1_input_e = a_tensor
|
| 563 |
+
GAN1_input_x = x_tensor
|
| 564 |
+
GAN1_disc_e = a_tensor
|
| 565 |
+
GAN1_disc_x = x_tensor
|
| 566 |
+
GAN2_input_e = drugs_a_tensor
|
| 567 |
+
GAN2_input_x = drugs_x_tensor
|
| 568 |
+
GAN2_disc_e = drugs_a_tensor
|
| 569 |
+
GAN2_disc_x = drugs_x_tensor
|
| 570 |
+
elif self.submodel == "Prot":
|
| 571 |
+
GAN1_input_e = a_tensor
|
| 572 |
+
GAN1_input_x = x_tensor
|
| 573 |
+
GAN1_disc_e = a_tensor
|
| 574 |
+
GAN1_disc_x = x_tensor
|
| 575 |
+
GAN2_input_e = akt1_human_adj
|
| 576 |
+
GAN2_input_x = akt1_human_annot
|
| 577 |
+
GAN2_disc_e = drugs_a_tensor
|
| 578 |
+
GAN2_disc_x = drugs_x_tensor
|
| 579 |
+
elif self.submodel == "RL":
|
| 580 |
+
GAN1_input_e = z_edge
|
| 581 |
+
GAN1_input_x = z_node
|
| 582 |
+
GAN1_disc_e = a_tensor
|
| 583 |
+
GAN1_disc_x = x_tensor
|
| 584 |
+
GAN2_input_e = drugs_a_tensor
|
| 585 |
+
GAN2_input_x = drugs_x_tensor
|
| 586 |
+
GAN2_disc_e = drugs_a_tensor
|
| 587 |
+
GAN2_disc_x = drugs_x_tensor
|
| 588 |
+
elif self.submodel == "NoTarget":
|
| 589 |
+
GAN1_input_e = z_edge
|
| 590 |
+
GAN1_input_x = z_node
|
| 591 |
+
GAN1_disc_e = a_tensor
|
| 592 |
+
GAN1_disc_x = x_tensor
|
| 593 |
+
|
| 594 |
+
# =================================================================================== #
|
| 595 |
+
# 2. Train the discriminator #
|
| 596 |
+
# =================================================================================== #
|
| 597 |
+
loss = {}
|
| 598 |
+
self.reset_grad()
|
| 599 |
+
|
| 600 |
+
# Compute discriminator loss.
|
| 601 |
+
|
| 602 |
+
node, edge, d_loss = discriminator_loss(self.G,
|
| 603 |
+
self.D,
|
| 604 |
+
real_graphs,
|
| 605 |
+
GAN1_disc_e,
|
| 606 |
+
GAN1_disc_x,
|
| 607 |
+
self.batch_size,
|
| 608 |
+
self.device,
|
| 609 |
+
self.gradient_penalty,
|
| 610 |
+
self.lambda_gp,
|
| 611 |
+
GAN1_input_e,
|
| 612 |
+
GAN1_input_x)
|
| 613 |
+
|
| 614 |
+
d_total = d_loss
|
| 615 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 616 |
+
d2_loss = discriminator2_loss(self.G2,
|
| 617 |
+
self.D2,
|
| 618 |
+
drug_graphs,
|
| 619 |
+
edge,
|
| 620 |
+
node,
|
| 621 |
+
self.batch_size,
|
| 622 |
+
self.device,
|
| 623 |
+
self.gradient_penalty,
|
| 624 |
+
self.lambda_gp,
|
| 625 |
+
GAN2_input_e,
|
| 626 |
+
GAN2_input_x)
|
| 627 |
+
d_total = d_loss + d2_loss
|
| 628 |
+
|
| 629 |
+
loss["d_total"] = d_total.item()
|
| 630 |
+
d_total.backward()
|
| 631 |
+
self.d_optimizer.step()
|
| 632 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 633 |
+
self.d2_optimizer.step()
|
| 634 |
+
self.reset_grad()
|
| 635 |
+
generator_output = generator_loss(self.G,
|
| 636 |
+
self.D,
|
| 637 |
+
self.V,
|
| 638 |
+
GAN1_input_e,
|
| 639 |
+
GAN1_input_x,
|
| 640 |
+
self.batch_size,
|
| 641 |
+
sim_reward,
|
| 642 |
+
self.dataset.matrices2mol_drugs,
|
| 643 |
+
fps_r,
|
| 644 |
+
self.submodel)
|
| 645 |
+
|
| 646 |
+
g_loss, fake_mol, g_edges_hat_sample, g_nodes_hat_sample, node, edge = generator_output
|
| 647 |
+
|
| 648 |
+
self.reset_grad()
|
| 649 |
+
g_total = g_loss
|
| 650 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 651 |
+
output = generator2_loss(self.G2,
|
| 652 |
+
self.D2,
|
| 653 |
+
self.V2,
|
| 654 |
+
edge,
|
| 655 |
+
node,
|
| 656 |
+
self.batch_size,
|
| 657 |
+
sim_reward,
|
| 658 |
+
self.dataset.matrices2mol_drugs,
|
| 659 |
+
fps_r,
|
| 660 |
+
GAN2_input_e,
|
| 661 |
+
GAN2_input_x,
|
| 662 |
+
self.submodel)
|
| 663 |
+
|
| 664 |
+
g2_loss, fake_mol_g, dr_g_edges_hat_sample, dr_g_nodes_hat_sample = output
|
| 665 |
+
|
| 666 |
+
g_total = g_loss + g2_loss
|
| 667 |
+
|
| 668 |
+
loss["g_total"] = g_total.item()
|
| 669 |
+
g_total.backward()
|
| 670 |
+
self.g_optimizer.step()
|
| 671 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 672 |
+
self.g2_optimizer.step()
|
| 673 |
+
|
| 674 |
+
if self.submodel == "RL":
|
| 675 |
+
self.v_optimizer.step()
|
| 676 |
+
self.v2_optimizer.step()
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
if (i+1) % self.log_step == 0:
|
| 680 |
+
|
| 681 |
+
logging(self.log_path, self.start_time, fake_mol, full_smiles, i, idx, loss, 1,self.sample_directory)
|
| 682 |
+
mol_sample(self.sample_directory,"GAN1",fake_mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), idx, i)
|
| 683 |
+
if self.submodel != "NoTarget" and self.submodel != "CrossLoss":
|
| 684 |
+
logging(self.log_path, self.start_time, fake_mol_g, drug_smiles, i, idx, loss, 2,self.sample_directory)
|
| 685 |
+
mol_sample(self.sample_directory,"GAN2",fake_mol_g, dr_g_edges_hat_sample.detach(), dr_g_nodes_hat_sample.detach(), idx, i)
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
if (idx+1) % 10 == 0:
|
| 689 |
+
self.save_model(self.model_directory,idx,i)
|
| 690 |
+
print("model saved at epoch {} and iteration {}".format(idx,i))
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def inference(self):
|
| 695 |
+
|
| 696 |
+
# Load the trained generator.
|
| 697 |
+
self.G.to(self.device)
|
| 698 |
+
#self.D.to(self.device)
|
| 699 |
+
self.G2.to(self.device)
|
| 700 |
+
#self.D2.to(self.device)
|
| 701 |
+
|
| 702 |
+
G_path = os.path.join(self.inference_model, '{}-G.ckpt'.format(self.submodel))
|
| 703 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
| 704 |
+
G2_path = os.path.join(self.inference_model, '{}-G2.ckpt'.format(self.submodel))
|
| 705 |
+
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_test.smi", 'r').read().splitlines()]
|
| 709 |
+
|
| 710 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 711 |
+
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 712 |
+
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 713 |
+
|
| 714 |
+
akt1_human_adj = torch.load("DrugGEN/data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 715 |
+
akt1_human_annot = torch.load("DrugGEN/data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 716 |
+
|
| 717 |
+
self.G.eval()
|
| 718 |
+
#self.D.eval()
|
| 719 |
+
self.G2.eval()
|
| 720 |
+
#self.D2.eval()
|
| 721 |
+
|
| 722 |
+
self.inf_batch_size =256
|
| 723 |
+
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
| 724 |
+
self.inf_dataset_file,
|
| 725 |
+
self.inf_raw_file,
|
| 726 |
+
self.max_atom,
|
| 727 |
+
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
|
| 728 |
+
# Can create any molecular graph dataset given smiles string.
|
| 729 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
| 730 |
+
# Uses sparse matrix representation for graphs,
|
| 731 |
+
# For computational and speed efficiency.
|
| 732 |
+
|
| 733 |
+
self.inf_loader = DataLoader(self.inf_dataset,
|
| 734 |
+
shuffle=True,
|
| 735 |
+
batch_size=self.inf_batch_size,
|
| 736 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
| 737 |
+
|
| 738 |
+
self.inf_drugs = DruggenDataset(self.drug_data_dir,
|
| 739 |
+
self.inf_drugs_dataset_file,
|
| 740 |
+
self.inf_drug_raw_file,
|
| 741 |
+
self.max_atom,
|
| 742 |
+
self.features) # Dataset for the second GAN. Custom dataset class from PyG parent class.
|
| 743 |
+
# Can create any molecular graph dataset given smiles string.
|
| 744 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
| 745 |
+
# Uses sparse matrix representation for graphs,
|
| 746 |
+
# For computational and speed efficiency.
|
| 747 |
+
|
| 748 |
+
self.inf_drugs_loader = DataLoader(self.inf_drugs,
|
| 749 |
+
shuffle=True,
|
| 750 |
+
batch_size=self.inf_batch_size,
|
| 751 |
+
drop_last=True) # PyG dataloader for the second GAN.
|
| 752 |
+
start_time = time.time()
|
| 753 |
+
#metric_calc_mol = []
|
| 754 |
+
metric_calc_dr = []
|
| 755 |
+
date = time.time()
|
| 756 |
+
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
|
| 757 |
+
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
|
| 758 |
+
with torch.inference_mode():
|
| 759 |
+
|
| 760 |
+
dataloader_iterator = iter(self.drugs_loader)
|
| 761 |
+
|
| 762 |
+
for i, data in enumerate(self.loader):
|
| 763 |
+
try:
|
| 764 |
+
drugs = next(dataloader_iterator)
|
| 765 |
+
except StopIteration:
|
| 766 |
+
dataloader_iterator = iter(self.drugs_loader)
|
| 767 |
+
drugs = next(dataloader_iterator)
|
| 768 |
+
|
| 769 |
+
# Preprocess both dataset
|
| 770 |
+
|
| 771 |
+
bulk_data = load_data(data,
|
| 772 |
+
drugs,
|
| 773 |
+
self.batch_size,
|
| 774 |
+
self.device,
|
| 775 |
+
self.b_dim,
|
| 776 |
+
self.m_dim,
|
| 777 |
+
self.drugs_b_dim,
|
| 778 |
+
self.drugs_m_dim,
|
| 779 |
+
self.z_dim,
|
| 780 |
+
self.vertexes)
|
| 781 |
+
|
| 782 |
+
drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node = bulk_data
|
| 783 |
+
|
| 784 |
+
if self.submodel == "CrossLoss":
|
| 785 |
+
GAN1_input_e = a_tensor
|
| 786 |
+
GAN1_input_x = x_tensor
|
| 787 |
+
GAN1_disc_e = drugs_a_tensor
|
| 788 |
+
GAN1_disc_x = drugs_x_tensor
|
| 789 |
+
GAN2_input_e = drugs_a_tensor
|
| 790 |
+
GAN2_input_x = drugs_x_tensor
|
| 791 |
+
GAN2_disc_e = a_tensor
|
| 792 |
+
GAN2_disc_x = x_tensor
|
| 793 |
+
elif self.submodel == "Ligand":
|
| 794 |
+
GAN1_input_e = a_tensor
|
| 795 |
+
GAN1_input_x = x_tensor
|
| 796 |
+
GAN1_disc_e = a_tensor
|
| 797 |
+
GAN1_disc_x = x_tensor
|
| 798 |
+
GAN2_input_e = drugs_a_tensor
|
| 799 |
+
GAN2_input_x = drugs_x_tensor
|
| 800 |
+
GAN2_disc_e = drugs_a_tensor
|
| 801 |
+
GAN2_disc_x = drugs_x_tensor
|
| 802 |
+
elif self.submodel == "Prot":
|
| 803 |
+
GAN1_input_e = a_tensor
|
| 804 |
+
GAN1_input_x = x_tensor
|
| 805 |
+
GAN1_disc_e = a_tensor
|
| 806 |
+
GAN1_disc_x = x_tensor
|
| 807 |
+
GAN2_input_e = akt1_human_adj
|
| 808 |
+
GAN2_input_x = akt1_human_annot
|
| 809 |
+
GAN2_disc_e = drugs_a_tensor
|
| 810 |
+
GAN2_disc_x = drugs_x_tensor
|
| 811 |
+
elif self.submodel == "RL":
|
| 812 |
+
GAN1_input_e = z_edge
|
| 813 |
+
GAN1_input_x = z_node
|
| 814 |
+
GAN1_disc_e = a_tensor
|
| 815 |
+
GAN1_disc_x = x_tensor
|
| 816 |
+
GAN2_input_e = drugs_a_tensor
|
| 817 |
+
GAN2_input_x = drugs_x_tensor
|
| 818 |
+
GAN2_disc_e = drugs_a_tensor
|
| 819 |
+
GAN2_disc_x = drugs_x_tensor
|
| 820 |
+
elif self.submodel == "NoTarget":
|
| 821 |
+
GAN1_input_e = z_edge
|
| 822 |
+
GAN1_input_x = z_node
|
| 823 |
+
GAN1_disc_e = a_tensor
|
| 824 |
+
GAN1_disc_x = x_tensor
|
| 825 |
+
# =================================================================================== #
|
| 826 |
+
# 2. GAN1 Inference #
|
| 827 |
+
# =================================================================================== #
|
| 828 |
+
generator_output = generator_loss(self.G,
|
| 829 |
+
self.D,
|
| 830 |
+
self.V,
|
| 831 |
+
GAN1_input_e,
|
| 832 |
+
GAN1_input_x,
|
| 833 |
+
self.batch_size,
|
| 834 |
+
sim_reward,
|
| 835 |
+
self.dataset.matrices2mol_drugs,
|
| 836 |
+
fps_r,
|
| 837 |
+
self.submodel)
|
| 838 |
+
|
| 839 |
+
_, fake_mol, _, _, node, edge = generator_output
|
| 840 |
+
|
| 841 |
+
# =================================================================================== #
|
| 842 |
+
# 3. GAN2 Inference #
|
| 843 |
+
# =================================================================================== #
|
| 844 |
+
|
| 845 |
+
output = generator2_loss(self.G2,
|
| 846 |
+
self.D2,
|
| 847 |
+
self.V2,
|
| 848 |
+
edge,
|
| 849 |
+
node,
|
| 850 |
+
self.batch_size,
|
| 851 |
+
sim_reward,
|
| 852 |
+
self.dataset.matrices2mol_drugs,
|
| 853 |
+
fps_r,
|
| 854 |
+
GAN2_input_e,
|
| 855 |
+
GAN2_input_x,
|
| 856 |
+
self.submodel)
|
| 857 |
+
|
| 858 |
+
_, fake_mol_g, _, _ = output
|
| 859 |
+
|
| 860 |
+
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
#inference_smiles = [Chem.MolToSmiles(line) for line in fake_mol]
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
print("molecule batch {} inferred".format(i))
|
| 869 |
+
|
| 870 |
+
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
| 871 |
+
for molecules in inference_drugs:
|
| 872 |
+
|
| 873 |
+
f.write(molecules)
|
| 874 |
+
f.write("\n")
|
| 875 |
+
metric_calc_dr.append(molecules)
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
if i == 120:
|
| 880 |
+
break
|
| 881 |
+
|
| 882 |
+
et = time.time() - start_time
|
| 883 |
+
|
| 884 |
+
print("Inference mode is lasted for {:.2f} seconds".format(et))
|
| 885 |
+
|
| 886 |
+
print("Metrics calculation started using MOSES.")
|
| 887 |
+
|
| 888 |
+
print("Validity: ", fraction_valid(inference_drugs), "\n")
|
| 889 |
+
print("Uniqueness: ", fraction_unique(inference_drugs), "\n")
|
| 890 |
+
print("Validity: ", novelty(inference_drugs, drug_smiles), "\n")
|
| 891 |
+
|
| 892 |
+
print("Metrics are calculated.")
|
training_data.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch_geometric.utils as geoutils
|
| 3 |
+
from utils import *
|
| 4 |
+
|
| 5 |
+
def load_data(data, drugs, batch_size, device, b_dim, m_dim, drugs_b_dim, drugs_m_dim,z_dim,vertexes):
|
| 6 |
+
|
| 7 |
+
z = sample_z(batch_size, z_dim) # (batch,max_len)
|
| 8 |
+
|
| 9 |
+
z = torch.from_numpy(z).to(device).float().requires_grad_(True)
|
| 10 |
+
data = data.to(device)
|
| 11 |
+
drugs = drugs.to(device)
|
| 12 |
+
z_e = sample_z_edge(batch_size,vertexes,b_dim) # (batch,max_len,max_len)
|
| 13 |
+
z_n = sample_z_node(batch_size,vertexes,m_dim) # (batch,max_len)
|
| 14 |
+
z_edge = torch.from_numpy(z_e).to(device).float().requires_grad_(True) # Edge noise.(batch,max_len,max_len)
|
| 15 |
+
z_node = torch.from_numpy(z_n).to(device).float().requires_grad_(True) # Node noise.(batch,max_len)
|
| 16 |
+
a = geoutils.to_dense_adj(edge_index = data.edge_index,batch=data.batch,edge_attr=data.edge_attr, max_num_nodes=int(data.batch.shape[0]/batch_size))
|
| 17 |
+
x = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
|
| 18 |
+
|
| 19 |
+
a_tensor = label2onehot(a, b_dim, device)
|
| 20 |
+
#x_tensor = label2onehot(x, m_dim)
|
| 21 |
+
x_tensor = x
|
| 22 |
+
|
| 23 |
+
a_tensor = a_tensor #+ torch.randn([a_tensor.size(0), a_tensor.size(1), a_tensor.size(2),1], device=a_tensor.device) * noise_strength_0
|
| 24 |
+
x_tensor = x_tensor #+ torch.randn([x_tensor.size(0), x_tensor.size(1),1], device=x_tensor.device) * noise_strength_1
|
| 25 |
+
|
| 26 |
+
drugs_a = geoutils.to_dense_adj(edge_index = drugs.edge_index,batch=drugs.batch,edge_attr=drugs.edge_attr, max_num_nodes=int(drugs.batch.shape[0]/batch_size))
|
| 27 |
+
|
| 28 |
+
drugs_x = drugs.x.view(batch_size,int(drugs.batch.shape[0]/batch_size),-1)
|
| 29 |
+
|
| 30 |
+
drugs_a = drugs_a.to(device).long()
|
| 31 |
+
drugs_x = drugs_x.to(device)
|
| 32 |
+
drugs_a_tensor = label2onehot(drugs_a, drugs_b_dim,device).float()
|
| 33 |
+
drugs_x_tensor = drugs_x
|
| 34 |
+
|
| 35 |
+
drugs_a_tensor = drugs_a_tensor #+ torch.randn([drugs_a_tensor.size(0), drugs_a_tensor.size(1), drugs_a_tensor.size(2),1], device=drugs_a_tensor.device) * noise_strength_2
|
| 36 |
+
drugs_x_tensor = drugs_x_tensor #+ torch.randn([drugs_x_tensor.size(0), drugs_x_tensor.size(1),1], device=drugs_x_tensor.device) * noise_strength_3
|
| 37 |
+
#prot_n = akt1_human_annot[None,:].to(device).float()
|
| 38 |
+
#prot_e = akt1_human_adj[None,None,:].view(1,546,546,1).to(device).float()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
a_tensor_vec = a_tensor.reshape(batch_size,-1)
|
| 43 |
+
x_tensor_vec = x_tensor.reshape(batch_size,-1)
|
| 44 |
+
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
|
| 45 |
+
|
| 46 |
+
a_drug_vec = drugs_a_tensor.reshape(batch_size,-1)
|
| 47 |
+
x_drug_vec = drugs_x_tensor.reshape(batch_size,-1)
|
| 48 |
+
drug_graphs = torch.concat((x_drug_vec,a_drug_vec),dim=-1)
|
| 49 |
+
|
| 50 |
+
return drug_graphs, real_graphs, a_tensor, x_tensor, drugs_a_tensor, drugs_x_tensor, z, z_edge, z_node
|
utils.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from statistics import mean
|
| 2 |
+
from rdkit import DataStructs
|
| 3 |
+
from rdkit import Chem
|
| 4 |
+
from rdkit.Chem import AllChem
|
| 5 |
+
from rdkit.Chem import Draw
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from matplotlib.lines import Line2D
|
| 11 |
+
from rdkit import RDLogger
|
| 12 |
+
import torch
|
| 13 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
| 14 |
+
import math
|
| 15 |
+
import time
|
| 16 |
+
import datetime
|
| 17 |
+
import re
|
| 18 |
+
RDLogger.DisableLog('rdApp.*')
|
| 19 |
+
import warnings
|
| 20 |
+
from multiprocessing import Pool
|
| 21 |
+
class Metrics(object):
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def valid(x):
|
| 25 |
+
return x is not None and Chem.MolToSmiles(x) != ''
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def tanimoto_sim_1v2(data1, data2):
|
| 29 |
+
min_len = data1.size if data1.size > data2.size else data2
|
| 30 |
+
sims = []
|
| 31 |
+
for i in range(min_len):
|
| 32 |
+
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
|
| 33 |
+
sims.append(sim)
|
| 34 |
+
mean_sim = mean(sim)
|
| 35 |
+
return mean_sim
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def mol_length(x):
|
| 39 |
+
if x is not None:
|
| 40 |
+
return len([char for char in max(Chem.MolToSmiles(x).split(sep =".")).upper() if char.isalpha()])
|
| 41 |
+
else:
|
| 42 |
+
return 0
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def max_component(data, max_len):
|
| 46 |
+
|
| 47 |
+
return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
|
| 48 |
+
|
| 49 |
+
def sim_reward(mol_gen, fps_r):
|
| 50 |
+
|
| 51 |
+
gen_scaf = []
|
| 52 |
+
|
| 53 |
+
for x in mol_gen:
|
| 54 |
+
if x is not None:
|
| 55 |
+
try:
|
| 56 |
+
|
| 57 |
+
gen_scaf.append(MurckoScaffold.GetScaffoldForMol(x))
|
| 58 |
+
except:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
if len(gen_scaf) == 0:
|
| 62 |
+
|
| 63 |
+
rew = 1
|
| 64 |
+
else:
|
| 65 |
+
fps = [Chem.RDKFingerprint(x) for x in gen_scaf]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
fps = np.array(fps)
|
| 69 |
+
fps_r = np.array(fps_r)
|
| 70 |
+
|
| 71 |
+
rew = average_agg_tanimoto(fps_r,fps)[0]
|
| 72 |
+
if math.isnan(rew):
|
| 73 |
+
rew = 1
|
| 74 |
+
|
| 75 |
+
return rew ## change this to penalty
|
| 76 |
+
|
| 77 |
+
##########################################
|
| 78 |
+
##########################################
|
| 79 |
+
##########################################
|
| 80 |
+
|
| 81 |
+
def mols2grid_image(mols,path):
|
| 82 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
| 83 |
+
|
| 84 |
+
for i in range(len(mols)):
|
| 85 |
+
if Metrics.valid(mols[i]):
|
| 86 |
+
#if Chem.MolToSmiles(mols[i]) != '':
|
| 87 |
+
AllChem.Compute2DCoords(mols[i])
|
| 88 |
+
Draw.MolToFile(mols[i], os.path.join(path,"{}.png".format(i+1)), size=(1200,1200))
|
| 89 |
+
else:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
def save_smiles_matrices(mols,edges_hard, nodes_hard,path,data_source = None):
|
| 93 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
| 94 |
+
|
| 95 |
+
for i in range(len(mols)):
|
| 96 |
+
if Metrics.valid(mols[i]):
|
| 97 |
+
#m0= all_scores_for_print(mols[i], data_source, norm=False)
|
| 98 |
+
#if Chem.MolToSmiles(mols[i]) != '':
|
| 99 |
+
save_path = os.path.join(path,"{}.txt".format(i+1))
|
| 100 |
+
with open(save_path, "a") as f:
|
| 101 |
+
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n",fmt='%1.2f')
|
| 102 |
+
f.write("\n")
|
| 103 |
+
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:",fmt='%1.2f')
|
| 104 |
+
f.write("\n")
|
| 105 |
+
#f.write(m0)
|
| 106 |
+
f.write("\n")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
print(Chem.MolToSmiles(mols[i]), file=open(save_path,"a"))
|
| 110 |
+
else:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
##########################################
|
| 114 |
+
##########################################
|
| 115 |
+
##########################################
|
| 116 |
+
|
| 117 |
+
def dense_to_sparse_with_attr(adj):
|
| 118 |
+
###
|
| 119 |
+
assert adj.dim() >= 2 and adj.dim() <= 3
|
| 120 |
+
assert adj.size(-1) == adj.size(-2)
|
| 121 |
+
|
| 122 |
+
index = adj.nonzero(as_tuple=True)
|
| 123 |
+
edge_attr = adj[index]
|
| 124 |
+
|
| 125 |
+
if len(index) == 3:
|
| 126 |
+
batch = index[0] * adj.size(-1)
|
| 127 |
+
index = (batch + index[1], batch + index[2])
|
| 128 |
+
#index = torch.stack(index, dim=0)
|
| 129 |
+
return index, edge_attr
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def label2onehot(labels, dim, device):
|
| 133 |
+
|
| 134 |
+
"""Convert label indices to one-hot vectors."""
|
| 135 |
+
|
| 136 |
+
out = torch.zeros(list(labels.size())+[dim]).to(device)
|
| 137 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
| 138 |
+
|
| 139 |
+
return out.float()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def sample_z_node(batch_size, vertexes, nodes):
|
| 143 |
+
|
| 144 |
+
''' Random noise for nodes logits. '''
|
| 145 |
+
|
| 146 |
+
return np.random.normal(0,1, size=(batch_size,vertexes, nodes)) # 128, 9, 5
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def sample_z_edge(batch_size, vertexes, edges):
|
| 150 |
+
|
| 151 |
+
''' Random noise for edges logits. '''
|
| 152 |
+
|
| 153 |
+
return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
|
| 154 |
+
|
| 155 |
+
def sample_z( batch_size, z_dim):
|
| 156 |
+
|
| 157 |
+
''' Random noise. '''
|
| 158 |
+
|
| 159 |
+
return np.random.normal(0,1, size=(batch_size,z_dim)) # 128, 9, 5
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def mol_sample(sample_directory, model_name, mol, edges, nodes, idx, i):
|
| 163 |
+
sample_path = os.path.join(sample_directory,"{}-{}_{}-epoch_iteration".format(model_name,idx+1, i+1))
|
| 164 |
+
|
| 165 |
+
if not os.path.exists(sample_path):
|
| 166 |
+
os.makedirs(sample_path)
|
| 167 |
+
|
| 168 |
+
mols2grid_image(mol,sample_path)
|
| 169 |
+
|
| 170 |
+
save_smiles_matrices(mol,edges.detach(), nodes.detach(), sample_path)
|
| 171 |
+
|
| 172 |
+
if len(os.listdir(sample_path)) == 0:
|
| 173 |
+
os.rmdir(sample_path)
|
| 174 |
+
|
| 175 |
+
print("Valid molecules are saved.")
|
| 176 |
+
print("Valid matrices and smiles are saved")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path):
|
| 183 |
+
|
| 184 |
+
gen_smiles = []
|
| 185 |
+
for line in mols:
|
| 186 |
+
if line is not None:
|
| 187 |
+
gen_smiles.append(Chem.MolToSmiles(line))
|
| 188 |
+
elif line is None:
|
| 189 |
+
gen_smiles.append(None)
|
| 190 |
+
|
| 191 |
+
#gen_smiles_saves = [None if x is None else re.sub('\*', '', x) for x in gen_smiles]
|
| 192 |
+
#gen_smiles_saves = [None if x is None else re.sub('\.', '', x) for x in gen_smiles_saves]
|
| 193 |
+
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
|
| 194 |
+
|
| 195 |
+
sample_save_dir = os.path.join(save_path, "samples-GAN{}.txt".format(model_num))
|
| 196 |
+
with open(sample_save_dir, "a") as f:
|
| 197 |
+
for idxs in range(len(gen_smiles_saves)):
|
| 198 |
+
if gen_smiles_saves[idxs] is not None:
|
| 199 |
+
|
| 200 |
+
f.write(gen_smiles_saves[idxs])
|
| 201 |
+
f.write("\n")
|
| 202 |
+
|
| 203 |
+
k = len(set(gen_smiles_saves) - {None})
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
et = time.time() - start_time
|
| 207 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
| 208 |
+
log = "Elapsed [{}], Epoch/Iteration [{}/{}] for GAN{}".format(et, idx, i+1, model_num)
|
| 209 |
+
|
| 210 |
+
# Log update
|
| 211 |
+
#m0 = get_all_metrics(gen = gen_smiles, train = train_smiles, batch_size=batch_size, k = valid_mol_num, device=self.device)
|
| 212 |
+
valid = fraction_valid(gen_smiles_saves)
|
| 213 |
+
unique = fraction_unique(gen_smiles_saves, k, check_validity=False)
|
| 214 |
+
novel = novelty(gen_smiles_saves, train_smiles)
|
| 215 |
+
|
| 216 |
+
#qed = [QED(mol) for mol in mols if mol is not None]
|
| 217 |
+
#sa = [SA(mol) for mol in mols if mol is not None]
|
| 218 |
+
#logp = [logP(mol) for mol in mols if mol is not None]
|
| 219 |
+
|
| 220 |
+
#IntDiv = internal_diversity(gen_smiles)
|
| 221 |
+
#m0= all_scores_val(fake_mol, mols, full_mols, full_smiles, vert, norm=True) # 'mols' is output of Fake Reward
|
| 222 |
+
#m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
|
| 223 |
+
#m0.update(m1)
|
| 224 |
+
|
| 225 |
+
#maxlen = MolecularMetrics.max_component(mols, 45)
|
| 226 |
+
|
| 227 |
+
#m0 = {k: np.array(v).mean() for k, v in m0.items()}
|
| 228 |
+
#loss.update(m0)
|
| 229 |
+
loss.update({'Valid': valid})
|
| 230 |
+
loss.update({'Unique@{}'.format(k): unique})
|
| 231 |
+
loss.update({'Novel': novel})
|
| 232 |
+
#loss.update({'QED': statistics.mean(qed)})
|
| 233 |
+
#loss.update({'SA': statistics.mean(sa)})
|
| 234 |
+
#loss.update({'LogP': statistics.mean(logp)})
|
| 235 |
+
#loss.update({'IntDiv': IntDiv})
|
| 236 |
+
|
| 237 |
+
#wandb.log({"maxlen": maxlen})
|
| 238 |
+
|
| 239 |
+
for tag, value in loss.items():
|
| 240 |
+
|
| 241 |
+
log += ", {}: {:.4f}".format(tag, value)
|
| 242 |
+
with open(log_path, "a") as f:
|
| 243 |
+
f.write(log)
|
| 244 |
+
f.write("\n")
|
| 245 |
+
print(log)
|
| 246 |
+
print("\n")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
|
| 251 |
+
|
| 252 |
+
cols = 4
|
| 253 |
+
rows = int(heads/cols)
|
| 254 |
+
|
| 255 |
+
fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
|
| 256 |
+
axes = axes.flat
|
| 257 |
+
attentions_pos = attn_w[0]
|
| 258 |
+
attentions_pos = attentions_pos.cpu().detach().numpy()
|
| 259 |
+
for i,att in enumerate(attentions_pos):
|
| 260 |
+
|
| 261 |
+
#im = axes[i].imshow(att, cmap='gray')
|
| 262 |
+
sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
|
| 263 |
+
axes[i].set_title(f'head - {i} ')
|
| 264 |
+
axes[i].set_ylabel('layers')
|
| 265 |
+
pltsavedir = "/home/atabey/attn/second"
|
| 266 |
+
plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def plot_grad_flow(named_parameters, model, iter, epoch):
|
| 270 |
+
|
| 271 |
+
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
|
| 272 |
+
'''Plots the gradients flowing through different layers in the net during training.
|
| 273 |
+
Can be used for checking for possible gradient vanishing / exploding problems.
|
| 274 |
+
|
| 275 |
+
Usage: Plug this function in Trainer class after loss.backwards() as
|
| 276 |
+
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
|
| 277 |
+
ave_grads = []
|
| 278 |
+
max_grads= []
|
| 279 |
+
layers = []
|
| 280 |
+
for n, p in named_parameters:
|
| 281 |
+
if(p.requires_grad) and ("bias" not in n):
|
| 282 |
+
print(p.grad,n)
|
| 283 |
+
layers.append(n)
|
| 284 |
+
ave_grads.append(p.grad.abs().mean().cpu())
|
| 285 |
+
max_grads.append(p.grad.abs().max().cpu())
|
| 286 |
+
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
|
| 287 |
+
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
|
| 288 |
+
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
|
| 289 |
+
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
|
| 290 |
+
plt.xlim(left=0, right=len(ave_grads))
|
| 291 |
+
plt.ylim(bottom = -0.001, top=1) # zoom in on the lower gradient regions
|
| 292 |
+
plt.xlabel("Layers")
|
| 293 |
+
plt.ylabel("average gradient")
|
| 294 |
+
plt.title("Gradient flow")
|
| 295 |
+
plt.grid(True)
|
| 296 |
+
plt.legend([Line2D([0], [0], color="c", lw=4),
|
| 297 |
+
Line2D([0], [0], color="b", lw=4),
|
| 298 |
+
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
| 299 |
+
pltsavedir = "/home/atabey/gradients/tryout"
|
| 300 |
+
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
def _genDegree():
|
| 304 |
+
|
| 305 |
+
''' Generates the Degree distribution tensor for PNA, should be used everytime a different
|
| 306 |
+
dataset is used.
|
| 307 |
+
Can be called without arguments and saves the tensor for later use. If tensor was created
|
| 308 |
+
before, it just loads the degree tensor.
|
| 309 |
+
'''
|
| 310 |
+
|
| 311 |
+
degree_path = os.path.join(self.degree_dir, self.dataset_name + '-degree.pt')
|
| 312 |
+
if not os.path.exists(degree_path):
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
max_degree = -1
|
| 316 |
+
for data in self.dataset:
|
| 317 |
+
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
| 318 |
+
max_degree = max(max_degree, int(d.max()))
|
| 319 |
+
|
| 320 |
+
# Compute the in-degree histogram tensor
|
| 321 |
+
deg = torch.zeros(max_degree + 1, dtype=torch.long)
|
| 322 |
+
for data in self.dataset:
|
| 323 |
+
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
| 324 |
+
deg += torch.bincount(d, minlength=deg.numel())
|
| 325 |
+
torch.save(deg, 'DrugGEN/data/' + self.dataset_name + '-degree.pt')
|
| 326 |
+
else:
|
| 327 |
+
deg = torch.load(degree_path, map_location=lambda storage, loc: storage)
|
| 328 |
+
|
| 329 |
+
return deg
|
| 330 |
+
"""
|
| 331 |
+
def get_mol(smiles_or_mol):
|
| 332 |
+
'''
|
| 333 |
+
Loads SMILES/molecule into RDKit's object
|
| 334 |
+
'''
|
| 335 |
+
if isinstance(smiles_or_mol, str):
|
| 336 |
+
if len(smiles_or_mol) == 0:
|
| 337 |
+
return None
|
| 338 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
| 339 |
+
if mol is None:
|
| 340 |
+
return None
|
| 341 |
+
try:
|
| 342 |
+
Chem.SanitizeMol(mol)
|
| 343 |
+
except ValueError:
|
| 344 |
+
return None
|
| 345 |
+
return mol
|
| 346 |
+
return smiles_or_mol
|
| 347 |
+
|
| 348 |
+
def mapper(n_jobs):
|
| 349 |
+
'''
|
| 350 |
+
Returns function for map call.
|
| 351 |
+
If n_jobs == 1, will use standard map
|
| 352 |
+
If n_jobs > 1, will use multiprocessing pool
|
| 353 |
+
If n_jobs is a pool object, will return its map function
|
| 354 |
+
'''
|
| 355 |
+
if n_jobs == 1:
|
| 356 |
+
def _mapper(*args, **kwargs):
|
| 357 |
+
return list(map(*args, **kwargs))
|
| 358 |
+
|
| 359 |
+
return _mapper
|
| 360 |
+
if isinstance(n_jobs, int):
|
| 361 |
+
pool = Pool(n_jobs)
|
| 362 |
+
|
| 363 |
+
def _mapper(*args, **kwargs):
|
| 364 |
+
try:
|
| 365 |
+
result = pool.map(*args, **kwargs)
|
| 366 |
+
finally:
|
| 367 |
+
pool.terminate()
|
| 368 |
+
return result
|
| 369 |
+
|
| 370 |
+
return _mapper
|
| 371 |
+
return n_jobs.map
|
| 372 |
+
def remove_invalid(gen, canonize=True, n_jobs=1):
|
| 373 |
+
"""
|
| 374 |
+
Removes invalid molecules from the dataset
|
| 375 |
+
"""
|
| 376 |
+
if not canonize:
|
| 377 |
+
mols = mapper(n_jobs)(get_mol, gen)
|
| 378 |
+
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
| 379 |
+
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
| 380 |
+
x is not None]
|
| 381 |
+
def fraction_valid(gen, n_jobs=1):
|
| 382 |
+
"""
|
| 383 |
+
Computes a number of valid molecules
|
| 384 |
+
Parameters:
|
| 385 |
+
gen: list of SMILES
|
| 386 |
+
n_jobs: number of threads for calculation
|
| 387 |
+
"""
|
| 388 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
| 389 |
+
return 1 - gen.count(None) / len(gen)
|
| 390 |
+
def canonic_smiles(smiles_or_mol):
|
| 391 |
+
mol = get_mol(smiles_or_mol)
|
| 392 |
+
if mol is None:
|
| 393 |
+
return None
|
| 394 |
+
return Chem.MolToSmiles(mol)
|
| 395 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
| 396 |
+
"""
|
| 397 |
+
Computes a number of unique molecules
|
| 398 |
+
Parameters:
|
| 399 |
+
gen: list of SMILES
|
| 400 |
+
k: compute unique@k
|
| 401 |
+
n_jobs: number of threads for calculation
|
| 402 |
+
check_validity: raises ValueError if invalid molecules are present
|
| 403 |
+
"""
|
| 404 |
+
if k is not None:
|
| 405 |
+
if len(gen) < k:
|
| 406 |
+
warnings.warn(
|
| 407 |
+
"Can't compute unique@{}.".format(k) +
|
| 408 |
+
"gen contains only {} molecules".format(len(gen))
|
| 409 |
+
)
|
| 410 |
+
gen = gen[:k]
|
| 411 |
+
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
| 412 |
+
if None in canonic and check_validity:
|
| 413 |
+
raise ValueError("Invalid molecule passed to unique@k")
|
| 414 |
+
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
| 415 |
+
|
| 416 |
+
def novelty(gen, train, n_jobs=1):
|
| 417 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
| 418 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
| 419 |
+
train_set = set(train)
|
| 420 |
+
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
| 425 |
+
batch_size=5000, agg='max',
|
| 426 |
+
device='cpu', p=1):
|
| 427 |
+
"""
|
| 428 |
+
For each molecule in gen_vecs finds closest molecule in stock_vecs.
|
| 429 |
+
Returns average tanimoto score for between these molecules
|
| 430 |
+
|
| 431 |
+
Parameters:
|
| 432 |
+
stock_vecs: numpy array <n_vectors x dim>
|
| 433 |
+
gen_vecs: numpy array <n_vectors' x dim>
|
| 434 |
+
agg: max or mean
|
| 435 |
+
p: power for averaging: (mean x^p)^(1/p)
|
| 436 |
+
"""
|
| 437 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
| 438 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
| 439 |
+
total = np.zeros(len(gen_vecs))
|
| 440 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
| 441 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
| 442 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
| 443 |
+
|
| 444 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
| 445 |
+
y_gen = y_gen.transpose(0, 1)
|
| 446 |
+
tp = torch.mm(x_stock, y_gen)
|
| 447 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) +
|
| 448 |
+
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
| 449 |
+
jac[np.isnan(jac)] = 1
|
| 450 |
+
if p != 1:
|
| 451 |
+
jac = jac**p
|
| 452 |
+
if agg == 'max':
|
| 453 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
| 454 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
| 455 |
+
elif agg == 'mean':
|
| 456 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
| 457 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
| 458 |
+
if agg == 'mean':
|
| 459 |
+
agg_tanimoto /= total
|
| 460 |
+
if p != 1:
|
| 461 |
+
agg_tanimoto = (agg_tanimoto)**(1/p)
|
| 462 |
+
return np.mean(agg_tanimoto)
|