Trương Gia Bảo commited on
Commit
a3ea5d3
·
1 Parent(s): 6c75a42

Initial commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model.iter-685000 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ import sys, os
4
+ import rdkit
5
+ import rdkit.Chem as Chem
6
+ from rdkit.Chem.Draw import MolToImage
7
+ from rdkit.Chem import Descriptors
8
+ import sascorer
9
+ import networkx as nx
10
+
11
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
12
+
13
+ sys.path.append('%s/fast_jtnn/' % os.path.dirname(os.path.realpath(__file__)))
14
+ from mol_tree import Vocab, MolTree
15
+ from jtprop_vae import JTPropVAE
16
+ from molbloom import buy
17
+
18
+
19
+ lg = rdkit.RDLogger.logger()
20
+ lg.setLevel(rdkit.RDLogger.CRITICAL)
21
+
22
+ st.header('Junction Tree Variational Autoencoder for Molecular Graph Generation (JTVAE)')
23
+ st.subheader('Wengong Jin, Regina Barzilay, Tommi Jaakkola')
24
+ descrip = '''
25
+ We seek to automate the design of molecules based on specific chemical properties. In computational terms, this task involves continuous embedding and generation of molecular graphs. Our primary contribution is the direct realization of molecular graphs, a task previously approached by generating linear SMILES strings instead of graphs. Our junction tree variational autoencoder generates molecular graphs in two phases, by first generating a tree-structured scaffold over chemical substructures, and then combining them into a molecule with a graph message passing network. This approach allows us to incrementally expand molecules while maintaining chemical validity at every step. We evaluate our model on multiple tasks ranging from molecular generation to optimization. Across these tasks, our model outperforms previous state-of-the-art baselines by a significant margin.
26
+
27
+ [https://arxiv.org/abs/1802.04364](https://arxiv.org/abs/1802.04364)'''
28
+
29
+ with st.expander('About'):
30
+ st.markdown(descrip)
31
+
32
+ st.text_input('Enter a SMILES string:','CNC(=O)C1=NC=CC(=C1)OC2=CC=C(C=C2)NC(=O)NC3=CC(=C(C=C3)Cl)C(F)(F)F',key='smiles')
33
+
34
+ def penalized_logp_standard(mol):
35
+
36
+ logP_mean = 2.4399606244103639873799239
37
+ logP_std = 0.9293197802518905481505840
38
+ SA_mean = -2.4485512208785431553792478
39
+ SA_std = 0.4603110476923852334429910
40
+ cycle_mean = -0.0307270378623088931402396
41
+ cycle_std = 0.2163675785228087178335699
42
+
43
+ log_p = Descriptors.MolLogP(mol)
44
+ SA = -sascorer.calculateScore(mol)
45
+
46
+ # cycle score
47
+ cycle_list = nx.cycle_basis(nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(mol)))
48
+ if len(cycle_list) == 0:
49
+ cycle_length = 0
50
+ else:
51
+ cycle_length = max([len(j) for j in cycle_list])
52
+ if cycle_length <= 6:
53
+ cycle_length = 0
54
+ else:
55
+ cycle_length = cycle_length - 6
56
+ cycle_score = -cycle_length
57
+ # print(logP_mean)
58
+
59
+ standardized_log_p = (log_p - logP_mean) / logP_std
60
+ standardized_SA = (SA - SA_mean) / SA_std
61
+ standardized_cycle = (cycle_score - cycle_mean) / cycle_std
62
+ return standardized_log_p + standardized_SA + standardized_cycle
63
+
64
+ mol = Chem.MolFromSmiles(st.session_state.smiles)
65
+ if mol is None:
66
+ st.write('SMILES is invalid. Please enter a valid SMILES.')
67
+ else:
68
+ st.write('Molecule:')
69
+ st.image(MolToImage(mol,size=(300,300)))
70
+ score = penalized_logp_standard(mol)
71
+ st.write('Penalized logP score: %.5f' % (score))
72
+
73
+ if mol is not None:
74
+ st.slider('Choose learning rate: ',0.0,10.0,0.4,key='lr')
75
+ st.slider('Choose similarity cutoff: ',0.0,3.0,0.4,key='sim_cutoff')
76
+ st.slider('Choose number of iterations: ',1,100,80,key='n_iter')
77
+ vocab = [x.strip("\r\n ") for x in open('./vocab.txt')]
78
+ vocab = Vocab(vocab)
79
+ if st.button('Optimize'):
80
+ st.write('Testing')
81
+
82
+ model = JTPropVAE(vocab, 450, 56, 20, 3)
83
+
84
+ model.load_state_dict(torch.load('./model.iter-685000',map_location=torch.device('cpu')))
85
+
86
+ new_smiles,sim = model.optimize(st.session_state.smiles, sim_cutoff=st.session_state.sim_cutoff, lr=st.session_state.lr, num_iter=st.session_state.n_iter)
87
+
88
+ del model
89
+ if new_smiles is None:
90
+ st.write('Cannot optimize.')
91
+ else:
92
+ st.write('New SMILES:')
93
+ st.code(new_smiles)
94
+ new_mol = Chem.MolFromSmiles(new_smiles)
95
+ if new_mol is None:
96
+ st.write('New SMILES is invalid.')
97
+ else:
98
+ st.write('New SMILES molecule:')
99
+ st.image(MolToImage(new_mol,size=(300,300)))
100
+ new_score = penalized_logp_standard(new_mol)
101
+ st.write('New penalized logP score: %.5f' % (new_score))
102
+ st.write('Caching ZINC20 if necessary...')
103
+ if buy(new_smiles, catalog='zinc20',canonicalize=True):
104
+ st.write('This molecule exists.')
105
+ st.caption('Checked by molbloom.')
106
+ else:
107
+ st.write('THIS MOLECULE DOES NOT EXIST!')
108
+ st.caption('Checked by molbloom.')
109
+
110
+
111
+
fast_jtnn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # sys.path.append('./')
3
+ from mol_tree import Vocab, MolTree
4
+ from jtnn_vae import JTNNVAE
5
+ from jtnn_enc import JTNNEncoder
6
+ from jtmpn import JTMPN
7
+ from mpn import MPN
8
+ from nnutils import create_var
9
+ from datautils import MolTreeFolder, PairTreeFolder, MolTreeDataset
fast_jtnn/chemutils.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rdkit
2
+ import rdkit.Chem as Chem
3
+ from scipy.sparse import csr_matrix
4
+ from scipy.sparse.csgraph import minimum_spanning_tree
5
+ from collections import defaultdict
6
+ from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
7
+ from vocab import Vocab
8
+
9
+ MST_MAX_WEIGHT = 100
10
+ MAX_NCAND = 2000
11
+
12
+ def set_atommap(mol, num=0):
13
+ for atom in mol.GetAtoms():
14
+ atom.SetAtomMapNum(num)
15
+
16
+ def get_mol(smiles):
17
+ mol = Chem.MolFromSmiles(smiles)
18
+ if mol is None:
19
+ return None
20
+ Chem.Kekulize(mol, clearAromaticFlags=True)
21
+ return mol
22
+
23
+ def get_smiles(mol):
24
+ return Chem.MolToSmiles(mol, kekuleSmiles=True)
25
+
26
+ def decode_stereo(smiles2D):
27
+ mol = Chem.MolFromSmiles(smiles2D)
28
+ dec_isomers = list(EnumerateStereoisomers(mol))
29
+
30
+ dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
31
+ smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
32
+
33
+ chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
34
+ if len(chiralN) > 0:
35
+ for mol in dec_isomers:
36
+ for idx in chiralN:
37
+ mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
38
+ smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
39
+
40
+ return smiles3D
41
+
42
+ def sanitize(mol):
43
+ try:
44
+ smiles = get_smiles(mol)
45
+ mol = get_mol(smiles)
46
+ except Exception as e:
47
+ return None
48
+ return mol
49
+
50
+ def copy_atom(atom):
51
+ new_atom = Chem.Atom(atom.GetSymbol())
52
+ new_atom.SetFormalCharge(atom.GetFormalCharge())
53
+ new_atom.SetAtomMapNum(atom.GetAtomMapNum())
54
+ return new_atom
55
+
56
+ def copy_edit_mol(mol):
57
+ new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
58
+ for atom in mol.GetAtoms():
59
+ new_atom = copy_atom(atom)
60
+ new_mol.AddAtom(new_atom)
61
+ for bond in mol.GetBonds():
62
+ a1 = bond.GetBeginAtom().GetIdx()
63
+ a2 = bond.GetEndAtom().GetIdx()
64
+ bt = bond.GetBondType()
65
+ new_mol.AddBond(a1, a2, bt)
66
+ return new_mol
67
+
68
+ def get_clique_mol(mol, atoms):
69
+ smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
70
+ new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
71
+ new_mol = copy_edit_mol(new_mol).GetMol()
72
+ new_mol = sanitize(new_mol) #We assume this is not None
73
+ return new_mol
74
+
75
+ def tree_decomp(mol):
76
+ n_atoms = mol.GetNumAtoms()
77
+ if n_atoms == 1: #special case
78
+ return [[0]], []
79
+
80
+ cliques = []
81
+ for bond in mol.GetBonds():
82
+ a1 = bond.GetBeginAtom().GetIdx()
83
+ a2 = bond.GetEndAtom().GetIdx()
84
+ if not bond.IsInRing():
85
+ cliques.append([a1,a2])
86
+
87
+ ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
88
+ cliques.extend(ssr)
89
+
90
+ nei_list = [[] for i in range(n_atoms)]
91
+ for i in range(len(cliques)):
92
+ for atom in cliques[i]:
93
+ nei_list[atom].append(i)
94
+
95
+ #Merge Rings with intersection > 2 atoms
96
+ for i in range(len(cliques)):
97
+ if len(cliques[i]) <= 2: continue
98
+ for atom in cliques[i]:
99
+ for j in nei_list[atom]:
100
+ if i >= j or len(cliques[j]) <= 2: continue
101
+ inter = set(cliques[i]) & set(cliques[j])
102
+ if len(inter) > 2:
103
+ cliques[i].extend(cliques[j])
104
+ cliques[i] = list(set(cliques[i]))
105
+ cliques[j] = []
106
+
107
+ cliques = [c for c in cliques if len(c) > 0]
108
+ nei_list = [[] for i in range(n_atoms)]
109
+ for i in range(len(cliques)):
110
+ for atom in cliques[i]:
111
+ nei_list[atom].append(i)
112
+
113
+ #Build edges and add singleton cliques
114
+ edges = defaultdict(int)
115
+ for atom in range(n_atoms):
116
+ if len(nei_list[atom]) <= 1:
117
+ continue
118
+ cnei = nei_list[atom]
119
+ bonds = [c for c in cnei if len(cliques[c]) == 2]
120
+ rings = [c for c in cnei if len(cliques[c]) > 4]
121
+ if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
122
+ cliques.append([atom])
123
+ c2 = len(cliques) - 1
124
+ for c1 in cnei:
125
+ edges[(c1,c2)] = 1
126
+ elif len(rings) > 2: #Multiple (n>2) complex rings
127
+ cliques.append([atom])
128
+ c2 = len(cliques) - 1
129
+ for c1 in cnei:
130
+ edges[(c1,c2)] = MST_MAX_WEIGHT - 1
131
+ else:
132
+ for i in range(len(cnei)):
133
+ for j in range(i + 1, len(cnei)):
134
+ c1,c2 = cnei[i],cnei[j]
135
+ inter = set(cliques[c1]) & set(cliques[c2])
136
+ if edges[(c1,c2)] < len(inter):
137
+ edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction
138
+
139
+ edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()]
140
+ if len(edges) == 0:
141
+ return cliques, edges
142
+
143
+ #Compute Maximum Spanning Tree
144
+ row,col,data = zip(*edges)
145
+ n_clique = len(cliques)
146
+ clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
147
+ junc_tree = minimum_spanning_tree(clique_graph)
148
+ row,col = junc_tree.nonzero()
149
+ edges = [(row[i],col[i]) for i in range(len(row))]
150
+ return (cliques, edges)
151
+
152
+ def atom_equal(a1, a2):
153
+ return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
154
+
155
+ #Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
156
+ def ring_bond_equal(b1, b2, reverse=False):
157
+ b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
158
+ if reverse:
159
+ b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
160
+ else:
161
+ b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
162
+ return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
163
+
164
+ def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap):
165
+ prev_nids = [node.nid for node in prev_nodes]
166
+ for nei_node in prev_nodes + neighbors:
167
+ nei_id,nei_mol = nei_node.nid,nei_node.mol
168
+ amap = nei_amap[nei_id]
169
+ for atom in nei_mol.GetAtoms():
170
+ if atom.GetIdx() not in amap:
171
+ new_atom = copy_atom(atom)
172
+ amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
173
+
174
+ if nei_mol.GetNumBonds() == 0:
175
+ nei_atom = nei_mol.GetAtomWithIdx(0)
176
+ ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
177
+ ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
178
+ else:
179
+ for bond in nei_mol.GetBonds():
180
+ a1 = amap[bond.GetBeginAtom().GetIdx()]
181
+ a2 = amap[bond.GetEndAtom().GetIdx()]
182
+ if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
183
+ ctr_mol.AddBond(a1, a2, bond.GetBondType())
184
+ elif nei_id in prev_nids: #father node overrides
185
+ ctr_mol.RemoveBond(a1, a2)
186
+ ctr_mol.AddBond(a1, a2, bond.GetBondType())
187
+ return ctr_mol
188
+
189
+ def local_attach(ctr_mol, neighbors, prev_nodes, amap_list):
190
+ ctr_mol = copy_edit_mol(ctr_mol)
191
+ nei_amap = {nei.nid:{} for nei in prev_nodes + neighbors}
192
+
193
+ for nei_id,ctr_atom,nei_atom in amap_list:
194
+ nei_amap[nei_id][nei_atom] = ctr_atom
195
+
196
+ ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap)
197
+ return ctr_mol.GetMol()
198
+
199
+ #This version records idx mapping between ctr_mol and nei_mol
200
+ def enum_attach(ctr_mol, nei_node, amap, singletons):
201
+ nei_mol,nei_idx = nei_node.mol,nei_node.nid
202
+ att_confs = []
203
+ black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons]
204
+ ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list]
205
+ ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
206
+
207
+ if nei_mol.GetNumBonds() == 0: #neighbor singleton
208
+ nei_atom = nei_mol.GetAtomWithIdx(0)
209
+ used_list = [atom_idx for _,atom_idx,_ in amap]
210
+ for atom in ctr_atoms:
211
+ if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
212
+ new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
213
+ att_confs.append( new_amap )
214
+
215
+ elif nei_mol.GetNumBonds() == 1: #neighbor is a bond
216
+ bond = nei_mol.GetBondWithIdx(0)
217
+ bond_val = int(bond.GetBondTypeAsDouble())
218
+ b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom()
219
+
220
+ for atom in ctr_atoms:
221
+ #Optimize if atom is carbon (other atoms may change valence)
222
+ if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
223
+ continue
224
+ if atom_equal(atom, b1):
225
+ new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
226
+ att_confs.append( new_amap )
227
+ elif atom_equal(atom, b2):
228
+ new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
229
+ att_confs.append( new_amap )
230
+ else:
231
+ #intersection is an atom
232
+ for a1 in ctr_atoms:
233
+ for a2 in nei_mol.GetAtoms():
234
+ if atom_equal(a1, a2):
235
+ #Optimize if atom is carbon (other atoms may change valence)
236
+ if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
237
+ continue
238
+ new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
239
+ att_confs.append( new_amap )
240
+
241
+ #intersection is an bond
242
+ if ctr_mol.GetNumBonds() > 1:
243
+ for b1 in ctr_bonds:
244
+ for b2 in nei_mol.GetBonds():
245
+ if ring_bond_equal(b1, b2):
246
+ new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
247
+ att_confs.append( new_amap )
248
+
249
+ if ring_bond_equal(b1, b2, reverse=True):
250
+ new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
251
+ att_confs.append( new_amap )
252
+ return att_confs
253
+
254
+ #Try rings first: Speed-Up
255
+ def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]):
256
+ all_attach_confs = []
257
+ singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1]
258
+
259
+ def search(cur_amap, depth):
260
+ if len(all_attach_confs) > MAX_NCAND:
261
+ return
262
+ if depth == len(neighbors):
263
+ all_attach_confs.append(cur_amap)
264
+ return
265
+
266
+ nei_node = neighbors[depth]
267
+ cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons)
268
+ cand_smiles = set()
269
+ candidates = []
270
+ for amap in cand_amap:
271
+ cand_mol = local_attach(node.mol, neighbors[:depth+1], prev_nodes, amap)
272
+ cand_mol = sanitize(cand_mol)
273
+ if cand_mol is None:
274
+ continue
275
+ smiles = get_smiles(cand_mol)
276
+ if smiles in cand_smiles:
277
+ continue
278
+ cand_smiles.add(smiles)
279
+ candidates.append(amap)
280
+
281
+ if len(candidates) == 0:
282
+ return
283
+
284
+ for new_amap in candidates:
285
+ search(new_amap, depth + 1)
286
+
287
+ search(prev_amap, 0)
288
+ cand_smiles = set()
289
+ candidates = []
290
+ aroma_score = []
291
+ for amap in all_attach_confs:
292
+ cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap)
293
+ cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
294
+ smiles = Chem.MolToSmiles(cand_mol)
295
+ if smiles in cand_smiles or check_singleton(cand_mol, node, neighbors) == False:
296
+ continue
297
+ cand_smiles.add(smiles)
298
+ candidates.append( (smiles,amap) )
299
+ aroma_score.append( check_aroma(cand_mol, node, neighbors) )
300
+
301
+ return candidates, aroma_score
302
+
303
+ def check_singleton(cand_mol, ctr_node, nei_nodes):
304
+ rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() > 2]
305
+ singletons = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() == 1]
306
+ if len(singletons) > 0 or len(rings) == 0: return True
307
+
308
+ n_leaf2_atoms = 0
309
+ for atom in cand_mol.GetAtoms():
310
+ nei_leaf_atoms = [a for a in atom.GetNeighbors() if not a.IsInRing()] #a.GetDegree() == 1]
311
+ if len(nei_leaf_atoms) > 1:
312
+ n_leaf2_atoms += 1
313
+
314
+ return n_leaf2_atoms == 0
315
+
316
+ def check_aroma(cand_mol, ctr_node, nei_nodes):
317
+ rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() >= 3]
318
+ if len(rings) < 2: return 0 #Only multi-ring system needs to be checked
319
+
320
+ get_nid = lambda x: 0 if x.is_leaf else x.nid
321
+ benzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.benzynes]
322
+ penzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.penzynes]
323
+ if len(benzynes) + len(penzynes) == 0:
324
+ return 0 #No specific aromatic rings
325
+
326
+ n_aroma_atoms = 0
327
+ for atom in cand_mol.GetAtoms():
328
+ if atom.GetAtomMapNum() in benzynes+penzynes and atom.GetIsAromatic():
329
+ n_aroma_atoms += 1
330
+
331
+ if n_aroma_atoms >= len(benzynes) * 4 + len(penzynes) * 3:
332
+ return 1000
333
+ else:
334
+ return -0.001
335
+
336
+ #Only used for debugging purpose
337
+ def dfs_assemble(cur_mol, global_amap, fa_amap, cur_node, fa_node):
338
+ fa_nid = fa_node.nid if fa_node is not None else -1
339
+ prev_nodes = [fa_node] if fa_node is not None else []
340
+
341
+ children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
342
+ neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
343
+ neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
344
+ singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
345
+ neighbors = singletons + neighbors
346
+
347
+ cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
348
+ cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
349
+
350
+ cand_smiles,cand_amap = zip(*cands)
351
+ label_idx = cand_smiles.index(cur_node.label)
352
+ label_amap = cand_amap[label_idx]
353
+
354
+ for nei_id,ctr_atom,nei_atom in label_amap:
355
+ if nei_id == fa_nid:
356
+ continue
357
+ global_amap[nei_id][nei_atom] = global_amap[cur_node.nid][ctr_atom]
358
+
359
+ cur_mol = attach_mols(cur_mol, children, [], global_amap) #father is already attached
360
+ for nei_node in children:
361
+ if not nei_node.is_leaf:
362
+ dfs_assemble(cur_mol, global_amap, label_amap, nei_node, cur_node)
363
+
364
+ if __name__ == "__main__":
365
+ import sys
366
+ from mol_tree import MolTree
367
+ lg = rdkit.RDLogger.logger()
368
+ lg.setLevel(rdkit.RDLogger.CRITICAL)
369
+
370
+ smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1","O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
371
+
372
+ def tree_test():
373
+ for s in sys.stdin:
374
+ s = s.split()[0]
375
+ tree = MolTree(s)
376
+ print('-------------------------------------------')
377
+ print(s)
378
+ for node in tree.nodes:
379
+ print(node.smiles, [x.smiles for x in node.neighbors])
380
+
381
+ def decode_test():
382
+ wrong = 0
383
+ for tot,s in enumerate(sys.stdin):
384
+ s = s.split()[0]
385
+ tree = MolTree(s)
386
+ tree.recover()
387
+
388
+ cur_mol = copy_edit_mol(tree.nodes[0].mol)
389
+ global_amap = [{}] + [{} for node in tree.nodes]
390
+ global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
391
+
392
+ dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None)
393
+
394
+ cur_mol = cur_mol.GetMol()
395
+ cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
396
+ set_atommap(cur_mol)
397
+ dec_smiles = Chem.MolToSmiles(cur_mol)
398
+
399
+ gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s))
400
+ if gold_smiles != dec_smiles:
401
+ print(gold_smiles, dec_smiles)
402
+ wrong += 1
403
+ print(wrong, tot + 1)
404
+
405
+ def enum_test():
406
+ for s in sys.stdin:
407
+ s = s.split()[0]
408
+ tree = MolTree(s)
409
+ tree.recover()
410
+ tree.assemble()
411
+ for node in tree.nodes:
412
+ if node.label not in node.cands:
413
+ print(tree.smiles)
414
+ print(node.smiles, [x.smiles for x in node.neighbors])
415
+ print(node.label, len(node.cands))
416
+
417
+ def count():
418
+ cnt,n = 0,0
419
+ for s in sys.stdin:
420
+ s = s.split()[0]
421
+ tree = MolTree(s)
422
+ tree.recover()
423
+ tree.assemble()
424
+ for node in tree.nodes:
425
+ cnt += len(node.cands)
426
+ n += len(tree.nodes)
427
+ #print cnt * 1.0 / n
428
+
429
+ count()
fast_jtnn/datautils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from mol_tree import MolTree
4
+ import numpy as np
5
+ from jtnn_enc import JTNNEncoder
6
+ from mpn import MPN
7
+ from jtmpn import JTMPN
8
+ import pickle
9
+ import os, random
10
+
11
+ class PairTreeFolder(object):
12
+
13
+ def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, y_assm=True, replicate=None):
14
+ self.data_folder = data_folder
15
+ self.data_files = [fn for fn in os.listdir(data_folder)]
16
+ self.batch_size = batch_size
17
+ self.vocab = vocab
18
+ self.num_workers = num_workers
19
+ self.y_assm = y_assm
20
+ self.shuffle = shuffle
21
+
22
+ if replicate is not None: #expand is int
23
+ self.data_files = self.data_files * replicate
24
+
25
+ def __iter__(self):
26
+ for fn in self.data_files:
27
+ fn = os.path.join(self.data_folder, fn)
28
+ with open(fn, 'rb') as f:
29
+ data = pickle.load(f)
30
+
31
+ if self.shuffle:
32
+ random.shuffle(data) #shuffle data before batch
33
+
34
+ batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
35
+ if len(batches[-1]) < self.batch_size:
36
+ batches.pop()
37
+
38
+ dataset = PairTreeDataset(batches, self.vocab, self.y_assm)
39
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
40
+
41
+ for b in dataloader:
42
+ yield b
43
+
44
+ del data, batches, dataset, dataloader
45
+
46
+ class MolTreeFolder(object):
47
+
48
+ def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
49
+ self.data_folder = data_folder
50
+ self.data_files = [fn for fn in os.listdir(data_folder)]
51
+ self.batch_size = batch_size
52
+ self.vocab = vocab
53
+ self.num_workers = num_workers
54
+ self.shuffle = shuffle
55
+ self.assm = assm
56
+
57
+ if replicate is not None: #expand is int
58
+ self.data_files = self.data_files * replicate
59
+
60
+ def __iter__(self):
61
+ for fn in self.data_files:
62
+ fn = os.path.join(self.data_folder, fn)
63
+ with open(fn, 'rb') as f:
64
+ data = pickle.load(f)
65
+
66
+ if self.shuffle:
67
+ random.shuffle(data) #shuffle data before batch
68
+
69
+ batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
70
+ if len(batches[-1]) < self.batch_size:
71
+ batches.pop()
72
+
73
+ dataset = MolTreeDataset(batches, self.vocab, self.assm)
74
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
75
+
76
+ for b in dataloader:
77
+ yield b
78
+
79
+ del data, batches, dataset, dataloader
80
+
81
+ class PairTreeDataset(Dataset):
82
+
83
+ def __init__(self, data, vocab, y_assm):
84
+ self.data = data
85
+ self.vocab = vocab
86
+ self.y_assm = y_assm
87
+
88
+ def __len__(self):
89
+ return len(self.data)
90
+
91
+ def __getitem__(self, idx):
92
+ batch0, batch1 = zip(*self.data[idx])
93
+ return tensorize(batch0, self.vocab, assm=False), tensorize(batch1, self.vocab, assm=self.y_assm)
94
+
95
+ class MolTreeDataset(Dataset):
96
+
97
+ def __init__(self, data, vocab, assm=True):
98
+ self.data = data
99
+ self.vocab = vocab
100
+ self.assm = assm
101
+
102
+ def __len__(self):
103
+ return len(self.data)
104
+
105
+ def __getitem__(self, idx):
106
+ return tensorize(self.data[idx], self.vocab, assm=self.assm)
107
+
108
+ def tensorize(tree_batch, vocab, assm=True):
109
+ set_batch_nodeID(tree_batch, vocab)
110
+ smiles_batch = [tree.smiles for tree in tree_batch]
111
+ jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
112
+ jtenc_holder = jtenc_holder
113
+ mpn_holder = MPN.tensorize(smiles_batch)
114
+
115
+ if assm is False:
116
+ return tree_batch, jtenc_holder, mpn_holder
117
+
118
+ cands = []
119
+ batch_idx = []
120
+ for i,mol_tree in enumerate(tree_batch):
121
+ for node in mol_tree.nodes:
122
+ #Leaf node's attachment is determined by neighboring node's attachment
123
+ if node.is_leaf or len(node.cands) == 1: continue
124
+ cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
125
+ batch_idx.extend([i] * len(node.cands))
126
+
127
+ jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
128
+ batch_idx = torch.LongTensor(batch_idx)
129
+
130
+ return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
131
+
132
+ def set_batch_nodeID(mol_batch, vocab):
133
+ tot = 0
134
+ for mol_tree in mol_batch:
135
+ for node in mol_tree.nodes:
136
+ node.idx = tot
137
+ node.wid = vocab.get_index(node.smiles)
138
+ tot += 1
139
+
140
+ class PropMolTreeDataset(Dataset):
141
+
142
+ def __init__(self, data, vocab, assm=True):
143
+ self.data = data
144
+ self.vocab = vocab
145
+ self.assm = assm
146
+
147
+ def __len__(self):
148
+ return len(self.data)
149
+
150
+ def __getitem__(self, idx):
151
+ return tensorize_prop(self.data[idx],self.vocab, assm=self.assm)
152
+
153
+ class PropMolTreeFolder(object):
154
+
155
+ def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
156
+ self.data_folder = data_folder
157
+ self.data_files = [fn for fn in os.listdir(data_folder)]
158
+ self.batch_size = batch_size
159
+ self.vocab = vocab
160
+ self.num_workers = num_workers
161
+ self.shuffle = shuffle
162
+ self.assm = assm
163
+
164
+ if replicate is not None: #expand is int
165
+ self.data_files = self.data_files * replicate
166
+
167
+ def __iter__(self):
168
+ for fn in self.data_files:
169
+ fn = os.path.join(self.data_folder, fn)
170
+ with open(fn, 'rb') as f:
171
+ data = pickle.load(f)
172
+
173
+ # print(data[0:5])
174
+
175
+ if self.shuffle:
176
+ random.shuffle(data) #shuffle data before batch
177
+
178
+ batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
179
+ if len(batches[-1]) < self.batch_size:
180
+ batches.pop()
181
+
182
+ dataset = PropMolTreeDataset(batches, self.vocab, self.assm)
183
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
184
+
185
+ for b in dataloader:
186
+ yield b
187
+
188
+ del data, batches, dataset, dataloader
189
+
190
+ def tensorize_prop(data, vocab, assm=True):
191
+ tree_batch,prop = list(zip(*data))
192
+ set_batch_nodeID(tree_batch, vocab)
193
+ smiles_batch = [tree.smiles for tree in tree_batch]
194
+ jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
195
+ jtenc_holder = jtenc_holder
196
+ mpn_holder = MPN.tensorize(smiles_batch)
197
+
198
+ if assm is False:
199
+ return tree_batch, jtenc_holder, mpn_holder
200
+
201
+ cands = []
202
+ batch_idx = []
203
+ for i,mol_tree in enumerate(tree_batch):
204
+ for node in mol_tree.nodes:
205
+ #Leaf node's attachment is determined by neighboring node's attachment
206
+ if node.is_leaf or len(node.cands) == 1: continue
207
+ cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
208
+ batch_idx.extend([i] * len(node.cands))
209
+
210
+ jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
211
+ batch_idx = torch.LongTensor(batch_idx)
212
+
213
+ return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx), prop
fast_jtnn/jtmpn.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from nnutils import create_var, index_select_ND
5
+ from chemutils import get_mol
6
+ import rdkit.Chem as Chem
7
+
8
+ ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
9
+
10
+ ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
11
+ BOND_FDIM = 5
12
+ MAX_NB = 15
13
+
14
+ def onek_encoding_unk(x, allowable_set):
15
+ if x not in allowable_set:
16
+ x = allowable_set[-1]
17
+ return list(map(lambda s: x == s, allowable_set))
18
+
19
+ def atom_features(atom):
20
+ return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
21
+ + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
22
+ + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
23
+ + [atom.GetIsAromatic()])
24
+
25
+ def bond_features(bond):
26
+ bt = bond.GetBondType()
27
+ return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()])
28
+
29
+ class JTMPN(nn.Module):
30
+
31
+ def __init__(self, hidden_size, depth):
32
+ super(JTMPN, self).__init__()
33
+ self.hidden_size = hidden_size
34
+ self.depth = depth
35
+
36
+ self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
37
+ self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
38
+ self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
39
+
40
+ def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0)
41
+ fatoms = create_var(fatoms)
42
+ fbonds = create_var(fbonds)
43
+ agraph = create_var(agraph)
44
+ bgraph = create_var(bgraph)
45
+
46
+ binput = self.W_i(fbonds)
47
+ graph_message = F.relu(binput)
48
+
49
+ for i in range(self.depth - 1):
50
+ message = torch.cat([tree_message,graph_message], dim=0)
51
+ nei_message = index_select_ND(message, 0, bgraph)
52
+ nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0)
53
+ nei_message = self.W_h(nei_message)
54
+ graph_message = F.relu(binput + nei_message)
55
+
56
+ message = torch.cat([tree_message,graph_message], dim=0)
57
+ nei_message = index_select_ND(message, 0, agraph)
58
+ nei_message = nei_message.sum(dim=1)
59
+ ainput = torch.cat([fatoms, nei_message], dim=1)
60
+ atom_hiddens = F.relu(self.W_o(ainput))
61
+
62
+ mol_vecs = []
63
+ for st,le in scope:
64
+ mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
65
+ mol_vecs.append(mol_vec)
66
+
67
+ mol_vecs = torch.stack(mol_vecs, dim=0)
68
+ return mol_vecs
69
+
70
+ @staticmethod
71
+ def tensorize(cand_batch, mess_dict):
72
+ fatoms,fbonds = [],[]
73
+ in_bonds,all_bonds = [],[]
74
+ total_atoms = 0
75
+ total_mess = len(mess_dict) + 1 #must include vec(0) padding
76
+ scope = []
77
+
78
+ for smiles,all_nodes,ctr_node in cand_batch:
79
+ mol = Chem.MolFromSmiles(smiles)
80
+ Chem.Kekulize(mol) #The original jtnn version kekulizes. Need to revisit why it is necessary
81
+ n_atoms = mol.GetNumAtoms()
82
+ ctr_bid = ctr_node.idx
83
+
84
+ for atom in mol.GetAtoms():
85
+ fatoms.append( atom_features(atom) )
86
+ in_bonds.append([])
87
+
88
+ for bond in mol.GetBonds():
89
+ a1 = bond.GetBeginAtom()
90
+ a2 = bond.GetEndAtom()
91
+ x = a1.GetIdx() + total_atoms
92
+ y = a2.GetIdx() + total_atoms
93
+ #Here x_nid,y_nid could be 0
94
+ x_nid,y_nid = a1.GetAtomMapNum(),a2.GetAtomMapNum()
95
+ x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1
96
+ y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1
97
+
98
+ bfeature = bond_features(bond)
99
+
100
+ b = total_mess + len(all_bonds) #bond idx offseted by total_mess
101
+ all_bonds.append((x,y))
102
+ fbonds.append( torch.cat([fatoms[x], bfeature], 0) )
103
+ in_bonds[y].append(b)
104
+
105
+ b = total_mess + len(all_bonds)
106
+ all_bonds.append((y,x))
107
+ fbonds.append( torch.cat([fatoms[y], bfeature], 0) )
108
+ in_bonds[x].append(b)
109
+
110
+ if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
111
+ if (x_bid,y_bid) in mess_dict:
112
+ mess_idx = mess_dict[(x_bid,y_bid)]
113
+ in_bonds[y].append(mess_idx)
114
+ if (y_bid,x_bid) in mess_dict:
115
+ mess_idx = mess_dict[(y_bid,x_bid)]
116
+ in_bonds[x].append(mess_idx)
117
+
118
+ scope.append((total_atoms,n_atoms))
119
+ total_atoms += n_atoms
120
+
121
+ total_bonds = len(all_bonds)
122
+ fatoms = torch.stack(fatoms, 0)
123
+ fbonds = torch.stack(fbonds, 0)
124
+ agraph = torch.zeros(total_atoms,MAX_NB).long()
125
+ bgraph = torch.zeros(total_bonds,MAX_NB).long()
126
+
127
+ for a in range(total_atoms):
128
+ for i,b in enumerate(in_bonds[a]):
129
+ agraph[a,i] = b
130
+
131
+ for b1 in range(total_bonds):
132
+ x,y = all_bonds[b1]
133
+ for i,b2 in enumerate(in_bonds[x]): #b2 is offseted by total_mess
134
+ if b2 < total_mess or all_bonds[b2-total_mess][0] != y:
135
+ bgraph[b1,i] = b2
136
+
137
+ return (fatoms, fbonds, agraph, bgraph, scope)
138
+
fast_jtnn/jtnn_dec.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from mol_tree import Vocab, MolTree, MolTreeNode
5
+ from nnutils import create_var, GRU
6
+ from chemutils import enum_assemble, set_atommap
7
+ import copy
8
+
9
+ MAX_NB = 15
10
+ MAX_DECODE_LEN = 100
11
+
12
+ class JTNNDecoder(nn.Module):
13
+
14
+ def __init__(self, vocab, hidden_size, latent_size, embedding):
15
+ super(JTNNDecoder, self).__init__()
16
+ self.hidden_size = hidden_size
17
+ self.vocab_size = vocab.size()
18
+ self.vocab = vocab
19
+ self.embedding = embedding
20
+
21
+ #GRU Weights
22
+ self.W_z = nn.Linear(2 * hidden_size, hidden_size)
23
+ self.U_r = nn.Linear(hidden_size, hidden_size, bias=False)
24
+ self.W_r = nn.Linear(hidden_size, hidden_size)
25
+ self.W_h = nn.Linear(2 * hidden_size, hidden_size)
26
+
27
+ #Word Prediction Weights
28
+ self.W = nn.Linear(hidden_size + latent_size, hidden_size)
29
+
30
+ #Stop Prediction Weights
31
+ self.U = nn.Linear(hidden_size + latent_size, hidden_size)
32
+ self.U_i = nn.Linear(2 * hidden_size, hidden_size)
33
+
34
+ #Output Weights
35
+ self.W_o = nn.Linear(hidden_size, self.vocab_size)
36
+ self.U_o = nn.Linear(hidden_size, 1)
37
+
38
+ #Loss Functions
39
+ # self.pred_loss = nn.CrossEntropyLoss(size_average=False)
40
+ # self.stop_loss = nn.BCEWithLogitsLoss(size_average=False)
41
+ self.pred_loss = nn.CrossEntropyLoss(reduction='sum')
42
+ self.stop_loss = nn.BCEWithLogitsLoss(reduction='sum')
43
+
44
+ def aggregate(self, hiddens, contexts, x_tree_vecs, mode):
45
+ if mode == 'word':
46
+ V, V_o = self.W, self.W_o
47
+ elif mode == 'stop':
48
+ V, V_o = self.U, self.U_o
49
+ else:
50
+ raise ValueError('aggregate mode is wrong')
51
+
52
+ tree_contexts = x_tree_vecs.index_select(0, contexts)
53
+ input_vec = torch.cat([hiddens, tree_contexts], dim=-1)
54
+ output_vec = F.relu( V(input_vec) )
55
+ return V_o(output_vec)
56
+
57
+ def forward(self, mol_batch, x_tree_vecs):
58
+ pred_hiddens,pred_contexts,pred_targets = [],[],[]
59
+ stop_hiddens,stop_contexts,stop_targets = [],[],[]
60
+ traces = []
61
+ for mol_tree in mol_batch:
62
+ s = []
63
+ dfs(s, mol_tree.nodes[0], -1)
64
+ traces.append(s)
65
+ for node in mol_tree.nodes:
66
+ node.neighbors = []
67
+
68
+ #Predict Root
69
+ batch_size = len(mol_batch)
70
+ pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size)))
71
+ pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
72
+ pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
73
+
74
+ max_iter = max([len(tr) for tr in traces])
75
+ padding = create_var(torch.zeros(self.hidden_size), False)
76
+ h = {}
77
+
78
+ for t in range(max_iter):
79
+ prop_list = []
80
+ batch_list = []
81
+ for i,plist in enumerate(traces):
82
+ if t < len(plist):
83
+ prop_list.append(plist[t])
84
+ batch_list.append(i)
85
+
86
+ cur_x = []
87
+ cur_h_nei,cur_o_nei = [],[]
88
+
89
+ for node_x, real_y, _ in prop_list:
90
+ #Neighbors for message passing (target not included)
91
+ cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
92
+ pad_len = MAX_NB - len(cur_nei)
93
+ cur_h_nei.extend(cur_nei)
94
+ cur_h_nei.extend([padding] * pad_len)
95
+
96
+ #Neighbors for stop prediction (all neighbors)
97
+ cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
98
+ pad_len = MAX_NB - len(cur_nei)
99
+ cur_o_nei.extend(cur_nei)
100
+ cur_o_nei.extend([padding] * pad_len)
101
+
102
+ #Current clique embedding
103
+ cur_x.append(node_x.wid)
104
+
105
+ #Clique embedding
106
+ cur_x = create_var(torch.LongTensor(cur_x))
107
+ cur_x = self.embedding(cur_x)
108
+
109
+ #Message passing
110
+ cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
111
+ new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
112
+
113
+ #Node Aggregate
114
+ cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
115
+ cur_o = cur_o_nei.sum(dim=1)
116
+
117
+ #Gather targets
118
+ pred_target,pred_list = [],[]
119
+ stop_target = []
120
+ for i,m in enumerate(prop_list):
121
+ node_x,node_y,direction = m
122
+ x,y = node_x.idx,node_y.idx
123
+ h[(x,y)] = new_h[i]
124
+ node_y.neighbors.append(node_x)
125
+ if direction == 1:
126
+ pred_target.append(node_y.wid)
127
+ pred_list.append(i)
128
+ stop_target.append(direction)
129
+
130
+ #Hidden states for stop prediction
131
+ cur_batch = create_var(torch.LongTensor(batch_list))
132
+ stop_hidden = torch.cat([cur_x,cur_o], dim=1)
133
+ stop_hiddens.append( stop_hidden )
134
+ stop_contexts.append( cur_batch )
135
+ stop_targets.extend( stop_target )
136
+
137
+ #Hidden states for clique prediction
138
+ if len(pred_list) > 0:
139
+ batch_list = [batch_list[i] for i in pred_list]
140
+ cur_batch = create_var(torch.LongTensor(batch_list))
141
+ pred_contexts.append( cur_batch )
142
+
143
+ cur_pred = create_var(torch.LongTensor(pred_list))
144
+ pred_hiddens.append( new_h.index_select(0, cur_pred) )
145
+ pred_targets.extend( pred_target )
146
+
147
+ #Last stop at root
148
+ cur_x,cur_o_nei = [],[]
149
+ for mol_tree in mol_batch:
150
+ node_x = mol_tree.nodes[0]
151
+ cur_x.append(node_x.wid)
152
+ cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
153
+ pad_len = MAX_NB - len(cur_nei)
154
+ cur_o_nei.extend(cur_nei)
155
+ cur_o_nei.extend([padding] * pad_len)
156
+
157
+ cur_x = create_var(torch.LongTensor(cur_x))
158
+ cur_x = self.embedding(cur_x)
159
+ cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
160
+ cur_o = cur_o_nei.sum(dim=1)
161
+
162
+ stop_hidden = torch.cat([cur_x,cur_o], dim=1)
163
+ stop_hiddens.append( stop_hidden )
164
+ stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
165
+ stop_targets.extend( [0] * len(mol_batch) )
166
+
167
+ #Predict next clique
168
+ pred_contexts = torch.cat(pred_contexts, dim=0)
169
+ pred_hiddens = torch.cat(pred_hiddens, dim=0)
170
+ pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs, 'word')
171
+ pred_targets = create_var(torch.LongTensor(pred_targets))
172
+
173
+ pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
174
+ _,preds = torch.max(pred_scores, dim=1)
175
+ pred_acc = torch.eq(preds, pred_targets).float()
176
+ pred_acc = torch.sum(pred_acc) / pred_targets.nelement()
177
+
178
+ #Predict stop
179
+ stop_contexts = torch.cat(stop_contexts, dim=0)
180
+ stop_hiddens = torch.cat(stop_hiddens, dim=0)
181
+ stop_hiddens = F.relu( self.U_i(stop_hiddens) )
182
+ stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs, 'stop')
183
+ stop_scores = stop_scores.squeeze(-1)
184
+ stop_targets = create_var(torch.Tensor(stop_targets))
185
+
186
+ stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
187
+ stops = torch.ge(stop_scores, 0).float()
188
+ stop_acc = torch.eq(stops, stop_targets).float()
189
+ stop_acc = torch.sum(stop_acc) / stop_targets.nelement()
190
+
191
+ return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
192
+
193
+ def decode(self, x_tree_vecs, prob_decode):
194
+ assert x_tree_vecs.size(0) == 1
195
+
196
+ stack = []
197
+ init_hiddens = create_var( torch.zeros(1, self.hidden_size) )
198
+ zero_pad = create_var(torch.zeros(1,1,self.hidden_size))
199
+ contexts = create_var( torch.LongTensor(1).zero_() )
200
+
201
+ #Root Prediction
202
+ root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word')
203
+ _,root_wid = torch.max(root_score, dim=1)
204
+ root_wid = root_wid.item()
205
+
206
+ root = MolTreeNode(self.vocab.get_smiles(root_wid))
207
+ root.wid = root_wid
208
+ root.idx = 0
209
+ stack.append( (root, self.vocab.get_slots(root.wid)) )
210
+
211
+ all_nodes = [root]
212
+ h = {}
213
+ for step in range(MAX_DECODE_LEN):
214
+ node_x,fa_slot = stack[-1]
215
+ cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ]
216
+ if len(cur_h_nei) > 0:
217
+ cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
218
+ else:
219
+ cur_h_nei = zero_pad
220
+
221
+ cur_x = create_var(torch.LongTensor([node_x.wid]))
222
+ cur_x = self.embedding(cur_x)
223
+
224
+ #Predict stop
225
+ cur_h = cur_h_nei.sum(dim=1)
226
+ stop_hiddens = torch.cat([cur_x,cur_h], dim=1)
227
+ stop_hiddens = F.relu( self.U_i(stop_hiddens) )
228
+ stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop')
229
+
230
+ if prob_decode:
231
+ backtrack = (torch.bernoulli( torch.sigmoid(stop_score) ).item() == 0)
232
+ else:
233
+ backtrack = (stop_score.item() < 0)
234
+
235
+ if not backtrack: #Forward: Predict next clique
236
+ new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
237
+ pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word')
238
+
239
+ if prob_decode:
240
+ sort_wid = torch.multinomial(F.softmax(pred_score, dim=1).squeeze(), 5)
241
+ else:
242
+ _,sort_wid = torch.sort(pred_score, dim=1, descending=True)
243
+ sort_wid = sort_wid.data.squeeze()
244
+
245
+ next_wid = None
246
+ for wid in sort_wid[:5]:
247
+ slots = self.vocab.get_slots(wid)
248
+ node_y = MolTreeNode(self.vocab.get_smiles(wid))
249
+ if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
250
+ next_wid = wid
251
+ next_slots = slots
252
+ break
253
+
254
+ if next_wid is None:
255
+ backtrack = True #No more children can be added
256
+ else:
257
+ node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
258
+ node_y.wid = next_wid
259
+ node_y.idx = len(all_nodes)
260
+ node_y.neighbors.append(node_x)
261
+ h[(node_x.idx,node_y.idx)] = new_h[0]
262
+ stack.append( (node_y,next_slots) )
263
+ all_nodes.append(node_y)
264
+
265
+ if backtrack: #Backtrack, use if instead of else
266
+ if len(stack) == 1:
267
+ break #At root, terminate
268
+
269
+ node_fa,_ = stack[-2]
270
+ cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ]
271
+ if len(cur_h_nei) > 0:
272
+ cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
273
+ else:
274
+ cur_h_nei = zero_pad
275
+
276
+ new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
277
+ h[(node_x.idx,node_fa.idx)] = new_h[0]
278
+ node_fa.neighbors.append(node_x)
279
+ stack.pop()
280
+
281
+ return root, all_nodes
282
+
283
+ """
284
+ Helper Functions:
285
+ """
286
+ def dfs(stack, x, fa_idx):
287
+ for y in x.neighbors:
288
+ if y.idx == fa_idx: continue
289
+ stack.append( (x,y,1) )
290
+ dfs(stack, y, x.idx)
291
+ stack.append( (y,x,0) )
292
+
293
+ def have_slots(fa_slots, ch_slots):
294
+ if len(fa_slots) > 2 and len(ch_slots) > 2:
295
+ return True
296
+ matches = []
297
+ for i,s1 in enumerate(fa_slots):
298
+ a1,c1,h1 = s1
299
+ for j,s2 in enumerate(ch_slots):
300
+ a2,c2,h2 = s2
301
+ if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
302
+ matches.append( (i,j) )
303
+
304
+ if len(matches) == 0: return False
305
+
306
+ fa_match,ch_match = zip(*matches)
307
+ if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
308
+ fa_slots.pop(fa_match[0])
309
+ if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
310
+ ch_slots.pop(ch_match[0])
311
+
312
+ return True
313
+
314
+ def can_assemble(node_x, node_y):
315
+ node_x.nid = 1
316
+ node_x.is_leaf = False
317
+ set_atommap(node_x.mol, node_x.nid)
318
+
319
+ neis = node_x.neighbors + [node_y]
320
+ for i,nei in enumerate(neis):
321
+ nei.nid = i + 2
322
+ nei.is_leaf = (len(nei.neighbors) <= 1)
323
+ if nei.is_leaf:
324
+ set_atommap(nei.mol, 0)
325
+ else:
326
+ set_atommap(nei.mol, nei.nid)
327
+
328
+ neighbors = [nei for nei in neis if nei.mol.GetNumAtoms() > 1]
329
+ neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
330
+ singletons = [nei for nei in neis if nei.mol.GetNumAtoms() == 1]
331
+ neighbors = singletons + neighbors
332
+ cands,aroma_scores = enum_assemble(node_x, neighbors)
333
+ return len(cands) > 0# and sum(aroma_scores) >= 0
334
+
335
+ if __name__ == "__main__":
336
+ smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1","O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
337
+ for s in smiles:
338
+ print(s)
339
+ tree = MolTree(s)
340
+ for i,node in enumerate(tree.nodes):
341
+ node.idx = i
342
+
343
+ stack = []
344
+ dfs(stack, tree.nodes[0], -1)
345
+ for x,y,d in stack:
346
+ print(x.smiles, y.smiles, d)
347
+ print('------------------------------')
fast_jtnn/jtnn_enc.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import deque
5
+ from mol_tree import Vocab, MolTree
6
+ from nnutils import create_var, index_select_ND
7
+
8
+ class JTNNEncoder(nn.Module):
9
+
10
+ def __init__(self, hidden_size, depth, embedding):
11
+ super(JTNNEncoder, self).__init__()
12
+ self.hidden_size = hidden_size
13
+ self.depth = depth
14
+
15
+ self.embedding = embedding
16
+ self.outputNN = nn.Sequential(
17
+ nn.Linear(2 * hidden_size, hidden_size),
18
+ nn.ReLU()
19
+ )
20
+ self.GRU = GraphGRU(hidden_size, hidden_size, depth=depth)
21
+
22
+ def forward(self, fnode, fmess, node_graph, mess_graph, scope):
23
+ fnode = create_var(fnode)
24
+ fmess = create_var(fmess)
25
+ node_graph = create_var(node_graph)
26
+ mess_graph = create_var(mess_graph)
27
+ messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))
28
+
29
+ fnode = self.embedding(fnode)
30
+ fmess = index_select_ND(fnode, 0, fmess)
31
+ messages = self.GRU(messages, fmess, mess_graph)
32
+
33
+ mess_nei = index_select_ND(messages, 0, node_graph)
34
+ node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
35
+ node_vecs = self.outputNN(node_vecs)
36
+
37
+ max_len = max([x for _,x in scope])
38
+ batch_vecs = []
39
+ for st,le in scope:
40
+ cur_vecs = node_vecs[st] #Root is the first node
41
+ batch_vecs.append( cur_vecs )
42
+
43
+ tree_vecs = torch.stack(batch_vecs, dim=0)
44
+ return tree_vecs, messages
45
+
46
+ @staticmethod
47
+ def tensorize(tree_batch):
48
+ node_batch = []
49
+ scope = []
50
+ for tree in tree_batch:
51
+ scope.append( (len(node_batch), len(tree.nodes)) )
52
+ node_batch.extend(tree.nodes)
53
+
54
+ return JTNNEncoder.tensorize_nodes(node_batch, scope)
55
+
56
+ @staticmethod
57
+ def tensorize_nodes(node_batch, scope):
58
+ messages,mess_dict = [None],{}
59
+ fnode = []
60
+ for x in node_batch:
61
+ fnode.append(x.wid)
62
+ for y in x.neighbors:
63
+ mess_dict[(x.idx,y.idx)] = len(messages)
64
+ messages.append( (x,y) )
65
+
66
+ node_graph = [[] for i in range(len(node_batch))]
67
+ mess_graph = [[] for i in range(len(messages))]
68
+ fmess = [0] * len(messages)
69
+
70
+ for x,y in messages[1:]:
71
+ mid1 = mess_dict[(x.idx,y.idx)]
72
+ fmess[mid1] = x.idx
73
+ node_graph[y.idx].append(mid1)
74
+ for z in y.neighbors:
75
+ if z.idx == x.idx: continue
76
+ mid2 = mess_dict[(y.idx,z.idx)]
77
+ mess_graph[mid2].append(mid1)
78
+
79
+ max_len = max([len(t) for t in node_graph] + [1])
80
+ for t in node_graph:
81
+ pad_len = max_len - len(t)
82
+ t.extend([0] * pad_len)
83
+
84
+ max_len = max([len(t) for t in mess_graph] + [1])
85
+ for t in mess_graph:
86
+ pad_len = max_len - len(t)
87
+ t.extend([0] * pad_len)
88
+
89
+ mess_graph = torch.LongTensor(mess_graph)
90
+ node_graph = torch.LongTensor(node_graph)
91
+ fmess = torch.LongTensor(fmess)
92
+ fnode = torch.LongTensor(fnode)
93
+ return (fnode, fmess, node_graph, mess_graph, scope), mess_dict
94
+
95
+ class GraphGRU(nn.Module):
96
+
97
+ def __init__(self, input_size, hidden_size, depth):
98
+ super(GraphGRU, self).__init__()
99
+ self.hidden_size = hidden_size
100
+ self.input_size = input_size
101
+ self.depth = depth
102
+
103
+ self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
104
+ self.W_r = nn.Linear(input_size, hidden_size, bias=False)
105
+ self.U_r = nn.Linear(hidden_size, hidden_size)
106
+ self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
107
+
108
+ def forward(self, h, x, mess_graph):
109
+ mask = torch.ones(h.size(0), 1)
110
+ mask[0] = 0 #first vector is padding
111
+ mask = create_var(mask)
112
+ for it in range(self.depth):
113
+ h_nei = index_select_ND(h, 0, mess_graph)
114
+ sum_h = h_nei.sum(dim=1)
115
+ z_input = torch.cat([x, sum_h], dim=1)
116
+ z = F.sigmoid(self.W_z(z_input))
117
+
118
+ r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
119
+ r_2 = self.U_r(h_nei)
120
+ r = F.sigmoid(r_1 + r_2)
121
+
122
+ gated_h = r * h_nei
123
+ sum_gated_h = gated_h.sum(dim=1)
124
+ h_input = torch.cat([x, sum_gated_h], dim=1)
125
+ pre_h = F.tanh(self.W_h(h_input))
126
+ h = (1.0 - z) * sum_h + z * pre_h
127
+ h = h * mask
128
+
129
+ return h
130
+
131
+
fast_jtnn/jtnn_vae.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from mol_tree import Vocab, MolTree
5
+ from nnutils import create_var, flatten_tensor, avg_pool
6
+ from jtnn_enc import JTNNEncoder
7
+ from jtnn_dec import JTNNDecoder
8
+ from mpn import MPN
9
+ from jtmpn import JTMPN
10
+ from datautils import tensorize
11
+
12
+ from chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
13
+ import rdkit
14
+ import rdkit.Chem as Chem
15
+ import copy, math
16
+
17
+ class JTNNVAE(nn.Module):
18
+
19
+ def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
20
+ super(JTNNVAE, self).__init__()
21
+ self.vocab = vocab
22
+ self.hidden_size = hidden_size
23
+ self.latent_size = latent_size = int(latent_size / 2) #Tree and Mol has two vectors
24
+
25
+ self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
26
+ self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))
27
+
28
+ self.jtmpn = JTMPN(hidden_size, depthG)
29
+ self.mpn = MPN(hidden_size, depthG)
30
+
31
+ self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
32
+ # self.assm_loss = nn.CrossEntropyLoss(size_average=False)
33
+ self.assm_loss = nn.CrossEntropyLoss(reduction='sum')
34
+
35
+ self.T_mean = nn.Linear(hidden_size, latent_size)
36
+ self.T_var = nn.Linear(hidden_size, latent_size)
37
+ self.G_mean = nn.Linear(hidden_size, latent_size)
38
+ self.G_var = nn.Linear(hidden_size, latent_size)
39
+
40
+ def encode(self, jtenc_holder, mpn_holder):
41
+ tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
42
+ mol_vecs = self.mpn(*mpn_holder)
43
+ return tree_vecs, tree_mess, mol_vecs
44
+
45
+ def encode_from_smiles(self, smiles_list):
46
+ tree_batch = [MolTree(s) for s in smiles_list]
47
+ _, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
48
+ tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder)
49
+ return torch.cat([tree_vecs, mol_vecs], dim=-1)
50
+
51
+ def encode_latent(self, jtenc_holder, mpn_holder):
52
+ tree_vecs, _ = self.jtnn(*jtenc_holder)
53
+ mol_vecs = self.mpn(*mpn_holder)
54
+ tree_mean = self.T_mean(tree_vecs)
55
+ mol_mean = self.G_mean(mol_vecs)
56
+ tree_var = -torch.abs(self.T_var(tree_vecs))
57
+ mol_var = -torch.abs(self.G_var(mol_vecs))
58
+ return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1)
59
+
60
+ def rsample(self, z_vecs, W_mean, W_var):
61
+ batch_size = z_vecs.size(0)
62
+ z_mean = W_mean(z_vecs)
63
+ z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
64
+ kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
65
+ epsilon = create_var(torch.randn_like(z_mean))
66
+ z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
67
+ return z_vecs, kl_loss
68
+
69
+ def sample_prior(self, prob_decode=False):
70
+ z_tree = torch.randn(1, self.latent_size).cuda()
71
+ z_mol = torch.randn(1, self.latent_size).cuda()
72
+ return self.decode(z_tree, z_mol, prob_decode)
73
+
74
+ def forward(self, x_batch, beta):
75
+ x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder = x_batch
76
+ x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
77
+ z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var)
78
+ z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var)
79
+
80
+ kl_div = tree_kl + mol_kl
81
+ word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
82
+ assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess)
83
+
84
+ return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc
85
+
86
+ def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
87
+ jtmpn_holder,batch_idx = jtmpn_holder
88
+ fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
89
+ batch_idx = create_var(batch_idx)
90
+
91
+ cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess)
92
+
93
+ x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
94
+ x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
95
+ scores = torch.bmm(
96
+ x_mol_vecs.unsqueeze(1),
97
+ cand_vecs.unsqueeze(-1)
98
+ ).squeeze()
99
+
100
+ cnt,tot,acc = 0,0,0
101
+ all_loss = []
102
+ for i,mol_tree in enumerate(mol_batch):
103
+ comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
104
+ cnt += len(comp_nodes)
105
+ for node in comp_nodes:
106
+ label = node.cands.index(node.label)
107
+ ncand = len(node.cands)
108
+ cur_score = scores.narrow(0, tot, ncand)
109
+ tot += ncand
110
+
111
+ if cur_score.data[label] >= cur_score.max().item():
112
+ acc += 1
113
+
114
+ label = create_var(torch.LongTensor([label]))
115
+ all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
116
+
117
+ all_loss = sum(all_loss) / len(mol_batch)
118
+ return all_loss, acc * 1.0 / cnt
119
+
120
+ def decode(self, x_tree_vecs, x_mol_vecs, prob_decode):
121
+ #currently do not support batch decoding
122
+ assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1
123
+
124
+ pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode)
125
+ if len(pred_nodes) == 0: return None
126
+ elif len(pred_nodes) == 1: return pred_root.smiles
127
+
128
+ #Mark nid & is_leaf & atommap
129
+ for i,node in enumerate(pred_nodes):
130
+ node.nid = i + 1
131
+ node.is_leaf = (len(node.neighbors) == 1)
132
+ if len(node.neighbors) > 1:
133
+ set_atommap(node.mol, node.nid)
134
+
135
+ scope = [(0, len(pred_nodes))]
136
+ jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
137
+ _,tree_mess = self.jtnn(*jtenc_holder)
138
+ tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict
139
+
140
+ x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear
141
+
142
+ cur_mol = copy_edit_mol(pred_root.mol)
143
+ global_amap = [{}] + [{} for node in pred_nodes]
144
+ global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
145
+
146
+ cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True)
147
+ if cur_mol is None:
148
+ cur_mol = copy_edit_mol(pred_root.mol)
149
+ global_amap = [{}] + [{} for node in pred_nodes]
150
+ global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
151
+ cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False)
152
+ if cur_mol is None: cur_mol = pre_mol
153
+
154
+ if cur_mol is None:
155
+ return None
156
+
157
+ cur_mol = cur_mol.GetMol()
158
+ set_atommap(cur_mol)
159
+ cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
160
+ return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
161
+
162
+ def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
163
+ fa_nid = fa_node.nid if fa_node is not None else -1
164
+ prev_nodes = [fa_node] if fa_node is not None else []
165
+
166
+ children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
167
+ neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
168
+ neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
169
+ singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
170
+ neighbors = singletons + neighbors
171
+
172
+ cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
173
+ cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
174
+ if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
175
+ return None, cur_mol
176
+
177
+ cand_smiles,cand_amap = zip(*cands)
178
+ aroma_score = torch.Tensor(aroma_score).cuda()
179
+ cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]
180
+
181
+ if len(cands) > 1:
182
+ jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
183
+ fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
184
+ cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
185
+ scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score
186
+ else:
187
+ scores = torch.Tensor([1.0])
188
+
189
+ if prob_decode:
190
+ probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
191
+ cand_idx = torch.multinomial(probs, probs.numel())
192
+ else:
193
+ _,cand_idx = torch.sort(scores, descending=True)
194
+
195
+ backup_mol = Chem.RWMol(cur_mol)
196
+ pre_mol = cur_mol
197
+ for i in range(cand_idx.numel()):
198
+ cur_mol = Chem.RWMol(backup_mol)
199
+ pred_amap = cand_amap[cand_idx[i].item()]
200
+ new_global_amap = copy.deepcopy(global_amap)
201
+
202
+ for nei_id,ctr_atom,nei_atom in pred_amap:
203
+ if nei_id == fa_nid:
204
+ continue
205
+ new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]
206
+
207
+ cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
208
+ new_mol = cur_mol.GetMol()
209
+ new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
210
+
211
+ if new_mol is None: continue
212
+
213
+ has_error = False
214
+ for nei_node in children:
215
+ if nei_node.is_leaf: continue
216
+ tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
217
+ if tmp_mol is None:
218
+ has_error = True
219
+ if i == 0: pre_mol = tmp_mol2
220
+ break
221
+ cur_mol = tmp_mol
222
+
223
+ if not has_error: return cur_mol, cur_mol
224
+
225
+ return None, pre_mol
226
+
fast_jtnn/jtprop_vae.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from mol_tree import Vocab, MolTree
5
+ from nnutils import create_var, flatten_tensor, avg_pool
6
+ from jtnn_enc import JTNNEncoder
7
+ from jtnn_dec import JTNNDecoder
8
+ from mpn import MPN
9
+ from jtmpn import JTMPN
10
+ from datautils import tensorize
11
+
12
+ from chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
13
+ import rdkit
14
+ import rdkit.Chem as Chem
15
+ from rdkit import DataStructs
16
+ from rdkit.Chem import AllChem
17
+ import copy, math
18
+
19
+ class JTPropVAE(nn.Module):
20
+
21
+ def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
22
+ super(JTPropVAE, self).__init__()
23
+ self.vocab = vocab
24
+ self.hidden_size = hidden_size
25
+ self.latent_size = latent_size = int(latent_size / 2) #Tree and Mol has two vectors
26
+
27
+ self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
28
+ self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))
29
+
30
+ self.jtmpn = JTMPN(hidden_size, depthG)
31
+ self.mpn = MPN(hidden_size, depthG)
32
+
33
+ self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
34
+ # self.assm_loss = nn.CrossEntropyLoss(size_average=False)
35
+ self.assm_loss = nn.CrossEntropyLoss(reduction='sum')
36
+
37
+ self.T_mean = nn.Linear(hidden_size, latent_size)
38
+ self.T_var = nn.Linear(hidden_size, latent_size)
39
+ self.G_mean = nn.Linear(hidden_size, latent_size)
40
+ self.G_var = nn.Linear(hidden_size, latent_size)
41
+
42
+ # Prop
43
+ self.propNN = nn.Sequential(
44
+ nn.Linear(self.latent_size*2, self.hidden_size),
45
+ nn.Tanh(),
46
+ nn.Linear(self.hidden_size, 1)
47
+ )
48
+ self.prop_loss = nn.MSELoss()
49
+
50
+ def encode(self, jtenc_holder, mpn_holder):
51
+ tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
52
+ mol_vecs = self.mpn(*mpn_holder)
53
+ return tree_vecs, tree_mess, mol_vecs
54
+
55
+ def encode_from_smiles(self, smiles_list):
56
+ tree_batch = [MolTree(s) for s in smiles_list]
57
+ _, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
58
+ tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder)
59
+ return torch.cat([tree_vecs, mol_vecs], dim=-1)
60
+
61
+ def encode_latent(self, jtenc_holder, mpn_holder):
62
+ tree_vecs, _ = self.jtnn(*jtenc_holder)
63
+ mol_vecs = self.mpn(*mpn_holder)
64
+ tree_mean = self.T_mean(tree_vecs)
65
+ mol_mean = self.G_mean(mol_vecs)
66
+ tree_var = -torch.abs(self.T_var(tree_vecs))
67
+ mol_var = -torch.abs(self.G_var(mol_vecs))
68
+ return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1)
69
+
70
+ def rsample(self, z_vecs, W_mean, W_var):
71
+ batch_size = z_vecs.size(0)
72
+ z_mean = W_mean(z_vecs)
73
+ z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
74
+ kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
75
+ epsilon = create_var(torch.randn_like(z_mean))
76
+ z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
77
+ return z_vecs, kl_loss
78
+
79
+ def sample_prior(self, prob_decode=False):
80
+ z_tree = torch.randn(1, self.latent_size).cuda()
81
+ z_mol = torch.randn(1, self.latent_size).cuda()
82
+ return self.decode(z_tree, z_mol, prob_decode)
83
+
84
+ def forward(self, x_batch, beta):
85
+ x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder, prop_batch = x_batch
86
+ x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
87
+ z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var)
88
+ z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var)
89
+
90
+ kl_div = tree_kl + mol_kl
91
+ word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
92
+ assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess)
93
+
94
+ all_vec = torch.cat([z_tree_vecs, z_mol_vecs], dim=1)
95
+ # prop_label = create_var(torch.Tensor(prop_batch))
96
+ prop_label = create_var(torch.Tensor(prop_batch))
97
+ prop_loss = self.prop_loss(self.propNN(all_vec).squeeze(), prop_label)
98
+
99
+ return word_loss + topo_loss + assm_loss + beta * kl_div + prop_loss, kl_div.item(), word_acc, topo_acc, assm_acc, prop_loss.item()
100
+
101
+ def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
102
+ jtmpn_holder,batch_idx = jtmpn_holder
103
+ fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
104
+ batch_idx = create_var(batch_idx)
105
+
106
+ cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess)
107
+
108
+ x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
109
+ x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
110
+ scores = torch.bmm(
111
+ x_mol_vecs.unsqueeze(1),
112
+ cand_vecs.unsqueeze(-1)
113
+ ).squeeze()
114
+
115
+ cnt,tot,acc = 0,0,0
116
+ all_loss = []
117
+ for i,mol_tree in enumerate(mol_batch):
118
+ comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
119
+ cnt += len(comp_nodes)
120
+ for node in comp_nodes:
121
+ label = node.cands.index(node.label)
122
+ ncand = len(node.cands)
123
+ cur_score = scores.narrow(0, tot, ncand)
124
+ tot += ncand
125
+
126
+ if cur_score.data[label] >= cur_score.max().item():
127
+ acc += 1
128
+
129
+ label = create_var(torch.LongTensor([label]))
130
+ all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
131
+
132
+ all_loss = sum(all_loss) / len(mol_batch)
133
+ return all_loss, acc * 1.0 / cnt
134
+
135
+ def decode(self, x_tree_vecs, x_mol_vecs, prob_decode):
136
+ #currently do not support batch decoding
137
+ assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1
138
+
139
+ pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode)
140
+ if len(pred_nodes) == 0: return None
141
+ elif len(pred_nodes) == 1: return pred_root.smiles
142
+
143
+ #Mark nid & is_leaf & atommap
144
+ for i,node in enumerate(pred_nodes):
145
+ node.nid = i + 1
146
+ node.is_leaf = (len(node.neighbors) == 1)
147
+ if len(node.neighbors) > 1:
148
+ set_atommap(node.mol, node.nid)
149
+
150
+ scope = [(0, len(pred_nodes))]
151
+ jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
152
+ _,tree_mess = self.jtnn(*jtenc_holder)
153
+ tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict
154
+
155
+ x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear
156
+
157
+ cur_mol = copy_edit_mol(pred_root.mol)
158
+ global_amap = [{}] + [{} for node in pred_nodes]
159
+ global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
160
+
161
+ cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True)
162
+ if cur_mol is None:
163
+ cur_mol = copy_edit_mol(pred_root.mol)
164
+ global_amap = [{}] + [{} for node in pred_nodes]
165
+ global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
166
+ cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False)
167
+ if cur_mol is None: cur_mol = pre_mol
168
+
169
+ if cur_mol is None:
170
+ return None
171
+
172
+ cur_mol = cur_mol.GetMol()
173
+ set_atommap(cur_mol)
174
+ cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
175
+ return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
176
+
177
+ def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
178
+ fa_nid = fa_node.nid if fa_node is not None else -1
179
+ prev_nodes = [fa_node] if fa_node is not None else []
180
+
181
+ children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
182
+ neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
183
+ neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
184
+ singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
185
+ neighbors = singletons + neighbors
186
+
187
+ cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
188
+ cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
189
+ if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
190
+ return None, cur_mol
191
+
192
+ cand_smiles,cand_amap = zip(*cands)
193
+ if torch.cuda.is_available():
194
+ aroma_score = torch.Tensor(aroma_score).cuda()
195
+ else:
196
+ aroma_score = torch.Tensor(aroma_score)
197
+ cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]
198
+
199
+ if len(cands) > 1:
200
+ jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
201
+ fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
202
+ cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
203
+ scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score
204
+ else:
205
+ scores = torch.Tensor([1.0])
206
+
207
+ if prob_decode:
208
+ probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
209
+ cand_idx = torch.multinomial(probs, probs.numel())
210
+ else:
211
+ _,cand_idx = torch.sort(scores, descending=True)
212
+
213
+ backup_mol = Chem.RWMol(cur_mol)
214
+ pre_mol = cur_mol
215
+ for i in range(cand_idx.numel()):
216
+ cur_mol = Chem.RWMol(backup_mol)
217
+ pred_amap = cand_amap[cand_idx[i].item()]
218
+ new_global_amap = copy.deepcopy(global_amap)
219
+
220
+ for nei_id,ctr_atom,nei_atom in pred_amap:
221
+ if nei_id == fa_nid:
222
+ continue
223
+ new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]
224
+
225
+ cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
226
+ new_mol = cur_mol.GetMol()
227
+ new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
228
+
229
+ if new_mol is None: continue
230
+
231
+ has_error = False
232
+ for nei_node in children:
233
+ if nei_node.is_leaf: continue
234
+ tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
235
+ if tmp_mol is None:
236
+ has_error = True
237
+ if i == 0: pre_mol = tmp_mol2
238
+ break
239
+ cur_mol = tmp_mol
240
+
241
+ if not has_error: return cur_mol, cur_mol
242
+
243
+ return None, pre_mol
244
+
245
+ def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20):
246
+ # mol_tree = MolTree(smiles)
247
+ # mol_tree.recover()
248
+ tree_batch = [MolTree(smiles)]
249
+ _, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
250
+ tree_vec, _, mol_vec = self.encode(jtenc_holder, mpn_holder)
251
+
252
+ mol = Chem.MolFromSmiles(smiles)
253
+ fp1 = AllChem.GetMorganFingerprint(mol, 2)
254
+
255
+ tree_mean = self.T_mean(tree_vec)
256
+ tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
257
+ mol_mean = self.G_mean(mol_vec)
258
+ mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.
259
+ mean = torch.cat([tree_mean, mol_mean], dim=1)
260
+ log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
261
+ cur_vec = create_var(mean.data, True)
262
+
263
+ visited = []
264
+ for step in range(num_iter):
265
+ prop_val = self.propNN(cur_vec).squeeze()
266
+ grad = torch.autograd.grad(prop_val, cur_vec)[0]
267
+ cur_vec = cur_vec.data + lr * grad.data
268
+ cur_vec = create_var(cur_vec, True)
269
+ visited.append(cur_vec)
270
+
271
+ l,r = 0, num_iter - 1
272
+ while l < r - 1:
273
+ mid = (l + r) // 2
274
+ new_vec = visited[mid]
275
+ tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1)
276
+ new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
277
+ if new_smiles is None:
278
+ r = mid - 1
279
+ continue
280
+
281
+ new_mol = Chem.MolFromSmiles(new_smiles)
282
+ fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
283
+ sim = DataStructs.TanimotoSimilarity(fp1, fp2)
284
+ if sim < sim_cutoff:
285
+ r = mid - 1
286
+ else:
287
+ l = mid
288
+ """
289
+ best_vec = visited[0]
290
+ for new_vec in visited:
291
+ tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1)
292
+ new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
293
+ if new_smiles is None: continue
294
+ new_mol = Chem.MolFromSmiles(new_smiles)
295
+ fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
296
+ sim = DataStructs.TanimotoSimilarity(fp1, fp2)
297
+ if sim >= sim_cutoff:
298
+ best_vec = new_vec
299
+ """
300
+ tree_vec,mol_vec = torch.chunk(visited[l], 2, dim=1)
301
+ #tree_vec,mol_vec = torch.chunk(best_vec, 2, dim=1)
302
+ new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
303
+ if new_smiles is None:
304
+ return None, None
305
+ new_mol = Chem.MolFromSmiles(new_smiles)
306
+ fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
307
+ sim = DataStructs.TanimotoSimilarity(fp1, fp2)
308
+ if sim >= sim_cutoff:
309
+ return new_smiles, sim
310
+ else:
311
+ return None, None
fast_jtnn/mol_tree.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rdkit
2
+ import rdkit.Chem as Chem
3
+ from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
4
+ from vocab import *
5
+ import argparse
6
+
7
+ class MolTreeNode(object):
8
+
9
+ def __init__(self, smiles, clique=[]):
10
+ self.smiles = smiles
11
+ self.mol = get_mol(self.smiles)
12
+
13
+ self.clique = [x for x in clique] #copy
14
+ self.neighbors = []
15
+
16
+ def add_neighbor(self, nei_node):
17
+ self.neighbors.append(nei_node)
18
+
19
+ def recover(self, original_mol):
20
+ clique = []
21
+ clique.extend(self.clique)
22
+ if not self.is_leaf:
23
+ for cidx in self.clique:
24
+ original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
25
+
26
+ for nei_node in self.neighbors:
27
+ clique.extend(nei_node.clique)
28
+ if nei_node.is_leaf: #Leaf node, no need to mark
29
+ continue
30
+ for cidx in nei_node.clique:
31
+ #allow singleton node override the atom mapping
32
+ if cidx not in self.clique or len(nei_node.clique) == 1:
33
+ atom = original_mol.GetAtomWithIdx(cidx)
34
+ atom.SetAtomMapNum(nei_node.nid)
35
+
36
+ clique = list(set(clique))
37
+ label_mol = get_clique_mol(original_mol, clique)
38
+ self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
39
+
40
+ for cidx in clique:
41
+ original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
42
+
43
+ return self.label
44
+
45
+ def assemble(self):
46
+ neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
47
+ neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
48
+ singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
49
+ neighbors = singletons + neighbors
50
+
51
+ cands,aroma = enum_assemble(self, neighbors)
52
+ new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0]
53
+ if len(new_cands) > 0: cands = new_cands
54
+
55
+ if len(cands) > 0:
56
+ self.cands, _ = zip(*cands)
57
+ self.cands = list(self.cands)
58
+ else:
59
+ self.cands = []
60
+
61
+ class MolTree(object):
62
+
63
+ def __init__(self, smiles):
64
+ self.smiles = smiles
65
+ self.mol = get_mol(smiles)
66
+
67
+ #Stereo Generation (currently disabled)
68
+ #mol = Chem.MolFromSmiles(smiles)
69
+ #self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
70
+ #self.smiles2D = Chem.MolToSmiles(mol)
71
+ #self.stereo_cands = decode_stereo(self.smiles2D)
72
+
73
+ cliques, edges = tree_decomp(self.mol)
74
+ self.nodes = []
75
+ root = 0
76
+ for i,c in enumerate(cliques):
77
+ cmol = get_clique_mol(self.mol, c)
78
+ node = MolTreeNode(get_smiles(cmol), c)
79
+ self.nodes.append(node)
80
+ if min(c) == 0: root = i
81
+
82
+ for x,y in edges:
83
+ self.nodes[x].add_neighbor(self.nodes[y])
84
+ self.nodes[y].add_neighbor(self.nodes[x])
85
+
86
+ if root > 0:
87
+ self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]
88
+
89
+ for i,node in enumerate(self.nodes):
90
+ node.nid = i + 1
91
+ if len(node.neighbors) > 1: #Leaf node mol is not marked
92
+ set_atommap(node.mol, node.nid)
93
+ node.is_leaf = (len(node.neighbors) == 1)
94
+
95
+ def size(self):
96
+ return len(self.nodes)
97
+
98
+ def recover(self):
99
+ for node in self.nodes:
100
+ node.recover(self.mol)
101
+
102
+ def assemble(self):
103
+ for node in self.nodes:
104
+ node.assemble()
105
+
106
+ def dfs(node, fa_idx):
107
+ max_depth = 0
108
+ for child in node.neighbors:
109
+ if child.idx == fa_idx: continue
110
+ max_depth = max(max_depth, dfs(child, node.idx))
111
+ return max_depth + 1
112
+
113
+ def data_process_chunk(smiles_list):
114
+ cset = set()
115
+ for line in smiles_list:
116
+ smiles = line.split()[0]
117
+ # print(smiles)
118
+ mol = MolTree(smiles)
119
+ for c in mol.nodes:
120
+ cset.add(c.smiles)
121
+ # i+=1
122
+ # if i%10000 == 0:
123
+ # # print(i,end='\x1b[1K\r')
124
+ # print(i, ' / 1584663')
125
+ return list(cset)
126
+
127
+ if __name__ == "__main__":
128
+ import sys
129
+ lg = rdkit.RDLogger.logger()
130
+ lg.setLevel(rdkit.RDLogger.CRITICAL)
131
+
132
+ i = 0
133
+
134
+ import os
135
+ from joblib import Parallel,delayed
136
+ from tqdm import tqdm
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument('--smiles_path', type=str,required=True)
139
+ parser.add_argument('--vocab_path', type=str,required=True)
140
+ parser.add_argument('--prop', type=bool,default=False)
141
+ parser.add_argument('--ncpu', default=8,type=int)
142
+ args = parser.parse_args()
143
+
144
+ if args.prop:
145
+ import pandas as pd
146
+ smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES'])
147
+ smiles_list = list(smiles_list.SMILES)
148
+ else:
149
+ with open(args.smiles_path,'r') as f:
150
+ smiles_list = [line.split()[0] for line in f]
151
+ print('Total smiles = ',len(smiles_list))
152
+
153
+ # moses: 1584663
154
+
155
+ chunk_size = 10000
156
+ vocab_set_list = Parallel(n_jobs=args.ncpu)(
157
+ delayed(data_process_chunk)(smiles_list[start: start + chunk_size])
158
+ for start in tqdm(range(0, len(smiles_list), chunk_size))
159
+ )
160
+ vocab_list =[]
161
+ for set_list in vocab_set_list:
162
+ vocab_list.extend(set_list)
163
+
164
+ cset = set(vocab_list)
165
+ with open(args.vocab_path,'w') as f:
166
+ for x in cset:
167
+ f.write(''.join([x,'\n']))
168
+
fast_jtnn/mpn.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import rdkit.Chem as Chem
4
+ import torch.nn.functional as F
5
+ from nnutils import *
6
+ from chemutils import get_mol
7
+
8
+ ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
9
+
10
+ ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
11
+ BOND_FDIM = 5 + 6
12
+ MAX_NB = 6
13
+
14
+ def onek_encoding_unk(x, allowable_set):
15
+ if x not in allowable_set:
16
+ x = allowable_set[-1]
17
+ return list(map(lambda s: x == s, allowable_set))
18
+
19
+ def atom_features(atom):
20
+ return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
21
+ + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
22
+ + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
23
+ + onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
24
+ + [atom.GetIsAromatic()])
25
+
26
+ def bond_features(bond):
27
+ bt = bond.GetBondType()
28
+ stereo = int(bond.GetStereo())
29
+ fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
30
+ fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
31
+ return torch.Tensor(fbond + fstereo)
32
+
33
+ class MPN(nn.Module):
34
+
35
+ def __init__(self, hidden_size, depth):
36
+ super(MPN, self).__init__()
37
+ self.hidden_size = hidden_size
38
+ self.depth = depth
39
+
40
+ self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
41
+ self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
42
+ self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
43
+
44
+ def forward(self, fatoms, fbonds, agraph, bgraph, scope):
45
+ fatoms = create_var(fatoms)
46
+ fbonds = create_var(fbonds)
47
+ agraph = create_var(agraph)
48
+ bgraph = create_var(bgraph)
49
+
50
+ binput = self.W_i(fbonds)
51
+ message = F.relu(binput)
52
+
53
+ for i in range(self.depth - 1):
54
+ nei_message = index_select_ND(message, 0, bgraph)
55
+ nei_message = nei_message.sum(dim=1)
56
+ nei_message = self.W_h(nei_message)
57
+ message = F.relu(binput + nei_message)
58
+
59
+ nei_message = index_select_ND(message, 0, agraph)
60
+ nei_message = nei_message.sum(dim=1)
61
+ ainput = torch.cat([fatoms, nei_message], dim=1)
62
+ atom_hiddens = F.relu(self.W_o(ainput))
63
+
64
+ max_len = max([x for _,x in scope])
65
+ batch_vecs = []
66
+ for st,le in scope:
67
+ cur_vecs = atom_hiddens[st : st + le].mean(dim=0)
68
+ batch_vecs.append( cur_vecs )
69
+
70
+ mol_vecs = torch.stack(batch_vecs, dim=0)
71
+ return mol_vecs
72
+
73
+ @staticmethod
74
+ def tensorize(mol_batch):
75
+ padding = torch.zeros(ATOM_FDIM + BOND_FDIM)
76
+ fatoms,fbonds = [],[padding] #Ensure bond is 1-indexed
77
+ in_bonds,all_bonds = [],[(-1,-1)] #Ensure bond is 1-indexed
78
+ scope = []
79
+ total_atoms = 0
80
+
81
+ for smiles in mol_batch:
82
+ mol = get_mol(smiles)
83
+ #mol = Chem.MolFromSmiles(smiles)
84
+ n_atoms = mol.GetNumAtoms()
85
+ for atom in mol.GetAtoms():
86
+ fatoms.append( atom_features(atom) )
87
+ in_bonds.append([])
88
+
89
+ for bond in mol.GetBonds():
90
+ a1 = bond.GetBeginAtom()
91
+ a2 = bond.GetEndAtom()
92
+ x = a1.GetIdx() + total_atoms
93
+ y = a2.GetIdx() + total_atoms
94
+
95
+ b = len(all_bonds)
96
+ all_bonds.append((x,y))
97
+ fbonds.append( torch.cat([fatoms[x], bond_features(bond)], 0) )
98
+ in_bonds[y].append(b)
99
+
100
+ b = len(all_bonds)
101
+ all_bonds.append((y,x))
102
+ fbonds.append( torch.cat([fatoms[y], bond_features(bond)], 0) )
103
+ in_bonds[x].append(b)
104
+
105
+ scope.append((total_atoms,n_atoms))
106
+ total_atoms += n_atoms
107
+
108
+ total_bonds = len(all_bonds)
109
+ fatoms = torch.stack(fatoms, 0)
110
+ fbonds = torch.stack(fbonds, 0)
111
+ agraph = torch.zeros(total_atoms,MAX_NB).long()
112
+ bgraph = torch.zeros(total_bonds,MAX_NB).long()
113
+
114
+ for a in range(total_atoms):
115
+ for i,b in enumerate(in_bonds[a]):
116
+ agraph[a,i] = b
117
+
118
+ for b1 in range(1, total_bonds):
119
+ x,y = all_bonds[b1]
120
+ for i,b2 in enumerate(in_bonds[x]):
121
+ if all_bonds[b2][0] != y:
122
+ bgraph[b1,i] = b2
123
+
124
+ return (fatoms, fbonds, agraph, bgraph, scope)
125
+
fast_jtnn/nnutils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+
6
+ def create_var(tensor, requires_grad=None):
7
+ if requires_grad is None:
8
+ if torch.cuda.is_available():
9
+ return Variable(data = tensor).cuda()
10
+ else:
11
+ return Variable(data = tensor)
12
+ else:
13
+ if torch.cuda.is_available():
14
+ return Variable(data = tensor, requires_grad=requires_grad).cuda()
15
+ else:
16
+ return Variable(data = tensor, requires_grad=requires_grad)
17
+
18
+ def index_select_ND(source, dim, index):
19
+ index_size = index.size()
20
+ suffix_dim = source.size()[1:]
21
+ final_size = index_size + suffix_dim
22
+ target = source.index_select(dim, index.view(-1))
23
+ return target.view(final_size)
24
+
25
+ def avg_pool(all_vecs, scope, dim):
26
+ size = create_var(torch.Tensor([le for _,le in scope]))
27
+ return all_vecs.sum(dim=dim) / size.unsqueeze(-1)
28
+
29
+ def stack_pad_tensor(tensor_list):
30
+ max_len = max([t.size(0) for t in tensor_list])
31
+ for i,tensor in enumerate(tensor_list):
32
+ pad_len = max_len - tensor.size(0)
33
+ tensor_list[i] = F.pad( tensor, (0,0,0,pad_len) )
34
+ return torch.stack(tensor_list, dim=0)
35
+
36
+ #3D padded tensor to 2D matrix, with padded zeros removed
37
+ def flatten_tensor(tensor, scope):
38
+ assert tensor.size(0) == len(scope)
39
+ tlist = []
40
+ for i,tup in enumerate(scope):
41
+ le = tup[1]
42
+ tlist.append( tensor[i, 0:le] )
43
+ return torch.cat(tlist, dim=0)
44
+
45
+ #2D matrix to 3D padded tensor
46
+ def inflate_tensor(tensor, scope):
47
+ max_len = max([le for _,le in scope])
48
+ batch_vecs = []
49
+ for st,le in scope:
50
+ cur_vecs = tensor[st : st + le]
51
+ cur_vecs = F.pad( cur_vecs, (0,0,0,max_len-le) )
52
+ batch_vecs.append( cur_vecs )
53
+
54
+ return torch.stack(batch_vecs, dim=0)
55
+
56
+ def GRU(x, h_nei, W_z, W_r, U_r, W_h):
57
+ hidden_size = x.size()[-1]
58
+ sum_h = h_nei.sum(dim=1)
59
+ z_input = torch.cat([x,sum_h], dim=1)
60
+ z = F.sigmoid(W_z(z_input))
61
+
62
+ r_1 = W_r(x).view(-1,1,hidden_size)
63
+ r_2 = U_r(h_nei)
64
+ r = F.sigmoid(r_1 + r_2)
65
+
66
+ gated_h = r * h_nei
67
+ sum_gated_h = gated_h.sum(dim=1)
68
+ h_input = torch.cat([x,sum_gated_h], dim=1)
69
+ pre_h = F.tanh(W_h(h_input))
70
+ new_h = (1.0 - z) * sum_h + z * pre_h
71
+ return new_h
72
+
fast_jtnn/vocab.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rdkit
2
+ import rdkit.Chem as Chem
3
+ import copy
4
+
5
+ def get_slots(smiles):
6
+ mol = Chem.MolFromSmiles(smiles)
7
+ return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
8
+
9
+ class Vocab(object):
10
+ benzynes = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1']
11
+ penzynes = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1']
12
+
13
+ def __init__(self, smiles_list):
14
+ self.vocab = smiles_list
15
+ self.vmap = {x:i for i,x in enumerate(self.vocab)}
16
+ self.slots = [get_slots(smiles) for smiles in self.vocab]
17
+ Vocab.benzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 6] + ['C1=CCNCC1']
18
+ Vocab.penzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 5] + ['C1=NCCN1','C1=NNCC1']
19
+
20
+ def get_index(self, smiles):
21
+ return self.vmap[smiles]
22
+
23
+ def get_smiles(self, idx):
24
+ return self.vocab[idx]
25
+
26
+ def get_slots(self, idx):
27
+ return copy.deepcopy(self.slots[idx])
28
+
29
+ def size(self):
30
+ return len(self.vocab)
31
+
fpscores.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
3
+ size 3848394
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ rdkit
2
+ numpy
3
+ torch
4
+ argparse
5
+ tqdm
6
+ networkx
7
+ scipy
8
+ copy
9
+ molbloom
sascorer.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # calculation of synthetic accessibility score as described in:
3
+ #
4
+ # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5
+ # Peter Ertl and Ansgar Schuffenhauer
6
+ # Journal of Cheminformatics 1:8 (2009)
7
+ # http://www.jcheminf.com/content/1/1/8
8
+ #
9
+ # several small modifications to the original paper are included
10
+ # particularly slightly different formula for marocyclic penalty
11
+ # and taking into account also molecule symmetry (fingerprint density)
12
+ #
13
+ # for a set of 10k diverse molecules the agreement between the original method
14
+ # as implemented in PipelinePilot and this implementation is r2 = 0.97
15
+ #
16
+ # peter ertl & greg landrum, september 2013
17
+ #
18
+
19
+
20
+ from rdkit import Chem
21
+ from rdkit.Chem import rdMolDescriptors
22
+ import pickle
23
+
24
+ import math
25
+ from collections import defaultdict
26
+
27
+ import os.path as op
28
+
29
+ _fscores = None
30
+
31
+
32
+ def readFragmentScores(name='fpscores'):
33
+ import gzip
34
+ global _fscores
35
+ # generate the full path filename:
36
+ if name == "fpscores":
37
+ name = op.join(op.dirname(__file__), name)
38
+ _fscores = pickle.load(gzip.open('%s.pkl.gz' % name))
39
+ outDict = {}
40
+ for i in _fscores:
41
+ for j in range(1, len(i)):
42
+ outDict[i[j]] = float(i[0])
43
+ _fscores = outDict
44
+
45
+
46
+ def numBridgeheadsAndSpiro(mol, ri=None):
47
+ nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48
+ nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49
+ return nBridgehead, nSpiro
50
+
51
+
52
+ def calculateScore(m):
53
+ if _fscores is None:
54
+ readFragmentScores()
55
+
56
+ # fragment score
57
+ fp = rdMolDescriptors.GetMorganFingerprint(m,
58
+ 2) # <- 2 is the *radius* of the circular fingerprint
59
+ fps = fp.GetNonzeroElements()
60
+ score1 = 0.
61
+ nf = 0
62
+ for bitId, v in fps.items():
63
+ nf += v
64
+ sfp = bitId
65
+ score1 += _fscores.get(sfp, -4) * v
66
+ score1 /= nf
67
+
68
+ # features score
69
+ nAtoms = m.GetNumAtoms()
70
+ nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71
+ ri = m.GetRingInfo()
72
+ nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73
+ nMacrocycles = 0
74
+ for x in ri.AtomRings():
75
+ if len(x) > 8:
76
+ nMacrocycles += 1
77
+
78
+ sizePenalty = nAtoms**1.005 - nAtoms
79
+ stereoPenalty = math.log10(nChiralCenters + 1)
80
+ spiroPenalty = math.log10(nSpiro + 1)
81
+ bridgePenalty = math.log10(nBridgeheads + 1)
82
+ macrocyclePenalty = 0.
83
+ # ---------------------------------------
84
+ # This differs from the paper, which defines:
85
+ # macrocyclePenalty = math.log10(nMacrocycles+1)
86
+ # This form generates better results when 2 or more macrocycles are present
87
+ if nMacrocycles > 0:
88
+ macrocyclePenalty = math.log10(2)
89
+
90
+ score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91
+
92
+ # correction for the fingerprint density
93
+ # not in the original publication, added in version 1.1
94
+ # to make highly symmetrical molecules easier to synthetise
95
+ score3 = 0.
96
+ if nAtoms > len(fps):
97
+ score3 = math.log(float(nAtoms) / len(fps)) * .5
98
+
99
+ sascore = score1 + score2 + score3
100
+
101
+ # need to transform "raw" value into scale between 1 and 10
102
+ min = -4.0
103
+ max = 2.5
104
+ sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105
+ # smooth the 10-end
106
+ if sascore > 8.:
107
+ sascore = 8. + math.log(sascore + 1. - 9.)
108
+ if sascore > 10.:
109
+ sascore = 10.0
110
+ elif sascore < 1.:
111
+ sascore = 1.0
112
+
113
+ return sascore
114
+
115
+
116
+ def processMols(mols):
117
+ print('smiles\tName\tsa_score')
118
+ for i, m in enumerate(mols):
119
+ if m is None:
120
+ continue
121
+
122
+ s = calculateScore(m)
123
+
124
+ smiles = Chem.MolToSmiles(m)
125
+ print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126
+
127
+
128
+ if __name__ == '__main__':
129
+ import sys
130
+ import time
131
+
132
+ t1 = time.time()
133
+ readFragmentScores("fpscores")
134
+ t2 = time.time()
135
+
136
+ suppl = Chem.SmilesMolSupplier(sys.argv[1])
137
+ t3 = time.time()
138
+ processMols(suppl)
139
+ t4 = time.time()
140
+
141
+ print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142
+ file=sys.stderr)
143
+
144
+ #
145
+ # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146
+ # All rights reserved.
147
+ #
148
+ # Redistribution and use in source and binary forms, with or without
149
+ # modification, are permitted provided that the following conditions are
150
+ # met:
151
+ #
152
+ # * Redistributions of source code must retain the above copyright
153
+ # notice, this list of conditions and the following disclaimer.
154
+ # * Redistributions in binary form must reproduce the above
155
+ # copyright notice, this list of conditions and the following
156
+ # disclaimer in the documentation and/or other materials provided
157
+ # with the distribution.
158
+ # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159
+ # nor the names of its contributors may be used to endorse or promote
160
+ # products derived from this software without specific prior written permission.
161
+ #
162
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166
+ # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168
+ # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170
+ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173
+ #
vocab.txt ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ C=N
2
+ C1CSCNN1
3
+ C1CCSCC1
4
+ C1=NSN=C1
5
+ C1=CSCS1
6
+ C1NCS1
7
+ C1=CSCC1
8
+ C1=CC2C=CC1O2
9
+ C1CC2COCC(CN1)N2
10
+ C1=COCCCN1
11
+ C1=CNSNC1
12
+ C1=NNCS1
13
+ C1=NNCSC1
14
+ C1=CSNCS1
15
+ C1=NC=NC=N1
16
+ C1=CCCCC=C1
17
+ C1CNCN=N1
18
+ C1=NNN=N1
19
+ C1NCC2OCC1O2
20
+ C1CC2CCOC(C2)N1
21
+ C1=CC2CCC1C2
22
+ C1=NN=NCC1
23
+ C1=CC2CC1C2
24
+ C1=CNCN=C1
25
+ C1=CC2C=CC1NN2
26
+ C1CN=NNC1
27
+ C1=CSCCN1
28
+ C1=CC=COC=C1
29
+ C1=NCSC1
30
+ C1CC2CNCC1N2
31
+ C1N=NNCN1
32
+ C=S
33
+ C1CC2OCC3CCC(O1)C2C3
34
+ C1COCCOC1
35
+ C1NCSN1
36
+ C1=CNN=C1
37
+ C1CNCCOC1
38
+ C1CSNCN1
39
+ C1CNSN1
40
+ C1CC2CCC(O1)O2
41
+ C1=CNSC=C1
42
+ C1=NCCO1
43
+ C1=CCSNC1
44
+ C1=NCOCC1
45
+ C1=CONC1
46
+ C1CC2CCC1O2
47
+ C1=NN=NS1
48
+ C1=CC2CCCC1NN2
49
+ C1CSCN1
50
+ C1NC2CC3CC1CC(C3)C2
51
+ C1=NC=N[SH]=N1
52
+ C1OC2CNC1C2
53
+ C1=CCSC1
54
+ C1=CC2C=CC1N2
55
+ C1=CCNC=NC1
56
+ O=S
57
+ C1NN=NN1
58
+ C1=CCNC1
59
+ C1=CCCN=C1
60
+ C1=COCCOC1
61
+ C1CC2CNCC(C1)NC2
62
+ C1=CCOC1
63
+ C1=CCNNCC1
64
+ C1=CNNCC1
65
+ C1=CSC1
66
+ C1CC2CCNC(C1)C2
67
+ C1CC2NC(CCS2)S1
68
+ C1CNN1
69
+ C1=CNCCNC1
70
+ C1=CNCCCN1
71
+ C1NCNCN1
72
+ C1=CC2C=CC1C2
73
+ OCl
74
+ C1=CN=CSC1
75
+ C1=CCNC=CC1
76
+ C1=CC2CC1CN2
77
+ C1=CC2CCCN(C1)C2
78
+ C1=CC2CC(C1)CN2
79
+ NO
80
+ C1=COC=NC1
81
+ C1=NSN=CC1
82
+ C1=CN=COC1
83
+ C1=CSCCN=C1
84
+ C1=CC=CNC=C1
85
+ C1=NN=CO1
86
+ C1=CNSC1
87
+ C1CNC1
88
+ C1=COCN1
89
+ C1CCSCNC1
90
+ C1CC2CNCC1C2
91
+ N
92
+ C1=NN=CN1
93
+ C1=NCCNCC1
94
+ C1CCCCCC1
95
+ C1CC2CC(N1)C1COC2O1
96
+ C1=NCNNC1
97
+ C1=CN=NC=N1
98
+ C1COCN1
99
+ C1CNNCN1
100
+ C1=NCC=NCC1
101
+ C1=NCNS1
102
+ C1OC2OCC1CO2
103
+ C1CCOCNC1
104
+ CF
105
+ C1=NNCO1
106
+ C1=NC=NCCN1
107
+ C1CC2CCCC(C1)O2
108
+ C1=NNCN1
109
+ C1=CCN=CC1
110
+ C1CC2CNCC1CN2
111
+ C1CC2CNCC(C1)N2
112
+ C1=CCSCC1
113
+ C1CSNS1
114
+ C1=NCNC1
115
+ C1CC2C3NNC2C1O3
116
+ C1CC2CNC(C2)N1
117
+ C1=CCOCCC1
118
+ C1CC2CCC(C1)O2
119
+ C1=CC2CC1NN2
120
+ C1=CC2CCCC(C1)C2
121
+ C1=CN=CCC1
122
+ C1CC2NCC1CN2
123
+ C1=CSCN1
124
+ C1=NCCC1
125
+ C1CC2C3CC1CC23
126
+ C1=CN=CN=C1
127
+ C1=NSNC1
128
+ C1CSN=N1
129
+ C1COCCSC1
130
+ C1CNSC1
131
+ C1=CCCCCC1
132
+ C1=NCNCN1
133
+ C1CC2CCC(C1)C2
134
+ C1CC2CNCC(C1)C2
135
+ C1=CC2CCCC1C2
136
+ C1=CC2CCC1O2
137
+ C1=CC2CCC1CC2
138
+ C1=CC2C=CC(C1)CC2
139
+ C1CSNCS1
140
+ NBr
141
+ C1COCSN1
142
+ C1=CSCNC1
143
+ C1CC2NCC1NN2
144
+ C1=CN=CCNC1
145
+ C1CC2CNCC1O2
146
+ C1=NNCCC1
147
+ C1CC2CC(C1)N2
148
+ C1=NNSC1
149
+ C1=COCCN1
150
+ C1=CCC2CC(C1)C2
151
+ C1=CC2CCCC1CC2
152
+ C1=CCNCC1
153
+ C1CNNC1
154
+ C1=NC=NC1
155
+ C1=COC=N1
156
+ C1CC2CNC(C2)O1
157
+ C1=CCC2CC=CC(C1)C2
158
+ C1=CC=CC=C1
159
+ C1=CCSOC1
160
+ C1CN2CCC1CC2
161
+ C1=CSCCCN1
162
+ C1CNCCSC1
163
+ C1=NCCSN1
164
+ C1=NCCNN1
165
+ C1NCON1
166
+ C1NCC2CC1CN2
167
+ C1=CNC=N1
168
+ C1=CC2CCC1NN2
169
+ C1CNCOC1
170
+ C1C2CC3CC1OC(O2)O3
171
+ C1=NNCSCC1
172
+ C1=CCOCC1
173
+ C1=NSCC1
174
+ C1=CNSCC1
175
+ C1=CCC1
176
+ C1CCCNCC1
177
+ C1CC2CCCN(C1)C2
178
+ CN
179
+ C1CC2CC(O1)C1OCC2O1
180
+ C1CC2CNC(C1)N2
181
+ CO
182
+ C1=C2CCCCC1C2
183
+ C1=COCCNC1
184
+ C1=CN=CNC1
185
+ C1=CNCCCC1
186
+ C1NN=NS1
187
+ C1=NC=NCS1
188
+ C1=NN=CCC1
189
+ C1=NCCCC1
190
+ C1CC2CCC3CC1NC23
191
+ C1=NCCOCC1
192
+ C1=CSNC=N1
193
+ C1=CSCNNC1
194
+ C1=COCCC1
195
+ C1=COCCC=N1
196
+ C1=NNCCN1
197
+ C1C2CC1C2
198
+ N1C2NC3NC4NC3NC2NC14
199
+ C1=CC2C3NSC2C13
200
+ C1CC2COCC1N2
201
+ C1=CC2CCN3CC1CC23
202
+ C1CCONC1
203
+ C1=CC1
204
+ C1CCNNC1
205
+ C1=CSC=N1
206
+ C1CC1
207
+ C1=NCNCC1
208
+ C1=CC2CC(NCN2)O1
209
+ C1=CNSN=C1
210
+ C1=CCCNCC1
211
+ C1=CSN=CN1
212
+ NS
213
+ C1=CSN=N1
214
+ C1=CN=CC1
215
+ C1=CC2CCC(O1)O2
216
+ C1C2CC3C1OC1NCC2C13
217
+ C1COCSC1
218
+ C1=CCCOCC1
219
+ OBr
220
+ C1=COCSN1
221
+ C1=CN=N[SH]=C1
222
+ C1=CSCC=NC1
223
+ C1=CC=C2CCCCC(=C1)C2
224
+ C1=CSC=CC1
225
+ C=C
226
+ C1=CNCOC1
227
+ C1=NON=C1
228
+ C1=C[SH]=NC=N1
229
+ C1=CNCNC1
230
+ C1CN2C3CC4CC(C1C3)C2C4
231
+ C1=CONCNN1
232
+ CBr
233
+ C1CCNCNC1
234
+ C1=CCNN=C1
235
+ C1=CCNCCC1
236
+ C1=CSCCCC1
237
+ C1=CC2COCC(C1)C2
238
+ C1=CCCC=CC1
239
+ C1COCCN1
240
+ N1NN1
241
+ C1=CC2CCCC(C1)N2
242
+ C1=CC=NC=C1
243
+ C1=NC=NNC1
244
+ C1CC2CCC1CNC2
245
+ C1=CNCC1
246
+ N=S
247
+ CC
248
+ C1=NNCNCC1
249
+ C1=NN=NN1
250
+ C1=NSNCC1
251
+ C1=CNCCSC1
252
+ C1=NC=NCCC1
253
+ C1CCNSCC1
254
+ C1=NC=NN=C1
255
+ C1=CCOCOC1
256
+ C1=NN=CNC1
257
+ C1=COCCCC1
258
+ C1=NCCCO1
259
+ C1CC2CNC(C1)C2
260
+ C1=CSN=CS1
261
+ C1CC2CCC(C1)CNC2
262
+ C1C[SH]=NS1
263
+ C1CC2CCCC(C1)C2
264
+ C1CC2CCC(C1)N2
265
+ C1=NOCC1
266
+ C1=NCCOC1
267
+ C1CC2C3CC1C2CO3
268
+ C1NCN1
269
+ C1COC2CCC(C2)N1
270
+ C1=CNC=CC1
271
+ C1=CNCCC1
272
+ C1=NN=CN=N1
273
+ C1=CNCCC=N1
274
+ C1NCC2CC1C2
275
+ C1C2CC3CC1CC(C2)C3
276
+ C1=NC2CC(C1)CN2
277
+ C1=CNCN=CC1
278
+ C1NCSCN1
279
+ S
280
+ C1CC2CCC1C2
281
+ C1=CCC=NC=C1
282
+ C1CCOCC1
283
+ C1CCN2CCCC(C1)C2
284
+ C1=NSCCS1
285
+ C1NC2CCOC(C2)N1
286
+ C1=NCNN1
287
+ C1=CSC=CNC1
288
+ C1=CCSC=C1
289
+ C1=COC=CO1
290
+ C1=CCCNC=C1
291
+ C1CSCCN1
292
+ C1CCCC1
293
+ C1CCSOC1
294
+ C1=NSCO1
295
+ C1=NCN=CO1
296
+ C1CN2CCN1CC2
297
+ C1=NCN=CN1
298
+ C1CC2CCC(N1)O2
299
+ C1COCO1
300
+ C1=COC=C1
301
+ C1=CCC2CCC(C1)N2
302
+ C1CCC2CCCC(C1)C2
303
+ C1=COCC1
304
+ C1=NSN=CO1
305
+ C1=CCCOC=C1
306
+ C1NNCO1
307
+ CS
308
+ C1CC2CC1CO2
309
+ C1NC2CNC1C2
310
+ C1=CC=C1
311
+ C1=COC=CC1
312
+ C1=CCNNC1
313
+ NN
314
+ C1=COCCO1
315
+ C1=NCC=NN1
316
+ C1CNCNCN1
317
+ C1=CN=NC=C1
318
+ C1=CSCO1
319
+ C1=NC=NSO1
320
+ C1=NCCS1
321
+ C1CC2CNC1CN2
322
+ C1=NNCC1
323
+ C1=CCCC=C1
324
+ C1=NCCN1
325
+ C1=NCCSCC1
326
+ C1=NNNC1
327
+ C1=CC2CCNC(C1)C2
328
+ C1=NCCN=N1
329
+ C1=CNSN1
330
+ C1=CNOC1
331
+ C1=CSN=C1
332
+ C1CCC2CCC(C1)C2
333
+ C1NCC2COCC1C2
334
+ C1COC2CC(N1)C1COC2O1
335
+ C1=NC=NN1
336
+ C1CSNSC1
337
+ C1CCN=NC1
338
+ C1=CO1
339
+ C1=CNCCOC1
340
+ C1=NNN=C1
341
+ C1=CC=NN=C1
342
+ C1CC2CCC1CC2
343
+ C1NC2CC(CO2)O1
344
+ C1NNCS1
345
+ C1=CC2CCCCC(CC1)C2
346
+ C1=CN=NCC1
347
+ C1=CON=C1
348
+ C1CSN1
349
+ C1=CSNC1
350
+ C1CC2CC(C1)C2
351
+ C1CNCN1
352
+ C1=CSCNN1
353
+ C1=CSCC=N1
354
+ C1=CC2CCC1CN2
355
+ C1=CSCCC1
356
+ C1=CC2CCC(C2)N1
357
+ C1CC2CC3CC1CC(C2)C3
358
+ C1=CSCCS1
359
+ C1=NC=NCC1
360
+ C1=CCC2CC(C=CN2)C1
361
+ C1OC2C3OC4C1C1OC2C3OCC41
362
+ C1=NCCCN1
363
+ C1COCOC1
364
+ C1=CNCC=N1
365
+ C1=NCCCS1
366
+ C1=CC2C=CC(C=C1)C2
367
+ C1CC2CC1CN2
368
+ C1=CNN1
369
+ C1=CC2CCCC1O2
370
+ C1=CCCCC1
371
+ C1=CNC=NC1
372
+ C1=CSNCN1
373
+ C1=CSCNN=C1
374
+ C1=CNC=CNC1
375
+ C1OC2CC3CC1C2C3
376
+ C1=NCN=C1
377
+ C1=CSC=CS1
378
+ C
379
+ C1=NCCN=CC1
380
+ C1=CSNCCN1
381
+ C1=CCN=CCC1
382
+ C1=CC2CCNC(C2)O1
383
+ C1C2CN3CCN(C2)CC1C3
384
+ C1C2CN3CC1CN(C2)C3
385
+ C1=NCCCCN1
386
+ C1C2CC3CC(CC1O3)N2
387
+ C1=NN=COC1
388
+ C1=CSCCNC1
389
+ C1=NOCN1
390
+ C1NCC2CNCC1C2
391
+ C1CCSC1
392
+ C1=CN=NNC1
393
+ C1CSCS1
394
+ C1CNN=N1
395
+ C1NCNSN1
396
+ C1=NCC1
397
+ N1NO1
398
+ C1COSCSO1
399
+ C1=CNNCN=C1
400
+ C1CSCCS1
401
+ C1=CSC=NC1
402
+ C1=CC2CCC(C1)NN2
403
+ C1CNCCNC1
404
+ C1C2CC3OC1CC3O2
405
+ C1CNCCN1
406
+ C1=NCCCNC1
407
+ C=O
408
+ C1CNSCCN1
409
+ C1CNCSC1
410
+ C1CCC2C3CCC2C(C1)C3
411
+ C1=CCC=C1
412
+ C1COC1
413
+ C1CC2CC1NN2
414
+ C1=NN=CS1
415
+ C1COCCSN1
416
+ C1CCNC1
417
+ C1=NOCCC1
418
+ C1CCC1
419
+ C1CC2CCC1CO2
420
+ C1=CCNC=C1
421
+ C1=CN=CN=CC1
422
+ C1=NC=NO1
423
+ C1=CC2CCOC(C2)N1
424
+ C1=CNCN=N1
425
+ C1NCN2CNCC1C2
426
+ C1CN2CCC3CC1CC2C3
427
+ C1=CSC=CO1
428
+ C1=CNCCN=C1
429
+ C1CCC2CC(C1)C2
430
+ C1=COCO1
431
+ C1=NCOC1
432
+ C1CSCO1
433
+ C1=NCCSC1
434
+ C1CCC2CCCC(C1)N2
435
+ C1=NCC=NC1
436
+ C1CSCCO1
437
+ C1=CNC=C1
438
+ C1COSN1
439
+ C1=CC2CC(O1)C1OCC2O1
440
+ C1CC2CCC(CN1)N2
441
+ C1=COCCN=C1
442
+ C1=CCOC=CC1
443
+ C1=NCN=CC1
444
+ C1OC2COC1C2
445
+ C1=NC2CC(C1)CCO2
446
+ C1=NC=NS1
447
+ C1=CSN=CO1
448
+ C1=CNN=CC1
449
+ C1=CC2C=CC(C1)CNC2
450
+ C1CC2CCNC(C2)N1
451
+ C1=NCCNC1
452
+ C1CC2COCC(C1)C2
453
+ C1=CN=CCN=C1
454
+ C1=CNN=NC1
455
+ C1CC2CC3CCN2C(C1)C3
456
+ C1=CC2COCC(C1)N2
457
+ C1CNOC1
458
+ C1=CCCC1
459
+ C1CCCSCC1
460
+ C1CCOCOC1
461
+ C1=NC=NCN1
462
+ C1=NCCNS1
463
+ C1NCNN1
464
+ C1CN2C3CC4CC1CC2C4O3
465
+ C1=COCNC1
466
+ C1CN2CC3OCC2CC13
467
+ C1NC2CC1C2
468
+ C1=NSCCN1
469
+ C1=CCOC=C1
470
+ C1CCSNC1
471
+ C1=CC2C=CC1CC2
472
+ C1CNSNC1
473
+ C1=CN=CNCC1
474
+ C1=NNCNC1
475
+ C1=CNCNN=C1
476
+ C1=NNCCCC1
477
+ C1=NCC2CCCC1C2
478
+ C1CCNCC1
479
+ C1NCC2CNCC1O2
480
+ C1=COCOC1
481
+ C1CC2COC(C1)O2
482
+ C#N
483
+ C1=NCNN=C1
484
+ C1CC2COC(C1)C2
485
+ C1CC2CCC1NN2
486
+ C1=CN=CC=N1
487
+ C1=NCCN=C1
488
+ C1CSCNCN1
489
+ C1NCC2COCC1N2
490
+ C1=CCN=C1
491
+ C1COCCO1
492
+ C1=COCCCO1
493
+ C1CC2COCC(C1)N2
494
+ C1=COCC=N1
495
+ C1CC2OC(CCS2)S1
496
+ C1=CCC=CC1
497
+ C1=NC1
498
+ C1CCCOCC1
499
+ C1=CC2CCC(C1)N2
500
+ C1=CNC=CN1
501
+ C1CCOC1
502
+ C1=CC=CCC=C1
503
+ C1=CC2CCNC(CC1)C2
504
+ C1=CNCSC1
505
+ C1CC2CCC1N2
506
+ C1=NSNCN1
507
+ C1=CSNCC1
508
+ C1CC2CCCC(C1)N2
509
+ C1=CN=CCC=N1
510
+ C1NCC2COCC1CN2
511
+ C1=CSC=C1
512
+ C1C2CN3CN1CN(C2)C3
513
+ C1=CCON=C1
514
+ C1=NCCCCC1
515
+ C1=CSCCCS1
516
+ C1=CNN=N1
517
+ C1CC2C3CC1C2CN3
518
+ C1CC2CCC(C2)O1
519
+ OS
520
+ C1CC2OC3CC1CC2O3
521
+ C1CC2CCC1CN2
522
+ C1=CNCN1
523
+ C1=CC2CC(N1)C1C=CC2C1
524
+ C1=CC2CC(C1)C2
525
+ C1=NCCCSC1
526
+ C1NC2CC(N1)C1OCC2O1
527
+ C#C
528
+ CCl
529
+ C1=CN=NN=C1
530
+ C1=CNCCN1
531
+ C1CCCCC1
532
+ C1CNCNC1
533
+ C1=CNNC1