Spaces:
Sleeping
Sleeping
Trương Gia Bảo
commited on
Commit
·
a3ea5d3
1
Parent(s):
6c75a42
Initial commit
Browse files- .gitattributes +1 -0
- app.py +111 -0
- fast_jtnn/__init__.py +9 -0
- fast_jtnn/chemutils.py +429 -0
- fast_jtnn/datautils.py +213 -0
- fast_jtnn/jtmpn.py +138 -0
- fast_jtnn/jtnn_dec.py +347 -0
- fast_jtnn/jtnn_enc.py +131 -0
- fast_jtnn/jtnn_vae.py +226 -0
- fast_jtnn/jtprop_vae.py +311 -0
- fast_jtnn/mol_tree.py +168 -0
- fast_jtnn/mpn.py +125 -0
- fast_jtnn/nnutils.py +72 -0
- fast_jtnn/vocab.py +31 -0
- fpscores.pkl.gz +3 -0
- requirements.txt +9 -0
- sascorer.py +173 -0
- vocab.txt +533 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
model.iter-685000 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import streamlit as st
|
3 |
+
import sys, os
|
4 |
+
import rdkit
|
5 |
+
import rdkit.Chem as Chem
|
6 |
+
from rdkit.Chem.Draw import MolToImage
|
7 |
+
from rdkit.Chem import Descriptors
|
8 |
+
import sascorer
|
9 |
+
import networkx as nx
|
10 |
+
|
11 |
+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
12 |
+
|
13 |
+
sys.path.append('%s/fast_jtnn/' % os.path.dirname(os.path.realpath(__file__)))
|
14 |
+
from mol_tree import Vocab, MolTree
|
15 |
+
from jtprop_vae import JTPropVAE
|
16 |
+
from molbloom import buy
|
17 |
+
|
18 |
+
|
19 |
+
lg = rdkit.RDLogger.logger()
|
20 |
+
lg.setLevel(rdkit.RDLogger.CRITICAL)
|
21 |
+
|
22 |
+
st.header('Junction Tree Variational Autoencoder for Molecular Graph Generation (JTVAE)')
|
23 |
+
st.subheader('Wengong Jin, Regina Barzilay, Tommi Jaakkola')
|
24 |
+
descrip = '''
|
25 |
+
We seek to automate the design of molecules based on specific chemical properties. In computational terms, this task involves continuous embedding and generation of molecular graphs. Our primary contribution is the direct realization of molecular graphs, a task previously approached by generating linear SMILES strings instead of graphs. Our junction tree variational autoencoder generates molecular graphs in two phases, by first generating a tree-structured scaffold over chemical substructures, and then combining them into a molecule with a graph message passing network. This approach allows us to incrementally expand molecules while maintaining chemical validity at every step. We evaluate our model on multiple tasks ranging from molecular generation to optimization. Across these tasks, our model outperforms previous state-of-the-art baselines by a significant margin.
|
26 |
+
|
27 |
+
[https://arxiv.org/abs/1802.04364](https://arxiv.org/abs/1802.04364)'''
|
28 |
+
|
29 |
+
with st.expander('About'):
|
30 |
+
st.markdown(descrip)
|
31 |
+
|
32 |
+
st.text_input('Enter a SMILES string:','CNC(=O)C1=NC=CC(=C1)OC2=CC=C(C=C2)NC(=O)NC3=CC(=C(C=C3)Cl)C(F)(F)F',key='smiles')
|
33 |
+
|
34 |
+
def penalized_logp_standard(mol):
|
35 |
+
|
36 |
+
logP_mean = 2.4399606244103639873799239
|
37 |
+
logP_std = 0.9293197802518905481505840
|
38 |
+
SA_mean = -2.4485512208785431553792478
|
39 |
+
SA_std = 0.4603110476923852334429910
|
40 |
+
cycle_mean = -0.0307270378623088931402396
|
41 |
+
cycle_std = 0.2163675785228087178335699
|
42 |
+
|
43 |
+
log_p = Descriptors.MolLogP(mol)
|
44 |
+
SA = -sascorer.calculateScore(mol)
|
45 |
+
|
46 |
+
# cycle score
|
47 |
+
cycle_list = nx.cycle_basis(nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(mol)))
|
48 |
+
if len(cycle_list) == 0:
|
49 |
+
cycle_length = 0
|
50 |
+
else:
|
51 |
+
cycle_length = max([len(j) for j in cycle_list])
|
52 |
+
if cycle_length <= 6:
|
53 |
+
cycle_length = 0
|
54 |
+
else:
|
55 |
+
cycle_length = cycle_length - 6
|
56 |
+
cycle_score = -cycle_length
|
57 |
+
# print(logP_mean)
|
58 |
+
|
59 |
+
standardized_log_p = (log_p - logP_mean) / logP_std
|
60 |
+
standardized_SA = (SA - SA_mean) / SA_std
|
61 |
+
standardized_cycle = (cycle_score - cycle_mean) / cycle_std
|
62 |
+
return standardized_log_p + standardized_SA + standardized_cycle
|
63 |
+
|
64 |
+
mol = Chem.MolFromSmiles(st.session_state.smiles)
|
65 |
+
if mol is None:
|
66 |
+
st.write('SMILES is invalid. Please enter a valid SMILES.')
|
67 |
+
else:
|
68 |
+
st.write('Molecule:')
|
69 |
+
st.image(MolToImage(mol,size=(300,300)))
|
70 |
+
score = penalized_logp_standard(mol)
|
71 |
+
st.write('Penalized logP score: %.5f' % (score))
|
72 |
+
|
73 |
+
if mol is not None:
|
74 |
+
st.slider('Choose learning rate: ',0.0,10.0,0.4,key='lr')
|
75 |
+
st.slider('Choose similarity cutoff: ',0.0,3.0,0.4,key='sim_cutoff')
|
76 |
+
st.slider('Choose number of iterations: ',1,100,80,key='n_iter')
|
77 |
+
vocab = [x.strip("\r\n ") for x in open('./vocab.txt')]
|
78 |
+
vocab = Vocab(vocab)
|
79 |
+
if st.button('Optimize'):
|
80 |
+
st.write('Testing')
|
81 |
+
|
82 |
+
model = JTPropVAE(vocab, 450, 56, 20, 3)
|
83 |
+
|
84 |
+
model.load_state_dict(torch.load('./model.iter-685000',map_location=torch.device('cpu')))
|
85 |
+
|
86 |
+
new_smiles,sim = model.optimize(st.session_state.smiles, sim_cutoff=st.session_state.sim_cutoff, lr=st.session_state.lr, num_iter=st.session_state.n_iter)
|
87 |
+
|
88 |
+
del model
|
89 |
+
if new_smiles is None:
|
90 |
+
st.write('Cannot optimize.')
|
91 |
+
else:
|
92 |
+
st.write('New SMILES:')
|
93 |
+
st.code(new_smiles)
|
94 |
+
new_mol = Chem.MolFromSmiles(new_smiles)
|
95 |
+
if new_mol is None:
|
96 |
+
st.write('New SMILES is invalid.')
|
97 |
+
else:
|
98 |
+
st.write('New SMILES molecule:')
|
99 |
+
st.image(MolToImage(new_mol,size=(300,300)))
|
100 |
+
new_score = penalized_logp_standard(new_mol)
|
101 |
+
st.write('New penalized logP score: %.5f' % (new_score))
|
102 |
+
st.write('Caching ZINC20 if necessary...')
|
103 |
+
if buy(new_smiles, catalog='zinc20',canonicalize=True):
|
104 |
+
st.write('This molecule exists.')
|
105 |
+
st.caption('Checked by molbloom.')
|
106 |
+
else:
|
107 |
+
st.write('THIS MOLECULE DOES NOT EXIST!')
|
108 |
+
st.caption('Checked by molbloom.')
|
109 |
+
|
110 |
+
|
111 |
+
|
fast_jtnn/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import sys
|
2 |
+
# sys.path.append('./')
|
3 |
+
from mol_tree import Vocab, MolTree
|
4 |
+
from jtnn_vae import JTNNVAE
|
5 |
+
from jtnn_enc import JTNNEncoder
|
6 |
+
from jtmpn import JTMPN
|
7 |
+
from mpn import MPN
|
8 |
+
from nnutils import create_var
|
9 |
+
from datautils import MolTreeFolder, PairTreeFolder, MolTreeDataset
|
fast_jtnn/chemutils.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import rdkit
|
2 |
+
import rdkit.Chem as Chem
|
3 |
+
from scipy.sparse import csr_matrix
|
4 |
+
from scipy.sparse.csgraph import minimum_spanning_tree
|
5 |
+
from collections import defaultdict
|
6 |
+
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
|
7 |
+
from vocab import Vocab
|
8 |
+
|
9 |
+
MST_MAX_WEIGHT = 100
|
10 |
+
MAX_NCAND = 2000
|
11 |
+
|
12 |
+
def set_atommap(mol, num=0):
|
13 |
+
for atom in mol.GetAtoms():
|
14 |
+
atom.SetAtomMapNum(num)
|
15 |
+
|
16 |
+
def get_mol(smiles):
|
17 |
+
mol = Chem.MolFromSmiles(smiles)
|
18 |
+
if mol is None:
|
19 |
+
return None
|
20 |
+
Chem.Kekulize(mol, clearAromaticFlags=True)
|
21 |
+
return mol
|
22 |
+
|
23 |
+
def get_smiles(mol):
|
24 |
+
return Chem.MolToSmiles(mol, kekuleSmiles=True)
|
25 |
+
|
26 |
+
def decode_stereo(smiles2D):
|
27 |
+
mol = Chem.MolFromSmiles(smiles2D)
|
28 |
+
dec_isomers = list(EnumerateStereoisomers(mol))
|
29 |
+
|
30 |
+
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
|
31 |
+
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
|
32 |
+
|
33 |
+
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
|
34 |
+
if len(chiralN) > 0:
|
35 |
+
for mol in dec_isomers:
|
36 |
+
for idx in chiralN:
|
37 |
+
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
|
38 |
+
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
|
39 |
+
|
40 |
+
return smiles3D
|
41 |
+
|
42 |
+
def sanitize(mol):
|
43 |
+
try:
|
44 |
+
smiles = get_smiles(mol)
|
45 |
+
mol = get_mol(smiles)
|
46 |
+
except Exception as e:
|
47 |
+
return None
|
48 |
+
return mol
|
49 |
+
|
50 |
+
def copy_atom(atom):
|
51 |
+
new_atom = Chem.Atom(atom.GetSymbol())
|
52 |
+
new_atom.SetFormalCharge(atom.GetFormalCharge())
|
53 |
+
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
|
54 |
+
return new_atom
|
55 |
+
|
56 |
+
def copy_edit_mol(mol):
|
57 |
+
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
|
58 |
+
for atom in mol.GetAtoms():
|
59 |
+
new_atom = copy_atom(atom)
|
60 |
+
new_mol.AddAtom(new_atom)
|
61 |
+
for bond in mol.GetBonds():
|
62 |
+
a1 = bond.GetBeginAtom().GetIdx()
|
63 |
+
a2 = bond.GetEndAtom().GetIdx()
|
64 |
+
bt = bond.GetBondType()
|
65 |
+
new_mol.AddBond(a1, a2, bt)
|
66 |
+
return new_mol
|
67 |
+
|
68 |
+
def get_clique_mol(mol, atoms):
|
69 |
+
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
|
70 |
+
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
71 |
+
new_mol = copy_edit_mol(new_mol).GetMol()
|
72 |
+
new_mol = sanitize(new_mol) #We assume this is not None
|
73 |
+
return new_mol
|
74 |
+
|
75 |
+
def tree_decomp(mol):
|
76 |
+
n_atoms = mol.GetNumAtoms()
|
77 |
+
if n_atoms == 1: #special case
|
78 |
+
return [[0]], []
|
79 |
+
|
80 |
+
cliques = []
|
81 |
+
for bond in mol.GetBonds():
|
82 |
+
a1 = bond.GetBeginAtom().GetIdx()
|
83 |
+
a2 = bond.GetEndAtom().GetIdx()
|
84 |
+
if not bond.IsInRing():
|
85 |
+
cliques.append([a1,a2])
|
86 |
+
|
87 |
+
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
|
88 |
+
cliques.extend(ssr)
|
89 |
+
|
90 |
+
nei_list = [[] for i in range(n_atoms)]
|
91 |
+
for i in range(len(cliques)):
|
92 |
+
for atom in cliques[i]:
|
93 |
+
nei_list[atom].append(i)
|
94 |
+
|
95 |
+
#Merge Rings with intersection > 2 atoms
|
96 |
+
for i in range(len(cliques)):
|
97 |
+
if len(cliques[i]) <= 2: continue
|
98 |
+
for atom in cliques[i]:
|
99 |
+
for j in nei_list[atom]:
|
100 |
+
if i >= j or len(cliques[j]) <= 2: continue
|
101 |
+
inter = set(cliques[i]) & set(cliques[j])
|
102 |
+
if len(inter) > 2:
|
103 |
+
cliques[i].extend(cliques[j])
|
104 |
+
cliques[i] = list(set(cliques[i]))
|
105 |
+
cliques[j] = []
|
106 |
+
|
107 |
+
cliques = [c for c in cliques if len(c) > 0]
|
108 |
+
nei_list = [[] for i in range(n_atoms)]
|
109 |
+
for i in range(len(cliques)):
|
110 |
+
for atom in cliques[i]:
|
111 |
+
nei_list[atom].append(i)
|
112 |
+
|
113 |
+
#Build edges and add singleton cliques
|
114 |
+
edges = defaultdict(int)
|
115 |
+
for atom in range(n_atoms):
|
116 |
+
if len(nei_list[atom]) <= 1:
|
117 |
+
continue
|
118 |
+
cnei = nei_list[atom]
|
119 |
+
bonds = [c for c in cnei if len(cliques[c]) == 2]
|
120 |
+
rings = [c for c in cnei if len(cliques[c]) > 4]
|
121 |
+
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with.
|
122 |
+
cliques.append([atom])
|
123 |
+
c2 = len(cliques) - 1
|
124 |
+
for c1 in cnei:
|
125 |
+
edges[(c1,c2)] = 1
|
126 |
+
elif len(rings) > 2: #Multiple (n>2) complex rings
|
127 |
+
cliques.append([atom])
|
128 |
+
c2 = len(cliques) - 1
|
129 |
+
for c1 in cnei:
|
130 |
+
edges[(c1,c2)] = MST_MAX_WEIGHT - 1
|
131 |
+
else:
|
132 |
+
for i in range(len(cnei)):
|
133 |
+
for j in range(i + 1, len(cnei)):
|
134 |
+
c1,c2 = cnei[i],cnei[j]
|
135 |
+
inter = set(cliques[c1]) & set(cliques[c2])
|
136 |
+
if edges[(c1,c2)] < len(inter):
|
137 |
+
edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction
|
138 |
+
|
139 |
+
edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()]
|
140 |
+
if len(edges) == 0:
|
141 |
+
return cliques, edges
|
142 |
+
|
143 |
+
#Compute Maximum Spanning Tree
|
144 |
+
row,col,data = zip(*edges)
|
145 |
+
n_clique = len(cliques)
|
146 |
+
clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
|
147 |
+
junc_tree = minimum_spanning_tree(clique_graph)
|
148 |
+
row,col = junc_tree.nonzero()
|
149 |
+
edges = [(row[i],col[i]) for i in range(len(row))]
|
150 |
+
return (cliques, edges)
|
151 |
+
|
152 |
+
def atom_equal(a1, a2):
|
153 |
+
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
|
154 |
+
|
155 |
+
#Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
|
156 |
+
def ring_bond_equal(b1, b2, reverse=False):
|
157 |
+
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
|
158 |
+
if reverse:
|
159 |
+
b2 = (b2.GetEndAtom(), b2.GetBeginAtom())
|
160 |
+
else:
|
161 |
+
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
|
162 |
+
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
|
163 |
+
|
164 |
+
def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap):
|
165 |
+
prev_nids = [node.nid for node in prev_nodes]
|
166 |
+
for nei_node in prev_nodes + neighbors:
|
167 |
+
nei_id,nei_mol = nei_node.nid,nei_node.mol
|
168 |
+
amap = nei_amap[nei_id]
|
169 |
+
for atom in nei_mol.GetAtoms():
|
170 |
+
if atom.GetIdx() not in amap:
|
171 |
+
new_atom = copy_atom(atom)
|
172 |
+
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
|
173 |
+
|
174 |
+
if nei_mol.GetNumBonds() == 0:
|
175 |
+
nei_atom = nei_mol.GetAtomWithIdx(0)
|
176 |
+
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
|
177 |
+
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
|
178 |
+
else:
|
179 |
+
for bond in nei_mol.GetBonds():
|
180 |
+
a1 = amap[bond.GetBeginAtom().GetIdx()]
|
181 |
+
a2 = amap[bond.GetEndAtom().GetIdx()]
|
182 |
+
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
|
183 |
+
ctr_mol.AddBond(a1, a2, bond.GetBondType())
|
184 |
+
elif nei_id in prev_nids: #father node overrides
|
185 |
+
ctr_mol.RemoveBond(a1, a2)
|
186 |
+
ctr_mol.AddBond(a1, a2, bond.GetBondType())
|
187 |
+
return ctr_mol
|
188 |
+
|
189 |
+
def local_attach(ctr_mol, neighbors, prev_nodes, amap_list):
|
190 |
+
ctr_mol = copy_edit_mol(ctr_mol)
|
191 |
+
nei_amap = {nei.nid:{} for nei in prev_nodes + neighbors}
|
192 |
+
|
193 |
+
for nei_id,ctr_atom,nei_atom in amap_list:
|
194 |
+
nei_amap[nei_id][nei_atom] = ctr_atom
|
195 |
+
|
196 |
+
ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap)
|
197 |
+
return ctr_mol.GetMol()
|
198 |
+
|
199 |
+
#This version records idx mapping between ctr_mol and nei_mol
|
200 |
+
def enum_attach(ctr_mol, nei_node, amap, singletons):
|
201 |
+
nei_mol,nei_idx = nei_node.mol,nei_node.nid
|
202 |
+
att_confs = []
|
203 |
+
black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons]
|
204 |
+
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list]
|
205 |
+
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
|
206 |
+
|
207 |
+
if nei_mol.GetNumBonds() == 0: #neighbor singleton
|
208 |
+
nei_atom = nei_mol.GetAtomWithIdx(0)
|
209 |
+
used_list = [atom_idx for _,atom_idx,_ in amap]
|
210 |
+
for atom in ctr_atoms:
|
211 |
+
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
|
212 |
+
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
|
213 |
+
att_confs.append( new_amap )
|
214 |
+
|
215 |
+
elif nei_mol.GetNumBonds() == 1: #neighbor is a bond
|
216 |
+
bond = nei_mol.GetBondWithIdx(0)
|
217 |
+
bond_val = int(bond.GetBondTypeAsDouble())
|
218 |
+
b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom()
|
219 |
+
|
220 |
+
for atom in ctr_atoms:
|
221 |
+
#Optimize if atom is carbon (other atoms may change valence)
|
222 |
+
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
|
223 |
+
continue
|
224 |
+
if atom_equal(atom, b1):
|
225 |
+
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
|
226 |
+
att_confs.append( new_amap )
|
227 |
+
elif atom_equal(atom, b2):
|
228 |
+
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
|
229 |
+
att_confs.append( new_amap )
|
230 |
+
else:
|
231 |
+
#intersection is an atom
|
232 |
+
for a1 in ctr_atoms:
|
233 |
+
for a2 in nei_mol.GetAtoms():
|
234 |
+
if atom_equal(a1, a2):
|
235 |
+
#Optimize if atom is carbon (other atoms may change valence)
|
236 |
+
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
|
237 |
+
continue
|
238 |
+
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
|
239 |
+
att_confs.append( new_amap )
|
240 |
+
|
241 |
+
#intersection is an bond
|
242 |
+
if ctr_mol.GetNumBonds() > 1:
|
243 |
+
for b1 in ctr_bonds:
|
244 |
+
for b2 in nei_mol.GetBonds():
|
245 |
+
if ring_bond_equal(b1, b2):
|
246 |
+
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
|
247 |
+
att_confs.append( new_amap )
|
248 |
+
|
249 |
+
if ring_bond_equal(b1, b2, reverse=True):
|
250 |
+
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
|
251 |
+
att_confs.append( new_amap )
|
252 |
+
return att_confs
|
253 |
+
|
254 |
+
#Try rings first: Speed-Up
|
255 |
+
def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]):
|
256 |
+
all_attach_confs = []
|
257 |
+
singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1]
|
258 |
+
|
259 |
+
def search(cur_amap, depth):
|
260 |
+
if len(all_attach_confs) > MAX_NCAND:
|
261 |
+
return
|
262 |
+
if depth == len(neighbors):
|
263 |
+
all_attach_confs.append(cur_amap)
|
264 |
+
return
|
265 |
+
|
266 |
+
nei_node = neighbors[depth]
|
267 |
+
cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons)
|
268 |
+
cand_smiles = set()
|
269 |
+
candidates = []
|
270 |
+
for amap in cand_amap:
|
271 |
+
cand_mol = local_attach(node.mol, neighbors[:depth+1], prev_nodes, amap)
|
272 |
+
cand_mol = sanitize(cand_mol)
|
273 |
+
if cand_mol is None:
|
274 |
+
continue
|
275 |
+
smiles = get_smiles(cand_mol)
|
276 |
+
if smiles in cand_smiles:
|
277 |
+
continue
|
278 |
+
cand_smiles.add(smiles)
|
279 |
+
candidates.append(amap)
|
280 |
+
|
281 |
+
if len(candidates) == 0:
|
282 |
+
return
|
283 |
+
|
284 |
+
for new_amap in candidates:
|
285 |
+
search(new_amap, depth + 1)
|
286 |
+
|
287 |
+
search(prev_amap, 0)
|
288 |
+
cand_smiles = set()
|
289 |
+
candidates = []
|
290 |
+
aroma_score = []
|
291 |
+
for amap in all_attach_confs:
|
292 |
+
cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap)
|
293 |
+
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
|
294 |
+
smiles = Chem.MolToSmiles(cand_mol)
|
295 |
+
if smiles in cand_smiles or check_singleton(cand_mol, node, neighbors) == False:
|
296 |
+
continue
|
297 |
+
cand_smiles.add(smiles)
|
298 |
+
candidates.append( (smiles,amap) )
|
299 |
+
aroma_score.append( check_aroma(cand_mol, node, neighbors) )
|
300 |
+
|
301 |
+
return candidates, aroma_score
|
302 |
+
|
303 |
+
def check_singleton(cand_mol, ctr_node, nei_nodes):
|
304 |
+
rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() > 2]
|
305 |
+
singletons = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() == 1]
|
306 |
+
if len(singletons) > 0 or len(rings) == 0: return True
|
307 |
+
|
308 |
+
n_leaf2_atoms = 0
|
309 |
+
for atom in cand_mol.GetAtoms():
|
310 |
+
nei_leaf_atoms = [a for a in atom.GetNeighbors() if not a.IsInRing()] #a.GetDegree() == 1]
|
311 |
+
if len(nei_leaf_atoms) > 1:
|
312 |
+
n_leaf2_atoms += 1
|
313 |
+
|
314 |
+
return n_leaf2_atoms == 0
|
315 |
+
|
316 |
+
def check_aroma(cand_mol, ctr_node, nei_nodes):
|
317 |
+
rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() >= 3]
|
318 |
+
if len(rings) < 2: return 0 #Only multi-ring system needs to be checked
|
319 |
+
|
320 |
+
get_nid = lambda x: 0 if x.is_leaf else x.nid
|
321 |
+
benzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.benzynes]
|
322 |
+
penzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.penzynes]
|
323 |
+
if len(benzynes) + len(penzynes) == 0:
|
324 |
+
return 0 #No specific aromatic rings
|
325 |
+
|
326 |
+
n_aroma_atoms = 0
|
327 |
+
for atom in cand_mol.GetAtoms():
|
328 |
+
if atom.GetAtomMapNum() in benzynes+penzynes and atom.GetIsAromatic():
|
329 |
+
n_aroma_atoms += 1
|
330 |
+
|
331 |
+
if n_aroma_atoms >= len(benzynes) * 4 + len(penzynes) * 3:
|
332 |
+
return 1000
|
333 |
+
else:
|
334 |
+
return -0.001
|
335 |
+
|
336 |
+
#Only used for debugging purpose
|
337 |
+
def dfs_assemble(cur_mol, global_amap, fa_amap, cur_node, fa_node):
|
338 |
+
fa_nid = fa_node.nid if fa_node is not None else -1
|
339 |
+
prev_nodes = [fa_node] if fa_node is not None else []
|
340 |
+
|
341 |
+
children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
|
342 |
+
neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
|
343 |
+
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
|
344 |
+
singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
|
345 |
+
neighbors = singletons + neighbors
|
346 |
+
|
347 |
+
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
|
348 |
+
cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
|
349 |
+
|
350 |
+
cand_smiles,cand_amap = zip(*cands)
|
351 |
+
label_idx = cand_smiles.index(cur_node.label)
|
352 |
+
label_amap = cand_amap[label_idx]
|
353 |
+
|
354 |
+
for nei_id,ctr_atom,nei_atom in label_amap:
|
355 |
+
if nei_id == fa_nid:
|
356 |
+
continue
|
357 |
+
global_amap[nei_id][nei_atom] = global_amap[cur_node.nid][ctr_atom]
|
358 |
+
|
359 |
+
cur_mol = attach_mols(cur_mol, children, [], global_amap) #father is already attached
|
360 |
+
for nei_node in children:
|
361 |
+
if not nei_node.is_leaf:
|
362 |
+
dfs_assemble(cur_mol, global_amap, label_amap, nei_node, cur_node)
|
363 |
+
|
364 |
+
if __name__ == "__main__":
|
365 |
+
import sys
|
366 |
+
from mol_tree import MolTree
|
367 |
+
lg = rdkit.RDLogger.logger()
|
368 |
+
lg.setLevel(rdkit.RDLogger.CRITICAL)
|
369 |
+
|
370 |
+
smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1","O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
|
371 |
+
|
372 |
+
def tree_test():
|
373 |
+
for s in sys.stdin:
|
374 |
+
s = s.split()[0]
|
375 |
+
tree = MolTree(s)
|
376 |
+
print('-------------------------------------------')
|
377 |
+
print(s)
|
378 |
+
for node in tree.nodes:
|
379 |
+
print(node.smiles, [x.smiles for x in node.neighbors])
|
380 |
+
|
381 |
+
def decode_test():
|
382 |
+
wrong = 0
|
383 |
+
for tot,s in enumerate(sys.stdin):
|
384 |
+
s = s.split()[0]
|
385 |
+
tree = MolTree(s)
|
386 |
+
tree.recover()
|
387 |
+
|
388 |
+
cur_mol = copy_edit_mol(tree.nodes[0].mol)
|
389 |
+
global_amap = [{}] + [{} for node in tree.nodes]
|
390 |
+
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
|
391 |
+
|
392 |
+
dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None)
|
393 |
+
|
394 |
+
cur_mol = cur_mol.GetMol()
|
395 |
+
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
|
396 |
+
set_atommap(cur_mol)
|
397 |
+
dec_smiles = Chem.MolToSmiles(cur_mol)
|
398 |
+
|
399 |
+
gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s))
|
400 |
+
if gold_smiles != dec_smiles:
|
401 |
+
print(gold_smiles, dec_smiles)
|
402 |
+
wrong += 1
|
403 |
+
print(wrong, tot + 1)
|
404 |
+
|
405 |
+
def enum_test():
|
406 |
+
for s in sys.stdin:
|
407 |
+
s = s.split()[0]
|
408 |
+
tree = MolTree(s)
|
409 |
+
tree.recover()
|
410 |
+
tree.assemble()
|
411 |
+
for node in tree.nodes:
|
412 |
+
if node.label not in node.cands:
|
413 |
+
print(tree.smiles)
|
414 |
+
print(node.smiles, [x.smiles for x in node.neighbors])
|
415 |
+
print(node.label, len(node.cands))
|
416 |
+
|
417 |
+
def count():
|
418 |
+
cnt,n = 0,0
|
419 |
+
for s in sys.stdin:
|
420 |
+
s = s.split()[0]
|
421 |
+
tree = MolTree(s)
|
422 |
+
tree.recover()
|
423 |
+
tree.assemble()
|
424 |
+
for node in tree.nodes:
|
425 |
+
cnt += len(node.cands)
|
426 |
+
n += len(tree.nodes)
|
427 |
+
#print cnt * 1.0 / n
|
428 |
+
|
429 |
+
count()
|
fast_jtnn/datautils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from mol_tree import MolTree
|
4 |
+
import numpy as np
|
5 |
+
from jtnn_enc import JTNNEncoder
|
6 |
+
from mpn import MPN
|
7 |
+
from jtmpn import JTMPN
|
8 |
+
import pickle
|
9 |
+
import os, random
|
10 |
+
|
11 |
+
class PairTreeFolder(object):
|
12 |
+
|
13 |
+
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, y_assm=True, replicate=None):
|
14 |
+
self.data_folder = data_folder
|
15 |
+
self.data_files = [fn for fn in os.listdir(data_folder)]
|
16 |
+
self.batch_size = batch_size
|
17 |
+
self.vocab = vocab
|
18 |
+
self.num_workers = num_workers
|
19 |
+
self.y_assm = y_assm
|
20 |
+
self.shuffle = shuffle
|
21 |
+
|
22 |
+
if replicate is not None: #expand is int
|
23 |
+
self.data_files = self.data_files * replicate
|
24 |
+
|
25 |
+
def __iter__(self):
|
26 |
+
for fn in self.data_files:
|
27 |
+
fn = os.path.join(self.data_folder, fn)
|
28 |
+
with open(fn, 'rb') as f:
|
29 |
+
data = pickle.load(f)
|
30 |
+
|
31 |
+
if self.shuffle:
|
32 |
+
random.shuffle(data) #shuffle data before batch
|
33 |
+
|
34 |
+
batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
|
35 |
+
if len(batches[-1]) < self.batch_size:
|
36 |
+
batches.pop()
|
37 |
+
|
38 |
+
dataset = PairTreeDataset(batches, self.vocab, self.y_assm)
|
39 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
|
40 |
+
|
41 |
+
for b in dataloader:
|
42 |
+
yield b
|
43 |
+
|
44 |
+
del data, batches, dataset, dataloader
|
45 |
+
|
46 |
+
class MolTreeFolder(object):
|
47 |
+
|
48 |
+
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
|
49 |
+
self.data_folder = data_folder
|
50 |
+
self.data_files = [fn for fn in os.listdir(data_folder)]
|
51 |
+
self.batch_size = batch_size
|
52 |
+
self.vocab = vocab
|
53 |
+
self.num_workers = num_workers
|
54 |
+
self.shuffle = shuffle
|
55 |
+
self.assm = assm
|
56 |
+
|
57 |
+
if replicate is not None: #expand is int
|
58 |
+
self.data_files = self.data_files * replicate
|
59 |
+
|
60 |
+
def __iter__(self):
|
61 |
+
for fn in self.data_files:
|
62 |
+
fn = os.path.join(self.data_folder, fn)
|
63 |
+
with open(fn, 'rb') as f:
|
64 |
+
data = pickle.load(f)
|
65 |
+
|
66 |
+
if self.shuffle:
|
67 |
+
random.shuffle(data) #shuffle data before batch
|
68 |
+
|
69 |
+
batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
|
70 |
+
if len(batches[-1]) < self.batch_size:
|
71 |
+
batches.pop()
|
72 |
+
|
73 |
+
dataset = MolTreeDataset(batches, self.vocab, self.assm)
|
74 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
|
75 |
+
|
76 |
+
for b in dataloader:
|
77 |
+
yield b
|
78 |
+
|
79 |
+
del data, batches, dataset, dataloader
|
80 |
+
|
81 |
+
class PairTreeDataset(Dataset):
|
82 |
+
|
83 |
+
def __init__(self, data, vocab, y_assm):
|
84 |
+
self.data = data
|
85 |
+
self.vocab = vocab
|
86 |
+
self.y_assm = y_assm
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.data)
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
batch0, batch1 = zip(*self.data[idx])
|
93 |
+
return tensorize(batch0, self.vocab, assm=False), tensorize(batch1, self.vocab, assm=self.y_assm)
|
94 |
+
|
95 |
+
class MolTreeDataset(Dataset):
|
96 |
+
|
97 |
+
def __init__(self, data, vocab, assm=True):
|
98 |
+
self.data = data
|
99 |
+
self.vocab = vocab
|
100 |
+
self.assm = assm
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return len(self.data)
|
104 |
+
|
105 |
+
def __getitem__(self, idx):
|
106 |
+
return tensorize(self.data[idx], self.vocab, assm=self.assm)
|
107 |
+
|
108 |
+
def tensorize(tree_batch, vocab, assm=True):
|
109 |
+
set_batch_nodeID(tree_batch, vocab)
|
110 |
+
smiles_batch = [tree.smiles for tree in tree_batch]
|
111 |
+
jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
|
112 |
+
jtenc_holder = jtenc_holder
|
113 |
+
mpn_holder = MPN.tensorize(smiles_batch)
|
114 |
+
|
115 |
+
if assm is False:
|
116 |
+
return tree_batch, jtenc_holder, mpn_holder
|
117 |
+
|
118 |
+
cands = []
|
119 |
+
batch_idx = []
|
120 |
+
for i,mol_tree in enumerate(tree_batch):
|
121 |
+
for node in mol_tree.nodes:
|
122 |
+
#Leaf node's attachment is determined by neighboring node's attachment
|
123 |
+
if node.is_leaf or len(node.cands) == 1: continue
|
124 |
+
cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
|
125 |
+
batch_idx.extend([i] * len(node.cands))
|
126 |
+
|
127 |
+
jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
|
128 |
+
batch_idx = torch.LongTensor(batch_idx)
|
129 |
+
|
130 |
+
return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
|
131 |
+
|
132 |
+
def set_batch_nodeID(mol_batch, vocab):
|
133 |
+
tot = 0
|
134 |
+
for mol_tree in mol_batch:
|
135 |
+
for node in mol_tree.nodes:
|
136 |
+
node.idx = tot
|
137 |
+
node.wid = vocab.get_index(node.smiles)
|
138 |
+
tot += 1
|
139 |
+
|
140 |
+
class PropMolTreeDataset(Dataset):
|
141 |
+
|
142 |
+
def __init__(self, data, vocab, assm=True):
|
143 |
+
self.data = data
|
144 |
+
self.vocab = vocab
|
145 |
+
self.assm = assm
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.data)
|
149 |
+
|
150 |
+
def __getitem__(self, idx):
|
151 |
+
return tensorize_prop(self.data[idx],self.vocab, assm=self.assm)
|
152 |
+
|
153 |
+
class PropMolTreeFolder(object):
|
154 |
+
|
155 |
+
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
|
156 |
+
self.data_folder = data_folder
|
157 |
+
self.data_files = [fn for fn in os.listdir(data_folder)]
|
158 |
+
self.batch_size = batch_size
|
159 |
+
self.vocab = vocab
|
160 |
+
self.num_workers = num_workers
|
161 |
+
self.shuffle = shuffle
|
162 |
+
self.assm = assm
|
163 |
+
|
164 |
+
if replicate is not None: #expand is int
|
165 |
+
self.data_files = self.data_files * replicate
|
166 |
+
|
167 |
+
def __iter__(self):
|
168 |
+
for fn in self.data_files:
|
169 |
+
fn = os.path.join(self.data_folder, fn)
|
170 |
+
with open(fn, 'rb') as f:
|
171 |
+
data = pickle.load(f)
|
172 |
+
|
173 |
+
# print(data[0:5])
|
174 |
+
|
175 |
+
if self.shuffle:
|
176 |
+
random.shuffle(data) #shuffle data before batch
|
177 |
+
|
178 |
+
batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
|
179 |
+
if len(batches[-1]) < self.batch_size:
|
180 |
+
batches.pop()
|
181 |
+
|
182 |
+
dataset = PropMolTreeDataset(batches, self.vocab, self.assm)
|
183 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
|
184 |
+
|
185 |
+
for b in dataloader:
|
186 |
+
yield b
|
187 |
+
|
188 |
+
del data, batches, dataset, dataloader
|
189 |
+
|
190 |
+
def tensorize_prop(data, vocab, assm=True):
|
191 |
+
tree_batch,prop = list(zip(*data))
|
192 |
+
set_batch_nodeID(tree_batch, vocab)
|
193 |
+
smiles_batch = [tree.smiles for tree in tree_batch]
|
194 |
+
jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
|
195 |
+
jtenc_holder = jtenc_holder
|
196 |
+
mpn_holder = MPN.tensorize(smiles_batch)
|
197 |
+
|
198 |
+
if assm is False:
|
199 |
+
return tree_batch, jtenc_holder, mpn_holder
|
200 |
+
|
201 |
+
cands = []
|
202 |
+
batch_idx = []
|
203 |
+
for i,mol_tree in enumerate(tree_batch):
|
204 |
+
for node in mol_tree.nodes:
|
205 |
+
#Leaf node's attachment is determined by neighboring node's attachment
|
206 |
+
if node.is_leaf or len(node.cands) == 1: continue
|
207 |
+
cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
|
208 |
+
batch_idx.extend([i] * len(node.cands))
|
209 |
+
|
210 |
+
jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
|
211 |
+
batch_idx = torch.LongTensor(batch_idx)
|
212 |
+
|
213 |
+
return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx), prop
|
fast_jtnn/jtmpn.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from nnutils import create_var, index_select_ND
|
5 |
+
from chemutils import get_mol
|
6 |
+
import rdkit.Chem as Chem
|
7 |
+
|
8 |
+
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
|
9 |
+
|
10 |
+
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
|
11 |
+
BOND_FDIM = 5
|
12 |
+
MAX_NB = 15
|
13 |
+
|
14 |
+
def onek_encoding_unk(x, allowable_set):
|
15 |
+
if x not in allowable_set:
|
16 |
+
x = allowable_set[-1]
|
17 |
+
return list(map(lambda s: x == s, allowable_set))
|
18 |
+
|
19 |
+
def atom_features(atom):
|
20 |
+
return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
|
21 |
+
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
|
22 |
+
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
|
23 |
+
+ [atom.GetIsAromatic()])
|
24 |
+
|
25 |
+
def bond_features(bond):
|
26 |
+
bt = bond.GetBondType()
|
27 |
+
return torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()])
|
28 |
+
|
29 |
+
class JTMPN(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, hidden_size, depth):
|
32 |
+
super(JTMPN, self).__init__()
|
33 |
+
self.hidden_size = hidden_size
|
34 |
+
self.depth = depth
|
35 |
+
|
36 |
+
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
|
37 |
+
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
|
38 |
+
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
|
39 |
+
|
40 |
+
def forward(self, fatoms, fbonds, agraph, bgraph, scope, tree_message): #tree_message[0] == vec(0)
|
41 |
+
fatoms = create_var(fatoms)
|
42 |
+
fbonds = create_var(fbonds)
|
43 |
+
agraph = create_var(agraph)
|
44 |
+
bgraph = create_var(bgraph)
|
45 |
+
|
46 |
+
binput = self.W_i(fbonds)
|
47 |
+
graph_message = F.relu(binput)
|
48 |
+
|
49 |
+
for i in range(self.depth - 1):
|
50 |
+
message = torch.cat([tree_message,graph_message], dim=0)
|
51 |
+
nei_message = index_select_ND(message, 0, bgraph)
|
52 |
+
nei_message = nei_message.sum(dim=1) #assuming tree_message[0] == vec(0)
|
53 |
+
nei_message = self.W_h(nei_message)
|
54 |
+
graph_message = F.relu(binput + nei_message)
|
55 |
+
|
56 |
+
message = torch.cat([tree_message,graph_message], dim=0)
|
57 |
+
nei_message = index_select_ND(message, 0, agraph)
|
58 |
+
nei_message = nei_message.sum(dim=1)
|
59 |
+
ainput = torch.cat([fatoms, nei_message], dim=1)
|
60 |
+
atom_hiddens = F.relu(self.W_o(ainput))
|
61 |
+
|
62 |
+
mol_vecs = []
|
63 |
+
for st,le in scope:
|
64 |
+
mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
|
65 |
+
mol_vecs.append(mol_vec)
|
66 |
+
|
67 |
+
mol_vecs = torch.stack(mol_vecs, dim=0)
|
68 |
+
return mol_vecs
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def tensorize(cand_batch, mess_dict):
|
72 |
+
fatoms,fbonds = [],[]
|
73 |
+
in_bonds,all_bonds = [],[]
|
74 |
+
total_atoms = 0
|
75 |
+
total_mess = len(mess_dict) + 1 #must include vec(0) padding
|
76 |
+
scope = []
|
77 |
+
|
78 |
+
for smiles,all_nodes,ctr_node in cand_batch:
|
79 |
+
mol = Chem.MolFromSmiles(smiles)
|
80 |
+
Chem.Kekulize(mol) #The original jtnn version kekulizes. Need to revisit why it is necessary
|
81 |
+
n_atoms = mol.GetNumAtoms()
|
82 |
+
ctr_bid = ctr_node.idx
|
83 |
+
|
84 |
+
for atom in mol.GetAtoms():
|
85 |
+
fatoms.append( atom_features(atom) )
|
86 |
+
in_bonds.append([])
|
87 |
+
|
88 |
+
for bond in mol.GetBonds():
|
89 |
+
a1 = bond.GetBeginAtom()
|
90 |
+
a2 = bond.GetEndAtom()
|
91 |
+
x = a1.GetIdx() + total_atoms
|
92 |
+
y = a2.GetIdx() + total_atoms
|
93 |
+
#Here x_nid,y_nid could be 0
|
94 |
+
x_nid,y_nid = a1.GetAtomMapNum(),a2.GetAtomMapNum()
|
95 |
+
x_bid = all_nodes[x_nid - 1].idx if x_nid > 0 else -1
|
96 |
+
y_bid = all_nodes[y_nid - 1].idx if y_nid > 0 else -1
|
97 |
+
|
98 |
+
bfeature = bond_features(bond)
|
99 |
+
|
100 |
+
b = total_mess + len(all_bonds) #bond idx offseted by total_mess
|
101 |
+
all_bonds.append((x,y))
|
102 |
+
fbonds.append( torch.cat([fatoms[x], bfeature], 0) )
|
103 |
+
in_bonds[y].append(b)
|
104 |
+
|
105 |
+
b = total_mess + len(all_bonds)
|
106 |
+
all_bonds.append((y,x))
|
107 |
+
fbonds.append( torch.cat([fatoms[y], bfeature], 0) )
|
108 |
+
in_bonds[x].append(b)
|
109 |
+
|
110 |
+
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
|
111 |
+
if (x_bid,y_bid) in mess_dict:
|
112 |
+
mess_idx = mess_dict[(x_bid,y_bid)]
|
113 |
+
in_bonds[y].append(mess_idx)
|
114 |
+
if (y_bid,x_bid) in mess_dict:
|
115 |
+
mess_idx = mess_dict[(y_bid,x_bid)]
|
116 |
+
in_bonds[x].append(mess_idx)
|
117 |
+
|
118 |
+
scope.append((total_atoms,n_atoms))
|
119 |
+
total_atoms += n_atoms
|
120 |
+
|
121 |
+
total_bonds = len(all_bonds)
|
122 |
+
fatoms = torch.stack(fatoms, 0)
|
123 |
+
fbonds = torch.stack(fbonds, 0)
|
124 |
+
agraph = torch.zeros(total_atoms,MAX_NB).long()
|
125 |
+
bgraph = torch.zeros(total_bonds,MAX_NB).long()
|
126 |
+
|
127 |
+
for a in range(total_atoms):
|
128 |
+
for i,b in enumerate(in_bonds[a]):
|
129 |
+
agraph[a,i] = b
|
130 |
+
|
131 |
+
for b1 in range(total_bonds):
|
132 |
+
x,y = all_bonds[b1]
|
133 |
+
for i,b2 in enumerate(in_bonds[x]): #b2 is offseted by total_mess
|
134 |
+
if b2 < total_mess or all_bonds[b2-total_mess][0] != y:
|
135 |
+
bgraph[b1,i] = b2
|
136 |
+
|
137 |
+
return (fatoms, fbonds, agraph, bgraph, scope)
|
138 |
+
|
fast_jtnn/jtnn_dec.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from mol_tree import Vocab, MolTree, MolTreeNode
|
5 |
+
from nnutils import create_var, GRU
|
6 |
+
from chemutils import enum_assemble, set_atommap
|
7 |
+
import copy
|
8 |
+
|
9 |
+
MAX_NB = 15
|
10 |
+
MAX_DECODE_LEN = 100
|
11 |
+
|
12 |
+
class JTNNDecoder(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, vocab, hidden_size, latent_size, embedding):
|
15 |
+
super(JTNNDecoder, self).__init__()
|
16 |
+
self.hidden_size = hidden_size
|
17 |
+
self.vocab_size = vocab.size()
|
18 |
+
self.vocab = vocab
|
19 |
+
self.embedding = embedding
|
20 |
+
|
21 |
+
#GRU Weights
|
22 |
+
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
|
23 |
+
self.U_r = nn.Linear(hidden_size, hidden_size, bias=False)
|
24 |
+
self.W_r = nn.Linear(hidden_size, hidden_size)
|
25 |
+
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
|
26 |
+
|
27 |
+
#Word Prediction Weights
|
28 |
+
self.W = nn.Linear(hidden_size + latent_size, hidden_size)
|
29 |
+
|
30 |
+
#Stop Prediction Weights
|
31 |
+
self.U = nn.Linear(hidden_size + latent_size, hidden_size)
|
32 |
+
self.U_i = nn.Linear(2 * hidden_size, hidden_size)
|
33 |
+
|
34 |
+
#Output Weights
|
35 |
+
self.W_o = nn.Linear(hidden_size, self.vocab_size)
|
36 |
+
self.U_o = nn.Linear(hidden_size, 1)
|
37 |
+
|
38 |
+
#Loss Functions
|
39 |
+
# self.pred_loss = nn.CrossEntropyLoss(size_average=False)
|
40 |
+
# self.stop_loss = nn.BCEWithLogitsLoss(size_average=False)
|
41 |
+
self.pred_loss = nn.CrossEntropyLoss(reduction='sum')
|
42 |
+
self.stop_loss = nn.BCEWithLogitsLoss(reduction='sum')
|
43 |
+
|
44 |
+
def aggregate(self, hiddens, contexts, x_tree_vecs, mode):
|
45 |
+
if mode == 'word':
|
46 |
+
V, V_o = self.W, self.W_o
|
47 |
+
elif mode == 'stop':
|
48 |
+
V, V_o = self.U, self.U_o
|
49 |
+
else:
|
50 |
+
raise ValueError('aggregate mode is wrong')
|
51 |
+
|
52 |
+
tree_contexts = x_tree_vecs.index_select(0, contexts)
|
53 |
+
input_vec = torch.cat([hiddens, tree_contexts], dim=-1)
|
54 |
+
output_vec = F.relu( V(input_vec) )
|
55 |
+
return V_o(output_vec)
|
56 |
+
|
57 |
+
def forward(self, mol_batch, x_tree_vecs):
|
58 |
+
pred_hiddens,pred_contexts,pred_targets = [],[],[]
|
59 |
+
stop_hiddens,stop_contexts,stop_targets = [],[],[]
|
60 |
+
traces = []
|
61 |
+
for mol_tree in mol_batch:
|
62 |
+
s = []
|
63 |
+
dfs(s, mol_tree.nodes[0], -1)
|
64 |
+
traces.append(s)
|
65 |
+
for node in mol_tree.nodes:
|
66 |
+
node.neighbors = []
|
67 |
+
|
68 |
+
#Predict Root
|
69 |
+
batch_size = len(mol_batch)
|
70 |
+
pred_hiddens.append(create_var(torch.zeros(len(mol_batch),self.hidden_size)))
|
71 |
+
pred_targets.extend([mol_tree.nodes[0].wid for mol_tree in mol_batch])
|
72 |
+
pred_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
|
73 |
+
|
74 |
+
max_iter = max([len(tr) for tr in traces])
|
75 |
+
padding = create_var(torch.zeros(self.hidden_size), False)
|
76 |
+
h = {}
|
77 |
+
|
78 |
+
for t in range(max_iter):
|
79 |
+
prop_list = []
|
80 |
+
batch_list = []
|
81 |
+
for i,plist in enumerate(traces):
|
82 |
+
if t < len(plist):
|
83 |
+
prop_list.append(plist[t])
|
84 |
+
batch_list.append(i)
|
85 |
+
|
86 |
+
cur_x = []
|
87 |
+
cur_h_nei,cur_o_nei = [],[]
|
88 |
+
|
89 |
+
for node_x, real_y, _ in prop_list:
|
90 |
+
#Neighbors for message passing (target not included)
|
91 |
+
cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != real_y.idx]
|
92 |
+
pad_len = MAX_NB - len(cur_nei)
|
93 |
+
cur_h_nei.extend(cur_nei)
|
94 |
+
cur_h_nei.extend([padding] * pad_len)
|
95 |
+
|
96 |
+
#Neighbors for stop prediction (all neighbors)
|
97 |
+
cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
|
98 |
+
pad_len = MAX_NB - len(cur_nei)
|
99 |
+
cur_o_nei.extend(cur_nei)
|
100 |
+
cur_o_nei.extend([padding] * pad_len)
|
101 |
+
|
102 |
+
#Current clique embedding
|
103 |
+
cur_x.append(node_x.wid)
|
104 |
+
|
105 |
+
#Clique embedding
|
106 |
+
cur_x = create_var(torch.LongTensor(cur_x))
|
107 |
+
cur_x = self.embedding(cur_x)
|
108 |
+
|
109 |
+
#Message passing
|
110 |
+
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
|
111 |
+
new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
|
112 |
+
|
113 |
+
#Node Aggregate
|
114 |
+
cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
|
115 |
+
cur_o = cur_o_nei.sum(dim=1)
|
116 |
+
|
117 |
+
#Gather targets
|
118 |
+
pred_target,pred_list = [],[]
|
119 |
+
stop_target = []
|
120 |
+
for i,m in enumerate(prop_list):
|
121 |
+
node_x,node_y,direction = m
|
122 |
+
x,y = node_x.idx,node_y.idx
|
123 |
+
h[(x,y)] = new_h[i]
|
124 |
+
node_y.neighbors.append(node_x)
|
125 |
+
if direction == 1:
|
126 |
+
pred_target.append(node_y.wid)
|
127 |
+
pred_list.append(i)
|
128 |
+
stop_target.append(direction)
|
129 |
+
|
130 |
+
#Hidden states for stop prediction
|
131 |
+
cur_batch = create_var(torch.LongTensor(batch_list))
|
132 |
+
stop_hidden = torch.cat([cur_x,cur_o], dim=1)
|
133 |
+
stop_hiddens.append( stop_hidden )
|
134 |
+
stop_contexts.append( cur_batch )
|
135 |
+
stop_targets.extend( stop_target )
|
136 |
+
|
137 |
+
#Hidden states for clique prediction
|
138 |
+
if len(pred_list) > 0:
|
139 |
+
batch_list = [batch_list[i] for i in pred_list]
|
140 |
+
cur_batch = create_var(torch.LongTensor(batch_list))
|
141 |
+
pred_contexts.append( cur_batch )
|
142 |
+
|
143 |
+
cur_pred = create_var(torch.LongTensor(pred_list))
|
144 |
+
pred_hiddens.append( new_h.index_select(0, cur_pred) )
|
145 |
+
pred_targets.extend( pred_target )
|
146 |
+
|
147 |
+
#Last stop at root
|
148 |
+
cur_x,cur_o_nei = [],[]
|
149 |
+
for mol_tree in mol_batch:
|
150 |
+
node_x = mol_tree.nodes[0]
|
151 |
+
cur_x.append(node_x.wid)
|
152 |
+
cur_nei = [h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors]
|
153 |
+
pad_len = MAX_NB - len(cur_nei)
|
154 |
+
cur_o_nei.extend(cur_nei)
|
155 |
+
cur_o_nei.extend([padding] * pad_len)
|
156 |
+
|
157 |
+
cur_x = create_var(torch.LongTensor(cur_x))
|
158 |
+
cur_x = self.embedding(cur_x)
|
159 |
+
cur_o_nei = torch.stack(cur_o_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
|
160 |
+
cur_o = cur_o_nei.sum(dim=1)
|
161 |
+
|
162 |
+
stop_hidden = torch.cat([cur_x,cur_o], dim=1)
|
163 |
+
stop_hiddens.append( stop_hidden )
|
164 |
+
stop_contexts.append( create_var( torch.LongTensor(range(batch_size)) ) )
|
165 |
+
stop_targets.extend( [0] * len(mol_batch) )
|
166 |
+
|
167 |
+
#Predict next clique
|
168 |
+
pred_contexts = torch.cat(pred_contexts, dim=0)
|
169 |
+
pred_hiddens = torch.cat(pred_hiddens, dim=0)
|
170 |
+
pred_scores = self.aggregate(pred_hiddens, pred_contexts, x_tree_vecs, 'word')
|
171 |
+
pred_targets = create_var(torch.LongTensor(pred_targets))
|
172 |
+
|
173 |
+
pred_loss = self.pred_loss(pred_scores, pred_targets) / len(mol_batch)
|
174 |
+
_,preds = torch.max(pred_scores, dim=1)
|
175 |
+
pred_acc = torch.eq(preds, pred_targets).float()
|
176 |
+
pred_acc = torch.sum(pred_acc) / pred_targets.nelement()
|
177 |
+
|
178 |
+
#Predict stop
|
179 |
+
stop_contexts = torch.cat(stop_contexts, dim=0)
|
180 |
+
stop_hiddens = torch.cat(stop_hiddens, dim=0)
|
181 |
+
stop_hiddens = F.relu( self.U_i(stop_hiddens) )
|
182 |
+
stop_scores = self.aggregate(stop_hiddens, stop_contexts, x_tree_vecs, 'stop')
|
183 |
+
stop_scores = stop_scores.squeeze(-1)
|
184 |
+
stop_targets = create_var(torch.Tensor(stop_targets))
|
185 |
+
|
186 |
+
stop_loss = self.stop_loss(stop_scores, stop_targets) / len(mol_batch)
|
187 |
+
stops = torch.ge(stop_scores, 0).float()
|
188 |
+
stop_acc = torch.eq(stops, stop_targets).float()
|
189 |
+
stop_acc = torch.sum(stop_acc) / stop_targets.nelement()
|
190 |
+
|
191 |
+
return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
|
192 |
+
|
193 |
+
def decode(self, x_tree_vecs, prob_decode):
|
194 |
+
assert x_tree_vecs.size(0) == 1
|
195 |
+
|
196 |
+
stack = []
|
197 |
+
init_hiddens = create_var( torch.zeros(1, self.hidden_size) )
|
198 |
+
zero_pad = create_var(torch.zeros(1,1,self.hidden_size))
|
199 |
+
contexts = create_var( torch.LongTensor(1).zero_() )
|
200 |
+
|
201 |
+
#Root Prediction
|
202 |
+
root_score = self.aggregate(init_hiddens, contexts, x_tree_vecs, 'word')
|
203 |
+
_,root_wid = torch.max(root_score, dim=1)
|
204 |
+
root_wid = root_wid.item()
|
205 |
+
|
206 |
+
root = MolTreeNode(self.vocab.get_smiles(root_wid))
|
207 |
+
root.wid = root_wid
|
208 |
+
root.idx = 0
|
209 |
+
stack.append( (root, self.vocab.get_slots(root.wid)) )
|
210 |
+
|
211 |
+
all_nodes = [root]
|
212 |
+
h = {}
|
213 |
+
for step in range(MAX_DECODE_LEN):
|
214 |
+
node_x,fa_slot = stack[-1]
|
215 |
+
cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors ]
|
216 |
+
if len(cur_h_nei) > 0:
|
217 |
+
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
|
218 |
+
else:
|
219 |
+
cur_h_nei = zero_pad
|
220 |
+
|
221 |
+
cur_x = create_var(torch.LongTensor([node_x.wid]))
|
222 |
+
cur_x = self.embedding(cur_x)
|
223 |
+
|
224 |
+
#Predict stop
|
225 |
+
cur_h = cur_h_nei.sum(dim=1)
|
226 |
+
stop_hiddens = torch.cat([cur_x,cur_h], dim=1)
|
227 |
+
stop_hiddens = F.relu( self.U_i(stop_hiddens) )
|
228 |
+
stop_score = self.aggregate(stop_hiddens, contexts, x_tree_vecs, 'stop')
|
229 |
+
|
230 |
+
if prob_decode:
|
231 |
+
backtrack = (torch.bernoulli( torch.sigmoid(stop_score) ).item() == 0)
|
232 |
+
else:
|
233 |
+
backtrack = (stop_score.item() < 0)
|
234 |
+
|
235 |
+
if not backtrack: #Forward: Predict next clique
|
236 |
+
new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
|
237 |
+
pred_score = self.aggregate(new_h, contexts, x_tree_vecs, 'word')
|
238 |
+
|
239 |
+
if prob_decode:
|
240 |
+
sort_wid = torch.multinomial(F.softmax(pred_score, dim=1).squeeze(), 5)
|
241 |
+
else:
|
242 |
+
_,sort_wid = torch.sort(pred_score, dim=1, descending=True)
|
243 |
+
sort_wid = sort_wid.data.squeeze()
|
244 |
+
|
245 |
+
next_wid = None
|
246 |
+
for wid in sort_wid[:5]:
|
247 |
+
slots = self.vocab.get_slots(wid)
|
248 |
+
node_y = MolTreeNode(self.vocab.get_smiles(wid))
|
249 |
+
if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
|
250 |
+
next_wid = wid
|
251 |
+
next_slots = slots
|
252 |
+
break
|
253 |
+
|
254 |
+
if next_wid is None:
|
255 |
+
backtrack = True #No more children can be added
|
256 |
+
else:
|
257 |
+
node_y = MolTreeNode(self.vocab.get_smiles(next_wid))
|
258 |
+
node_y.wid = next_wid
|
259 |
+
node_y.idx = len(all_nodes)
|
260 |
+
node_y.neighbors.append(node_x)
|
261 |
+
h[(node_x.idx,node_y.idx)] = new_h[0]
|
262 |
+
stack.append( (node_y,next_slots) )
|
263 |
+
all_nodes.append(node_y)
|
264 |
+
|
265 |
+
if backtrack: #Backtrack, use if instead of else
|
266 |
+
if len(stack) == 1:
|
267 |
+
break #At root, terminate
|
268 |
+
|
269 |
+
node_fa,_ = stack[-2]
|
270 |
+
cur_h_nei = [ h[(node_y.idx,node_x.idx)] for node_y in node_x.neighbors if node_y.idx != node_fa.idx ]
|
271 |
+
if len(cur_h_nei) > 0:
|
272 |
+
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1,-1,self.hidden_size)
|
273 |
+
else:
|
274 |
+
cur_h_nei = zero_pad
|
275 |
+
|
276 |
+
new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
|
277 |
+
h[(node_x.idx,node_fa.idx)] = new_h[0]
|
278 |
+
node_fa.neighbors.append(node_x)
|
279 |
+
stack.pop()
|
280 |
+
|
281 |
+
return root, all_nodes
|
282 |
+
|
283 |
+
"""
|
284 |
+
Helper Functions:
|
285 |
+
"""
|
286 |
+
def dfs(stack, x, fa_idx):
|
287 |
+
for y in x.neighbors:
|
288 |
+
if y.idx == fa_idx: continue
|
289 |
+
stack.append( (x,y,1) )
|
290 |
+
dfs(stack, y, x.idx)
|
291 |
+
stack.append( (y,x,0) )
|
292 |
+
|
293 |
+
def have_slots(fa_slots, ch_slots):
|
294 |
+
if len(fa_slots) > 2 and len(ch_slots) > 2:
|
295 |
+
return True
|
296 |
+
matches = []
|
297 |
+
for i,s1 in enumerate(fa_slots):
|
298 |
+
a1,c1,h1 = s1
|
299 |
+
for j,s2 in enumerate(ch_slots):
|
300 |
+
a2,c2,h2 = s2
|
301 |
+
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
|
302 |
+
matches.append( (i,j) )
|
303 |
+
|
304 |
+
if len(matches) == 0: return False
|
305 |
+
|
306 |
+
fa_match,ch_match = zip(*matches)
|
307 |
+
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
|
308 |
+
fa_slots.pop(fa_match[0])
|
309 |
+
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
|
310 |
+
ch_slots.pop(ch_match[0])
|
311 |
+
|
312 |
+
return True
|
313 |
+
|
314 |
+
def can_assemble(node_x, node_y):
|
315 |
+
node_x.nid = 1
|
316 |
+
node_x.is_leaf = False
|
317 |
+
set_atommap(node_x.mol, node_x.nid)
|
318 |
+
|
319 |
+
neis = node_x.neighbors + [node_y]
|
320 |
+
for i,nei in enumerate(neis):
|
321 |
+
nei.nid = i + 2
|
322 |
+
nei.is_leaf = (len(nei.neighbors) <= 1)
|
323 |
+
if nei.is_leaf:
|
324 |
+
set_atommap(nei.mol, 0)
|
325 |
+
else:
|
326 |
+
set_atommap(nei.mol, nei.nid)
|
327 |
+
|
328 |
+
neighbors = [nei for nei in neis if nei.mol.GetNumAtoms() > 1]
|
329 |
+
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
|
330 |
+
singletons = [nei for nei in neis if nei.mol.GetNumAtoms() == 1]
|
331 |
+
neighbors = singletons + neighbors
|
332 |
+
cands,aroma_scores = enum_assemble(node_x, neighbors)
|
333 |
+
return len(cands) > 0# and sum(aroma_scores) >= 0
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1","O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
|
337 |
+
for s in smiles:
|
338 |
+
print(s)
|
339 |
+
tree = MolTree(s)
|
340 |
+
for i,node in enumerate(tree.nodes):
|
341 |
+
node.idx = i
|
342 |
+
|
343 |
+
stack = []
|
344 |
+
dfs(stack, tree.nodes[0], -1)
|
345 |
+
for x,y,d in stack:
|
346 |
+
print(x.smiles, y.smiles, d)
|
347 |
+
print('------------------------------')
|
fast_jtnn/jtnn_enc.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from collections import deque
|
5 |
+
from mol_tree import Vocab, MolTree
|
6 |
+
from nnutils import create_var, index_select_ND
|
7 |
+
|
8 |
+
class JTNNEncoder(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, hidden_size, depth, embedding):
|
11 |
+
super(JTNNEncoder, self).__init__()
|
12 |
+
self.hidden_size = hidden_size
|
13 |
+
self.depth = depth
|
14 |
+
|
15 |
+
self.embedding = embedding
|
16 |
+
self.outputNN = nn.Sequential(
|
17 |
+
nn.Linear(2 * hidden_size, hidden_size),
|
18 |
+
nn.ReLU()
|
19 |
+
)
|
20 |
+
self.GRU = GraphGRU(hidden_size, hidden_size, depth=depth)
|
21 |
+
|
22 |
+
def forward(self, fnode, fmess, node_graph, mess_graph, scope):
|
23 |
+
fnode = create_var(fnode)
|
24 |
+
fmess = create_var(fmess)
|
25 |
+
node_graph = create_var(node_graph)
|
26 |
+
mess_graph = create_var(mess_graph)
|
27 |
+
messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))
|
28 |
+
|
29 |
+
fnode = self.embedding(fnode)
|
30 |
+
fmess = index_select_ND(fnode, 0, fmess)
|
31 |
+
messages = self.GRU(messages, fmess, mess_graph)
|
32 |
+
|
33 |
+
mess_nei = index_select_ND(messages, 0, node_graph)
|
34 |
+
node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
|
35 |
+
node_vecs = self.outputNN(node_vecs)
|
36 |
+
|
37 |
+
max_len = max([x for _,x in scope])
|
38 |
+
batch_vecs = []
|
39 |
+
for st,le in scope:
|
40 |
+
cur_vecs = node_vecs[st] #Root is the first node
|
41 |
+
batch_vecs.append( cur_vecs )
|
42 |
+
|
43 |
+
tree_vecs = torch.stack(batch_vecs, dim=0)
|
44 |
+
return tree_vecs, messages
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def tensorize(tree_batch):
|
48 |
+
node_batch = []
|
49 |
+
scope = []
|
50 |
+
for tree in tree_batch:
|
51 |
+
scope.append( (len(node_batch), len(tree.nodes)) )
|
52 |
+
node_batch.extend(tree.nodes)
|
53 |
+
|
54 |
+
return JTNNEncoder.tensorize_nodes(node_batch, scope)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def tensorize_nodes(node_batch, scope):
|
58 |
+
messages,mess_dict = [None],{}
|
59 |
+
fnode = []
|
60 |
+
for x in node_batch:
|
61 |
+
fnode.append(x.wid)
|
62 |
+
for y in x.neighbors:
|
63 |
+
mess_dict[(x.idx,y.idx)] = len(messages)
|
64 |
+
messages.append( (x,y) )
|
65 |
+
|
66 |
+
node_graph = [[] for i in range(len(node_batch))]
|
67 |
+
mess_graph = [[] for i in range(len(messages))]
|
68 |
+
fmess = [0] * len(messages)
|
69 |
+
|
70 |
+
for x,y in messages[1:]:
|
71 |
+
mid1 = mess_dict[(x.idx,y.idx)]
|
72 |
+
fmess[mid1] = x.idx
|
73 |
+
node_graph[y.idx].append(mid1)
|
74 |
+
for z in y.neighbors:
|
75 |
+
if z.idx == x.idx: continue
|
76 |
+
mid2 = mess_dict[(y.idx,z.idx)]
|
77 |
+
mess_graph[mid2].append(mid1)
|
78 |
+
|
79 |
+
max_len = max([len(t) for t in node_graph] + [1])
|
80 |
+
for t in node_graph:
|
81 |
+
pad_len = max_len - len(t)
|
82 |
+
t.extend([0] * pad_len)
|
83 |
+
|
84 |
+
max_len = max([len(t) for t in mess_graph] + [1])
|
85 |
+
for t in mess_graph:
|
86 |
+
pad_len = max_len - len(t)
|
87 |
+
t.extend([0] * pad_len)
|
88 |
+
|
89 |
+
mess_graph = torch.LongTensor(mess_graph)
|
90 |
+
node_graph = torch.LongTensor(node_graph)
|
91 |
+
fmess = torch.LongTensor(fmess)
|
92 |
+
fnode = torch.LongTensor(fnode)
|
93 |
+
return (fnode, fmess, node_graph, mess_graph, scope), mess_dict
|
94 |
+
|
95 |
+
class GraphGRU(nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, input_size, hidden_size, depth):
|
98 |
+
super(GraphGRU, self).__init__()
|
99 |
+
self.hidden_size = hidden_size
|
100 |
+
self.input_size = input_size
|
101 |
+
self.depth = depth
|
102 |
+
|
103 |
+
self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
|
104 |
+
self.W_r = nn.Linear(input_size, hidden_size, bias=False)
|
105 |
+
self.U_r = nn.Linear(hidden_size, hidden_size)
|
106 |
+
self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
|
107 |
+
|
108 |
+
def forward(self, h, x, mess_graph):
|
109 |
+
mask = torch.ones(h.size(0), 1)
|
110 |
+
mask[0] = 0 #first vector is padding
|
111 |
+
mask = create_var(mask)
|
112 |
+
for it in range(self.depth):
|
113 |
+
h_nei = index_select_ND(h, 0, mess_graph)
|
114 |
+
sum_h = h_nei.sum(dim=1)
|
115 |
+
z_input = torch.cat([x, sum_h], dim=1)
|
116 |
+
z = F.sigmoid(self.W_z(z_input))
|
117 |
+
|
118 |
+
r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
|
119 |
+
r_2 = self.U_r(h_nei)
|
120 |
+
r = F.sigmoid(r_1 + r_2)
|
121 |
+
|
122 |
+
gated_h = r * h_nei
|
123 |
+
sum_gated_h = gated_h.sum(dim=1)
|
124 |
+
h_input = torch.cat([x, sum_gated_h], dim=1)
|
125 |
+
pre_h = F.tanh(self.W_h(h_input))
|
126 |
+
h = (1.0 - z) * sum_h + z * pre_h
|
127 |
+
h = h * mask
|
128 |
+
|
129 |
+
return h
|
130 |
+
|
131 |
+
|
fast_jtnn/jtnn_vae.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from mol_tree import Vocab, MolTree
|
5 |
+
from nnutils import create_var, flatten_tensor, avg_pool
|
6 |
+
from jtnn_enc import JTNNEncoder
|
7 |
+
from jtnn_dec import JTNNDecoder
|
8 |
+
from mpn import MPN
|
9 |
+
from jtmpn import JTMPN
|
10 |
+
from datautils import tensorize
|
11 |
+
|
12 |
+
from chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
|
13 |
+
import rdkit
|
14 |
+
import rdkit.Chem as Chem
|
15 |
+
import copy, math
|
16 |
+
|
17 |
+
class JTNNVAE(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
|
20 |
+
super(JTNNVAE, self).__init__()
|
21 |
+
self.vocab = vocab
|
22 |
+
self.hidden_size = hidden_size
|
23 |
+
self.latent_size = latent_size = int(latent_size / 2) #Tree and Mol has two vectors
|
24 |
+
|
25 |
+
self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
|
26 |
+
self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))
|
27 |
+
|
28 |
+
self.jtmpn = JTMPN(hidden_size, depthG)
|
29 |
+
self.mpn = MPN(hidden_size, depthG)
|
30 |
+
|
31 |
+
self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
|
32 |
+
# self.assm_loss = nn.CrossEntropyLoss(size_average=False)
|
33 |
+
self.assm_loss = nn.CrossEntropyLoss(reduction='sum')
|
34 |
+
|
35 |
+
self.T_mean = nn.Linear(hidden_size, latent_size)
|
36 |
+
self.T_var = nn.Linear(hidden_size, latent_size)
|
37 |
+
self.G_mean = nn.Linear(hidden_size, latent_size)
|
38 |
+
self.G_var = nn.Linear(hidden_size, latent_size)
|
39 |
+
|
40 |
+
def encode(self, jtenc_holder, mpn_holder):
|
41 |
+
tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
|
42 |
+
mol_vecs = self.mpn(*mpn_holder)
|
43 |
+
return tree_vecs, tree_mess, mol_vecs
|
44 |
+
|
45 |
+
def encode_from_smiles(self, smiles_list):
|
46 |
+
tree_batch = [MolTree(s) for s in smiles_list]
|
47 |
+
_, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
|
48 |
+
tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder)
|
49 |
+
return torch.cat([tree_vecs, mol_vecs], dim=-1)
|
50 |
+
|
51 |
+
def encode_latent(self, jtenc_holder, mpn_holder):
|
52 |
+
tree_vecs, _ = self.jtnn(*jtenc_holder)
|
53 |
+
mol_vecs = self.mpn(*mpn_holder)
|
54 |
+
tree_mean = self.T_mean(tree_vecs)
|
55 |
+
mol_mean = self.G_mean(mol_vecs)
|
56 |
+
tree_var = -torch.abs(self.T_var(tree_vecs))
|
57 |
+
mol_var = -torch.abs(self.G_var(mol_vecs))
|
58 |
+
return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1)
|
59 |
+
|
60 |
+
def rsample(self, z_vecs, W_mean, W_var):
|
61 |
+
batch_size = z_vecs.size(0)
|
62 |
+
z_mean = W_mean(z_vecs)
|
63 |
+
z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
|
64 |
+
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
|
65 |
+
epsilon = create_var(torch.randn_like(z_mean))
|
66 |
+
z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
|
67 |
+
return z_vecs, kl_loss
|
68 |
+
|
69 |
+
def sample_prior(self, prob_decode=False):
|
70 |
+
z_tree = torch.randn(1, self.latent_size).cuda()
|
71 |
+
z_mol = torch.randn(1, self.latent_size).cuda()
|
72 |
+
return self.decode(z_tree, z_mol, prob_decode)
|
73 |
+
|
74 |
+
def forward(self, x_batch, beta):
|
75 |
+
x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder = x_batch
|
76 |
+
x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
|
77 |
+
z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var)
|
78 |
+
z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var)
|
79 |
+
|
80 |
+
kl_div = tree_kl + mol_kl
|
81 |
+
word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
|
82 |
+
assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess)
|
83 |
+
|
84 |
+
return word_loss + topo_loss + assm_loss + beta * kl_div, kl_div.item(), word_acc, topo_acc, assm_acc
|
85 |
+
|
86 |
+
def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
|
87 |
+
jtmpn_holder,batch_idx = jtmpn_holder
|
88 |
+
fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
|
89 |
+
batch_idx = create_var(batch_idx)
|
90 |
+
|
91 |
+
cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess)
|
92 |
+
|
93 |
+
x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
|
94 |
+
x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
|
95 |
+
scores = torch.bmm(
|
96 |
+
x_mol_vecs.unsqueeze(1),
|
97 |
+
cand_vecs.unsqueeze(-1)
|
98 |
+
).squeeze()
|
99 |
+
|
100 |
+
cnt,tot,acc = 0,0,0
|
101 |
+
all_loss = []
|
102 |
+
for i,mol_tree in enumerate(mol_batch):
|
103 |
+
comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
|
104 |
+
cnt += len(comp_nodes)
|
105 |
+
for node in comp_nodes:
|
106 |
+
label = node.cands.index(node.label)
|
107 |
+
ncand = len(node.cands)
|
108 |
+
cur_score = scores.narrow(0, tot, ncand)
|
109 |
+
tot += ncand
|
110 |
+
|
111 |
+
if cur_score.data[label] >= cur_score.max().item():
|
112 |
+
acc += 1
|
113 |
+
|
114 |
+
label = create_var(torch.LongTensor([label]))
|
115 |
+
all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
|
116 |
+
|
117 |
+
all_loss = sum(all_loss) / len(mol_batch)
|
118 |
+
return all_loss, acc * 1.0 / cnt
|
119 |
+
|
120 |
+
def decode(self, x_tree_vecs, x_mol_vecs, prob_decode):
|
121 |
+
#currently do not support batch decoding
|
122 |
+
assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1
|
123 |
+
|
124 |
+
pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode)
|
125 |
+
if len(pred_nodes) == 0: return None
|
126 |
+
elif len(pred_nodes) == 1: return pred_root.smiles
|
127 |
+
|
128 |
+
#Mark nid & is_leaf & atommap
|
129 |
+
for i,node in enumerate(pred_nodes):
|
130 |
+
node.nid = i + 1
|
131 |
+
node.is_leaf = (len(node.neighbors) == 1)
|
132 |
+
if len(node.neighbors) > 1:
|
133 |
+
set_atommap(node.mol, node.nid)
|
134 |
+
|
135 |
+
scope = [(0, len(pred_nodes))]
|
136 |
+
jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
|
137 |
+
_,tree_mess = self.jtnn(*jtenc_holder)
|
138 |
+
tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict
|
139 |
+
|
140 |
+
x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear
|
141 |
+
|
142 |
+
cur_mol = copy_edit_mol(pred_root.mol)
|
143 |
+
global_amap = [{}] + [{} for node in pred_nodes]
|
144 |
+
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
|
145 |
+
|
146 |
+
cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True)
|
147 |
+
if cur_mol is None:
|
148 |
+
cur_mol = copy_edit_mol(pred_root.mol)
|
149 |
+
global_amap = [{}] + [{} for node in pred_nodes]
|
150 |
+
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
|
151 |
+
cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False)
|
152 |
+
if cur_mol is None: cur_mol = pre_mol
|
153 |
+
|
154 |
+
if cur_mol is None:
|
155 |
+
return None
|
156 |
+
|
157 |
+
cur_mol = cur_mol.GetMol()
|
158 |
+
set_atommap(cur_mol)
|
159 |
+
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
|
160 |
+
return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
|
161 |
+
|
162 |
+
def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
|
163 |
+
fa_nid = fa_node.nid if fa_node is not None else -1
|
164 |
+
prev_nodes = [fa_node] if fa_node is not None else []
|
165 |
+
|
166 |
+
children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
|
167 |
+
neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
|
168 |
+
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
|
169 |
+
singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
|
170 |
+
neighbors = singletons + neighbors
|
171 |
+
|
172 |
+
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
|
173 |
+
cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
|
174 |
+
if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
|
175 |
+
return None, cur_mol
|
176 |
+
|
177 |
+
cand_smiles,cand_amap = zip(*cands)
|
178 |
+
aroma_score = torch.Tensor(aroma_score).cuda()
|
179 |
+
cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]
|
180 |
+
|
181 |
+
if len(cands) > 1:
|
182 |
+
jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
|
183 |
+
fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
|
184 |
+
cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
|
185 |
+
scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score
|
186 |
+
else:
|
187 |
+
scores = torch.Tensor([1.0])
|
188 |
+
|
189 |
+
if prob_decode:
|
190 |
+
probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
|
191 |
+
cand_idx = torch.multinomial(probs, probs.numel())
|
192 |
+
else:
|
193 |
+
_,cand_idx = torch.sort(scores, descending=True)
|
194 |
+
|
195 |
+
backup_mol = Chem.RWMol(cur_mol)
|
196 |
+
pre_mol = cur_mol
|
197 |
+
for i in range(cand_idx.numel()):
|
198 |
+
cur_mol = Chem.RWMol(backup_mol)
|
199 |
+
pred_amap = cand_amap[cand_idx[i].item()]
|
200 |
+
new_global_amap = copy.deepcopy(global_amap)
|
201 |
+
|
202 |
+
for nei_id,ctr_atom,nei_atom in pred_amap:
|
203 |
+
if nei_id == fa_nid:
|
204 |
+
continue
|
205 |
+
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]
|
206 |
+
|
207 |
+
cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
|
208 |
+
new_mol = cur_mol.GetMol()
|
209 |
+
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
|
210 |
+
|
211 |
+
if new_mol is None: continue
|
212 |
+
|
213 |
+
has_error = False
|
214 |
+
for nei_node in children:
|
215 |
+
if nei_node.is_leaf: continue
|
216 |
+
tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
|
217 |
+
if tmp_mol is None:
|
218 |
+
has_error = True
|
219 |
+
if i == 0: pre_mol = tmp_mol2
|
220 |
+
break
|
221 |
+
cur_mol = tmp_mol
|
222 |
+
|
223 |
+
if not has_error: return cur_mol, cur_mol
|
224 |
+
|
225 |
+
return None, pre_mol
|
226 |
+
|
fast_jtnn/jtprop_vae.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from mol_tree import Vocab, MolTree
|
5 |
+
from nnutils import create_var, flatten_tensor, avg_pool
|
6 |
+
from jtnn_enc import JTNNEncoder
|
7 |
+
from jtnn_dec import JTNNDecoder
|
8 |
+
from mpn import MPN
|
9 |
+
from jtmpn import JTMPN
|
10 |
+
from datautils import tensorize
|
11 |
+
|
12 |
+
from chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols
|
13 |
+
import rdkit
|
14 |
+
import rdkit.Chem as Chem
|
15 |
+
from rdkit import DataStructs
|
16 |
+
from rdkit.Chem import AllChem
|
17 |
+
import copy, math
|
18 |
+
|
19 |
+
class JTPropVAE(nn.Module):
|
20 |
+
|
21 |
+
def __init__(self, vocab, hidden_size, latent_size, depthT, depthG):
|
22 |
+
super(JTPropVAE, self).__init__()
|
23 |
+
self.vocab = vocab
|
24 |
+
self.hidden_size = hidden_size
|
25 |
+
self.latent_size = latent_size = int(latent_size / 2) #Tree and Mol has two vectors
|
26 |
+
|
27 |
+
self.jtnn = JTNNEncoder(hidden_size, depthT, nn.Embedding(vocab.size(), hidden_size))
|
28 |
+
self.decoder = JTNNDecoder(vocab, hidden_size, latent_size, nn.Embedding(vocab.size(), hidden_size))
|
29 |
+
|
30 |
+
self.jtmpn = JTMPN(hidden_size, depthG)
|
31 |
+
self.mpn = MPN(hidden_size, depthG)
|
32 |
+
|
33 |
+
self.A_assm = nn.Linear(latent_size, hidden_size, bias=False)
|
34 |
+
# self.assm_loss = nn.CrossEntropyLoss(size_average=False)
|
35 |
+
self.assm_loss = nn.CrossEntropyLoss(reduction='sum')
|
36 |
+
|
37 |
+
self.T_mean = nn.Linear(hidden_size, latent_size)
|
38 |
+
self.T_var = nn.Linear(hidden_size, latent_size)
|
39 |
+
self.G_mean = nn.Linear(hidden_size, latent_size)
|
40 |
+
self.G_var = nn.Linear(hidden_size, latent_size)
|
41 |
+
|
42 |
+
# Prop
|
43 |
+
self.propNN = nn.Sequential(
|
44 |
+
nn.Linear(self.latent_size*2, self.hidden_size),
|
45 |
+
nn.Tanh(),
|
46 |
+
nn.Linear(self.hidden_size, 1)
|
47 |
+
)
|
48 |
+
self.prop_loss = nn.MSELoss()
|
49 |
+
|
50 |
+
def encode(self, jtenc_holder, mpn_holder):
|
51 |
+
tree_vecs, tree_mess = self.jtnn(*jtenc_holder)
|
52 |
+
mol_vecs = self.mpn(*mpn_holder)
|
53 |
+
return tree_vecs, tree_mess, mol_vecs
|
54 |
+
|
55 |
+
def encode_from_smiles(self, smiles_list):
|
56 |
+
tree_batch = [MolTree(s) for s in smiles_list]
|
57 |
+
_, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
|
58 |
+
tree_vecs, _, mol_vecs = self.encode(jtenc_holder, mpn_holder)
|
59 |
+
return torch.cat([tree_vecs, mol_vecs], dim=-1)
|
60 |
+
|
61 |
+
def encode_latent(self, jtenc_holder, mpn_holder):
|
62 |
+
tree_vecs, _ = self.jtnn(*jtenc_holder)
|
63 |
+
mol_vecs = self.mpn(*mpn_holder)
|
64 |
+
tree_mean = self.T_mean(tree_vecs)
|
65 |
+
mol_mean = self.G_mean(mol_vecs)
|
66 |
+
tree_var = -torch.abs(self.T_var(tree_vecs))
|
67 |
+
mol_var = -torch.abs(self.G_var(mol_vecs))
|
68 |
+
return torch.cat([tree_mean, mol_mean], dim=1), torch.cat([tree_var, mol_var], dim=1)
|
69 |
+
|
70 |
+
def rsample(self, z_vecs, W_mean, W_var):
|
71 |
+
batch_size = z_vecs.size(0)
|
72 |
+
z_mean = W_mean(z_vecs)
|
73 |
+
z_log_var = -torch.abs(W_var(z_vecs)) #Following Mueller et al.
|
74 |
+
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
|
75 |
+
epsilon = create_var(torch.randn_like(z_mean))
|
76 |
+
z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
|
77 |
+
return z_vecs, kl_loss
|
78 |
+
|
79 |
+
def sample_prior(self, prob_decode=False):
|
80 |
+
z_tree = torch.randn(1, self.latent_size).cuda()
|
81 |
+
z_mol = torch.randn(1, self.latent_size).cuda()
|
82 |
+
return self.decode(z_tree, z_mol, prob_decode)
|
83 |
+
|
84 |
+
def forward(self, x_batch, beta):
|
85 |
+
x_batch, x_jtenc_holder, x_mpn_holder, x_jtmpn_holder, prop_batch = x_batch
|
86 |
+
x_tree_vecs, x_tree_mess, x_mol_vecs = self.encode(x_jtenc_holder, x_mpn_holder)
|
87 |
+
z_tree_vecs,tree_kl = self.rsample(x_tree_vecs, self.T_mean, self.T_var)
|
88 |
+
z_mol_vecs,mol_kl = self.rsample(x_mol_vecs, self.G_mean, self.G_var)
|
89 |
+
|
90 |
+
kl_div = tree_kl + mol_kl
|
91 |
+
word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
|
92 |
+
assm_loss, assm_acc = self.assm(x_batch, x_jtmpn_holder, z_mol_vecs, x_tree_mess)
|
93 |
+
|
94 |
+
all_vec = torch.cat([z_tree_vecs, z_mol_vecs], dim=1)
|
95 |
+
# prop_label = create_var(torch.Tensor(prop_batch))
|
96 |
+
prop_label = create_var(torch.Tensor(prop_batch))
|
97 |
+
prop_loss = self.prop_loss(self.propNN(all_vec).squeeze(), prop_label)
|
98 |
+
|
99 |
+
return word_loss + topo_loss + assm_loss + beta * kl_div + prop_loss, kl_div.item(), word_acc, topo_acc, assm_acc, prop_loss.item()
|
100 |
+
|
101 |
+
def assm(self, mol_batch, jtmpn_holder, x_mol_vecs, x_tree_mess):
|
102 |
+
jtmpn_holder,batch_idx = jtmpn_holder
|
103 |
+
fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
|
104 |
+
batch_idx = create_var(batch_idx)
|
105 |
+
|
106 |
+
cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, x_tree_mess)
|
107 |
+
|
108 |
+
x_mol_vecs = x_mol_vecs.index_select(0, batch_idx)
|
109 |
+
x_mol_vecs = self.A_assm(x_mol_vecs) #bilinear
|
110 |
+
scores = torch.bmm(
|
111 |
+
x_mol_vecs.unsqueeze(1),
|
112 |
+
cand_vecs.unsqueeze(-1)
|
113 |
+
).squeeze()
|
114 |
+
|
115 |
+
cnt,tot,acc = 0,0,0
|
116 |
+
all_loss = []
|
117 |
+
for i,mol_tree in enumerate(mol_batch):
|
118 |
+
comp_nodes = [node for node in mol_tree.nodes if len(node.cands) > 1 and not node.is_leaf]
|
119 |
+
cnt += len(comp_nodes)
|
120 |
+
for node in comp_nodes:
|
121 |
+
label = node.cands.index(node.label)
|
122 |
+
ncand = len(node.cands)
|
123 |
+
cur_score = scores.narrow(0, tot, ncand)
|
124 |
+
tot += ncand
|
125 |
+
|
126 |
+
if cur_score.data[label] >= cur_score.max().item():
|
127 |
+
acc += 1
|
128 |
+
|
129 |
+
label = create_var(torch.LongTensor([label]))
|
130 |
+
all_loss.append( self.assm_loss(cur_score.view(1,-1), label) )
|
131 |
+
|
132 |
+
all_loss = sum(all_loss) / len(mol_batch)
|
133 |
+
return all_loss, acc * 1.0 / cnt
|
134 |
+
|
135 |
+
def decode(self, x_tree_vecs, x_mol_vecs, prob_decode):
|
136 |
+
#currently do not support batch decoding
|
137 |
+
assert x_tree_vecs.size(0) == 1 and x_mol_vecs.size(0) == 1
|
138 |
+
|
139 |
+
pred_root,pred_nodes = self.decoder.decode(x_tree_vecs, prob_decode)
|
140 |
+
if len(pred_nodes) == 0: return None
|
141 |
+
elif len(pred_nodes) == 1: return pred_root.smiles
|
142 |
+
|
143 |
+
#Mark nid & is_leaf & atommap
|
144 |
+
for i,node in enumerate(pred_nodes):
|
145 |
+
node.nid = i + 1
|
146 |
+
node.is_leaf = (len(node.neighbors) == 1)
|
147 |
+
if len(node.neighbors) > 1:
|
148 |
+
set_atommap(node.mol, node.nid)
|
149 |
+
|
150 |
+
scope = [(0, len(pred_nodes))]
|
151 |
+
jtenc_holder,mess_dict = JTNNEncoder.tensorize_nodes(pred_nodes, scope)
|
152 |
+
_,tree_mess = self.jtnn(*jtenc_holder)
|
153 |
+
tree_mess = (tree_mess, mess_dict) #Important: tree_mess is a matrix, mess_dict is a python dict
|
154 |
+
|
155 |
+
x_mol_vecs = self.A_assm(x_mol_vecs).squeeze() #bilinear
|
156 |
+
|
157 |
+
cur_mol = copy_edit_mol(pred_root.mol)
|
158 |
+
global_amap = [{}] + [{} for node in pred_nodes]
|
159 |
+
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
|
160 |
+
|
161 |
+
cur_mol,_ = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=True)
|
162 |
+
if cur_mol is None:
|
163 |
+
cur_mol = copy_edit_mol(pred_root.mol)
|
164 |
+
global_amap = [{}] + [{} for node in pred_nodes]
|
165 |
+
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()}
|
166 |
+
cur_mol,pre_mol = self.dfs_assemble(tree_mess, x_mol_vecs, pred_nodes, cur_mol, global_amap, [], pred_root, None, prob_decode, check_aroma=False)
|
167 |
+
if cur_mol is None: cur_mol = pre_mol
|
168 |
+
|
169 |
+
if cur_mol is None:
|
170 |
+
return None
|
171 |
+
|
172 |
+
cur_mol = cur_mol.GetMol()
|
173 |
+
set_atommap(cur_mol)
|
174 |
+
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
|
175 |
+
return Chem.MolToSmiles(cur_mol) if cur_mol is not None else None
|
176 |
+
|
177 |
+
def dfs_assemble(self, y_tree_mess, x_mol_vecs, all_nodes, cur_mol, global_amap, fa_amap, cur_node, fa_node, prob_decode, check_aroma):
|
178 |
+
fa_nid = fa_node.nid if fa_node is not None else -1
|
179 |
+
prev_nodes = [fa_node] if fa_node is not None else []
|
180 |
+
|
181 |
+
children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid]
|
182 |
+
neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1]
|
183 |
+
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
|
184 |
+
singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1]
|
185 |
+
neighbors = singletons + neighbors
|
186 |
+
|
187 |
+
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid]
|
188 |
+
cands,aroma_score = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
|
189 |
+
if len(cands) == 0 or (sum(aroma_score) < 0 and check_aroma):
|
190 |
+
return None, cur_mol
|
191 |
+
|
192 |
+
cand_smiles,cand_amap = zip(*cands)
|
193 |
+
if torch.cuda.is_available():
|
194 |
+
aroma_score = torch.Tensor(aroma_score).cuda()
|
195 |
+
else:
|
196 |
+
aroma_score = torch.Tensor(aroma_score)
|
197 |
+
cands = [(smiles, all_nodes, cur_node) for smiles in cand_smiles]
|
198 |
+
|
199 |
+
if len(cands) > 1:
|
200 |
+
jtmpn_holder = JTMPN.tensorize(cands, y_tree_mess[1])
|
201 |
+
fatoms,fbonds,agraph,bgraph,scope = jtmpn_holder
|
202 |
+
cand_vecs = self.jtmpn(fatoms, fbonds, agraph, bgraph, scope, y_tree_mess[0])
|
203 |
+
scores = torch.mv(cand_vecs, x_mol_vecs) + aroma_score
|
204 |
+
else:
|
205 |
+
scores = torch.Tensor([1.0])
|
206 |
+
|
207 |
+
if prob_decode:
|
208 |
+
probs = F.softmax(scores.view(1,-1), dim=1).squeeze() + 1e-7 #prevent prob = 0
|
209 |
+
cand_idx = torch.multinomial(probs, probs.numel())
|
210 |
+
else:
|
211 |
+
_,cand_idx = torch.sort(scores, descending=True)
|
212 |
+
|
213 |
+
backup_mol = Chem.RWMol(cur_mol)
|
214 |
+
pre_mol = cur_mol
|
215 |
+
for i in range(cand_idx.numel()):
|
216 |
+
cur_mol = Chem.RWMol(backup_mol)
|
217 |
+
pred_amap = cand_amap[cand_idx[i].item()]
|
218 |
+
new_global_amap = copy.deepcopy(global_amap)
|
219 |
+
|
220 |
+
for nei_id,ctr_atom,nei_atom in pred_amap:
|
221 |
+
if nei_id == fa_nid:
|
222 |
+
continue
|
223 |
+
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node.nid][ctr_atom]
|
224 |
+
|
225 |
+
cur_mol = attach_mols(cur_mol, children, [], new_global_amap) #father is already attached
|
226 |
+
new_mol = cur_mol.GetMol()
|
227 |
+
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
|
228 |
+
|
229 |
+
if new_mol is None: continue
|
230 |
+
|
231 |
+
has_error = False
|
232 |
+
for nei_node in children:
|
233 |
+
if nei_node.is_leaf: continue
|
234 |
+
tmp_mol, tmp_mol2 = self.dfs_assemble(y_tree_mess, x_mol_vecs, all_nodes, cur_mol, new_global_amap, pred_amap, nei_node, cur_node, prob_decode, check_aroma)
|
235 |
+
if tmp_mol is None:
|
236 |
+
has_error = True
|
237 |
+
if i == 0: pre_mol = tmp_mol2
|
238 |
+
break
|
239 |
+
cur_mol = tmp_mol
|
240 |
+
|
241 |
+
if not has_error: return cur_mol, cur_mol
|
242 |
+
|
243 |
+
return None, pre_mol
|
244 |
+
|
245 |
+
def optimize(self, smiles, sim_cutoff, lr=2.0, num_iter=20):
|
246 |
+
# mol_tree = MolTree(smiles)
|
247 |
+
# mol_tree.recover()
|
248 |
+
tree_batch = [MolTree(smiles)]
|
249 |
+
_, jtenc_holder, mpn_holder = tensorize(tree_batch, self.vocab, assm=False)
|
250 |
+
tree_vec, _, mol_vec = self.encode(jtenc_holder, mpn_holder)
|
251 |
+
|
252 |
+
mol = Chem.MolFromSmiles(smiles)
|
253 |
+
fp1 = AllChem.GetMorganFingerprint(mol, 2)
|
254 |
+
|
255 |
+
tree_mean = self.T_mean(tree_vec)
|
256 |
+
tree_log_var = -torch.abs(self.T_var(tree_vec)) #Following Mueller et al.
|
257 |
+
mol_mean = self.G_mean(mol_vec)
|
258 |
+
mol_log_var = -torch.abs(self.G_var(mol_vec)) #Following Mueller et al.
|
259 |
+
mean = torch.cat([tree_mean, mol_mean], dim=1)
|
260 |
+
log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
|
261 |
+
cur_vec = create_var(mean.data, True)
|
262 |
+
|
263 |
+
visited = []
|
264 |
+
for step in range(num_iter):
|
265 |
+
prop_val = self.propNN(cur_vec).squeeze()
|
266 |
+
grad = torch.autograd.grad(prop_val, cur_vec)[0]
|
267 |
+
cur_vec = cur_vec.data + lr * grad.data
|
268 |
+
cur_vec = create_var(cur_vec, True)
|
269 |
+
visited.append(cur_vec)
|
270 |
+
|
271 |
+
l,r = 0, num_iter - 1
|
272 |
+
while l < r - 1:
|
273 |
+
mid = (l + r) // 2
|
274 |
+
new_vec = visited[mid]
|
275 |
+
tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1)
|
276 |
+
new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
|
277 |
+
if new_smiles is None:
|
278 |
+
r = mid - 1
|
279 |
+
continue
|
280 |
+
|
281 |
+
new_mol = Chem.MolFromSmiles(new_smiles)
|
282 |
+
fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
|
283 |
+
sim = DataStructs.TanimotoSimilarity(fp1, fp2)
|
284 |
+
if sim < sim_cutoff:
|
285 |
+
r = mid - 1
|
286 |
+
else:
|
287 |
+
l = mid
|
288 |
+
"""
|
289 |
+
best_vec = visited[0]
|
290 |
+
for new_vec in visited:
|
291 |
+
tree_vec,mol_vec = torch.chunk(new_vec, 2, dim=1)
|
292 |
+
new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
|
293 |
+
if new_smiles is None: continue
|
294 |
+
new_mol = Chem.MolFromSmiles(new_smiles)
|
295 |
+
fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
|
296 |
+
sim = DataStructs.TanimotoSimilarity(fp1, fp2)
|
297 |
+
if sim >= sim_cutoff:
|
298 |
+
best_vec = new_vec
|
299 |
+
"""
|
300 |
+
tree_vec,mol_vec = torch.chunk(visited[l], 2, dim=1)
|
301 |
+
#tree_vec,mol_vec = torch.chunk(best_vec, 2, dim=1)
|
302 |
+
new_smiles = self.decode(tree_vec, mol_vec, prob_decode=False)
|
303 |
+
if new_smiles is None:
|
304 |
+
return None, None
|
305 |
+
new_mol = Chem.MolFromSmiles(new_smiles)
|
306 |
+
fp2 = AllChem.GetMorganFingerprint(new_mol, 2)
|
307 |
+
sim = DataStructs.TanimotoSimilarity(fp1, fp2)
|
308 |
+
if sim >= sim_cutoff:
|
309 |
+
return new_smiles, sim
|
310 |
+
else:
|
311 |
+
return None, None
|
fast_jtnn/mol_tree.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import rdkit
|
2 |
+
import rdkit.Chem as Chem
|
3 |
+
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
|
4 |
+
from vocab import *
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
class MolTreeNode(object):
|
8 |
+
|
9 |
+
def __init__(self, smiles, clique=[]):
|
10 |
+
self.smiles = smiles
|
11 |
+
self.mol = get_mol(self.smiles)
|
12 |
+
|
13 |
+
self.clique = [x for x in clique] #copy
|
14 |
+
self.neighbors = []
|
15 |
+
|
16 |
+
def add_neighbor(self, nei_node):
|
17 |
+
self.neighbors.append(nei_node)
|
18 |
+
|
19 |
+
def recover(self, original_mol):
|
20 |
+
clique = []
|
21 |
+
clique.extend(self.clique)
|
22 |
+
if not self.is_leaf:
|
23 |
+
for cidx in self.clique:
|
24 |
+
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
|
25 |
+
|
26 |
+
for nei_node in self.neighbors:
|
27 |
+
clique.extend(nei_node.clique)
|
28 |
+
if nei_node.is_leaf: #Leaf node, no need to mark
|
29 |
+
continue
|
30 |
+
for cidx in nei_node.clique:
|
31 |
+
#allow singleton node override the atom mapping
|
32 |
+
if cidx not in self.clique or len(nei_node.clique) == 1:
|
33 |
+
atom = original_mol.GetAtomWithIdx(cidx)
|
34 |
+
atom.SetAtomMapNum(nei_node.nid)
|
35 |
+
|
36 |
+
clique = list(set(clique))
|
37 |
+
label_mol = get_clique_mol(original_mol, clique)
|
38 |
+
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
|
39 |
+
|
40 |
+
for cidx in clique:
|
41 |
+
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
|
42 |
+
|
43 |
+
return self.label
|
44 |
+
|
45 |
+
def assemble(self):
|
46 |
+
neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
|
47 |
+
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
|
48 |
+
singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
|
49 |
+
neighbors = singletons + neighbors
|
50 |
+
|
51 |
+
cands,aroma = enum_assemble(self, neighbors)
|
52 |
+
new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0]
|
53 |
+
if len(new_cands) > 0: cands = new_cands
|
54 |
+
|
55 |
+
if len(cands) > 0:
|
56 |
+
self.cands, _ = zip(*cands)
|
57 |
+
self.cands = list(self.cands)
|
58 |
+
else:
|
59 |
+
self.cands = []
|
60 |
+
|
61 |
+
class MolTree(object):
|
62 |
+
|
63 |
+
def __init__(self, smiles):
|
64 |
+
self.smiles = smiles
|
65 |
+
self.mol = get_mol(smiles)
|
66 |
+
|
67 |
+
#Stereo Generation (currently disabled)
|
68 |
+
#mol = Chem.MolFromSmiles(smiles)
|
69 |
+
#self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
|
70 |
+
#self.smiles2D = Chem.MolToSmiles(mol)
|
71 |
+
#self.stereo_cands = decode_stereo(self.smiles2D)
|
72 |
+
|
73 |
+
cliques, edges = tree_decomp(self.mol)
|
74 |
+
self.nodes = []
|
75 |
+
root = 0
|
76 |
+
for i,c in enumerate(cliques):
|
77 |
+
cmol = get_clique_mol(self.mol, c)
|
78 |
+
node = MolTreeNode(get_smiles(cmol), c)
|
79 |
+
self.nodes.append(node)
|
80 |
+
if min(c) == 0: root = i
|
81 |
+
|
82 |
+
for x,y in edges:
|
83 |
+
self.nodes[x].add_neighbor(self.nodes[y])
|
84 |
+
self.nodes[y].add_neighbor(self.nodes[x])
|
85 |
+
|
86 |
+
if root > 0:
|
87 |
+
self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]
|
88 |
+
|
89 |
+
for i,node in enumerate(self.nodes):
|
90 |
+
node.nid = i + 1
|
91 |
+
if len(node.neighbors) > 1: #Leaf node mol is not marked
|
92 |
+
set_atommap(node.mol, node.nid)
|
93 |
+
node.is_leaf = (len(node.neighbors) == 1)
|
94 |
+
|
95 |
+
def size(self):
|
96 |
+
return len(self.nodes)
|
97 |
+
|
98 |
+
def recover(self):
|
99 |
+
for node in self.nodes:
|
100 |
+
node.recover(self.mol)
|
101 |
+
|
102 |
+
def assemble(self):
|
103 |
+
for node in self.nodes:
|
104 |
+
node.assemble()
|
105 |
+
|
106 |
+
def dfs(node, fa_idx):
|
107 |
+
max_depth = 0
|
108 |
+
for child in node.neighbors:
|
109 |
+
if child.idx == fa_idx: continue
|
110 |
+
max_depth = max(max_depth, dfs(child, node.idx))
|
111 |
+
return max_depth + 1
|
112 |
+
|
113 |
+
def data_process_chunk(smiles_list):
|
114 |
+
cset = set()
|
115 |
+
for line in smiles_list:
|
116 |
+
smiles = line.split()[0]
|
117 |
+
# print(smiles)
|
118 |
+
mol = MolTree(smiles)
|
119 |
+
for c in mol.nodes:
|
120 |
+
cset.add(c.smiles)
|
121 |
+
# i+=1
|
122 |
+
# if i%10000 == 0:
|
123 |
+
# # print(i,end='\x1b[1K\r')
|
124 |
+
# print(i, ' / 1584663')
|
125 |
+
return list(cset)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
import sys
|
129 |
+
lg = rdkit.RDLogger.logger()
|
130 |
+
lg.setLevel(rdkit.RDLogger.CRITICAL)
|
131 |
+
|
132 |
+
i = 0
|
133 |
+
|
134 |
+
import os
|
135 |
+
from joblib import Parallel,delayed
|
136 |
+
from tqdm import tqdm
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument('--smiles_path', type=str,required=True)
|
139 |
+
parser.add_argument('--vocab_path', type=str,required=True)
|
140 |
+
parser.add_argument('--prop', type=bool,default=False)
|
141 |
+
parser.add_argument('--ncpu', default=8,type=int)
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
if args.prop:
|
145 |
+
import pandas as pd
|
146 |
+
smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES'])
|
147 |
+
smiles_list = list(smiles_list.SMILES)
|
148 |
+
else:
|
149 |
+
with open(args.smiles_path,'r') as f:
|
150 |
+
smiles_list = [line.split()[0] for line in f]
|
151 |
+
print('Total smiles = ',len(smiles_list))
|
152 |
+
|
153 |
+
# moses: 1584663
|
154 |
+
|
155 |
+
chunk_size = 10000
|
156 |
+
vocab_set_list = Parallel(n_jobs=args.ncpu)(
|
157 |
+
delayed(data_process_chunk)(smiles_list[start: start + chunk_size])
|
158 |
+
for start in tqdm(range(0, len(smiles_list), chunk_size))
|
159 |
+
)
|
160 |
+
vocab_list =[]
|
161 |
+
for set_list in vocab_set_list:
|
162 |
+
vocab_list.extend(set_list)
|
163 |
+
|
164 |
+
cset = set(vocab_list)
|
165 |
+
with open(args.vocab_path,'w') as f:
|
166 |
+
for x in cset:
|
167 |
+
f.write(''.join([x,'\n']))
|
168 |
+
|
fast_jtnn/mpn.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import rdkit.Chem as Chem
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from nnutils import *
|
6 |
+
from chemutils import get_mol
|
7 |
+
|
8 |
+
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
|
9 |
+
|
10 |
+
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1
|
11 |
+
BOND_FDIM = 5 + 6
|
12 |
+
MAX_NB = 6
|
13 |
+
|
14 |
+
def onek_encoding_unk(x, allowable_set):
|
15 |
+
if x not in allowable_set:
|
16 |
+
x = allowable_set[-1]
|
17 |
+
return list(map(lambda s: x == s, allowable_set))
|
18 |
+
|
19 |
+
def atom_features(atom):
|
20 |
+
return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
|
21 |
+
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
|
22 |
+
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
|
23 |
+
+ onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3])
|
24 |
+
+ [atom.GetIsAromatic()])
|
25 |
+
|
26 |
+
def bond_features(bond):
|
27 |
+
bt = bond.GetBondType()
|
28 |
+
stereo = int(bond.GetStereo())
|
29 |
+
fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]
|
30 |
+
fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5])
|
31 |
+
return torch.Tensor(fbond + fstereo)
|
32 |
+
|
33 |
+
class MPN(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, hidden_size, depth):
|
36 |
+
super(MPN, self).__init__()
|
37 |
+
self.hidden_size = hidden_size
|
38 |
+
self.depth = depth
|
39 |
+
|
40 |
+
self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False)
|
41 |
+
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
|
42 |
+
self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size)
|
43 |
+
|
44 |
+
def forward(self, fatoms, fbonds, agraph, bgraph, scope):
|
45 |
+
fatoms = create_var(fatoms)
|
46 |
+
fbonds = create_var(fbonds)
|
47 |
+
agraph = create_var(agraph)
|
48 |
+
bgraph = create_var(bgraph)
|
49 |
+
|
50 |
+
binput = self.W_i(fbonds)
|
51 |
+
message = F.relu(binput)
|
52 |
+
|
53 |
+
for i in range(self.depth - 1):
|
54 |
+
nei_message = index_select_ND(message, 0, bgraph)
|
55 |
+
nei_message = nei_message.sum(dim=1)
|
56 |
+
nei_message = self.W_h(nei_message)
|
57 |
+
message = F.relu(binput + nei_message)
|
58 |
+
|
59 |
+
nei_message = index_select_ND(message, 0, agraph)
|
60 |
+
nei_message = nei_message.sum(dim=1)
|
61 |
+
ainput = torch.cat([fatoms, nei_message], dim=1)
|
62 |
+
atom_hiddens = F.relu(self.W_o(ainput))
|
63 |
+
|
64 |
+
max_len = max([x for _,x in scope])
|
65 |
+
batch_vecs = []
|
66 |
+
for st,le in scope:
|
67 |
+
cur_vecs = atom_hiddens[st : st + le].mean(dim=0)
|
68 |
+
batch_vecs.append( cur_vecs )
|
69 |
+
|
70 |
+
mol_vecs = torch.stack(batch_vecs, dim=0)
|
71 |
+
return mol_vecs
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def tensorize(mol_batch):
|
75 |
+
padding = torch.zeros(ATOM_FDIM + BOND_FDIM)
|
76 |
+
fatoms,fbonds = [],[padding] #Ensure bond is 1-indexed
|
77 |
+
in_bonds,all_bonds = [],[(-1,-1)] #Ensure bond is 1-indexed
|
78 |
+
scope = []
|
79 |
+
total_atoms = 0
|
80 |
+
|
81 |
+
for smiles in mol_batch:
|
82 |
+
mol = get_mol(smiles)
|
83 |
+
#mol = Chem.MolFromSmiles(smiles)
|
84 |
+
n_atoms = mol.GetNumAtoms()
|
85 |
+
for atom in mol.GetAtoms():
|
86 |
+
fatoms.append( atom_features(atom) )
|
87 |
+
in_bonds.append([])
|
88 |
+
|
89 |
+
for bond in mol.GetBonds():
|
90 |
+
a1 = bond.GetBeginAtom()
|
91 |
+
a2 = bond.GetEndAtom()
|
92 |
+
x = a1.GetIdx() + total_atoms
|
93 |
+
y = a2.GetIdx() + total_atoms
|
94 |
+
|
95 |
+
b = len(all_bonds)
|
96 |
+
all_bonds.append((x,y))
|
97 |
+
fbonds.append( torch.cat([fatoms[x], bond_features(bond)], 0) )
|
98 |
+
in_bonds[y].append(b)
|
99 |
+
|
100 |
+
b = len(all_bonds)
|
101 |
+
all_bonds.append((y,x))
|
102 |
+
fbonds.append( torch.cat([fatoms[y], bond_features(bond)], 0) )
|
103 |
+
in_bonds[x].append(b)
|
104 |
+
|
105 |
+
scope.append((total_atoms,n_atoms))
|
106 |
+
total_atoms += n_atoms
|
107 |
+
|
108 |
+
total_bonds = len(all_bonds)
|
109 |
+
fatoms = torch.stack(fatoms, 0)
|
110 |
+
fbonds = torch.stack(fbonds, 0)
|
111 |
+
agraph = torch.zeros(total_atoms,MAX_NB).long()
|
112 |
+
bgraph = torch.zeros(total_bonds,MAX_NB).long()
|
113 |
+
|
114 |
+
for a in range(total_atoms):
|
115 |
+
for i,b in enumerate(in_bonds[a]):
|
116 |
+
agraph[a,i] = b
|
117 |
+
|
118 |
+
for b1 in range(1, total_bonds):
|
119 |
+
x,y = all_bonds[b1]
|
120 |
+
for i,b2 in enumerate(in_bonds[x]):
|
121 |
+
if all_bonds[b2][0] != y:
|
122 |
+
bgraph[b1,i] = b2
|
123 |
+
|
124 |
+
return (fatoms, fbonds, agraph, bgraph, scope)
|
125 |
+
|
fast_jtnn/nnutils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
|
6 |
+
def create_var(tensor, requires_grad=None):
|
7 |
+
if requires_grad is None:
|
8 |
+
if torch.cuda.is_available():
|
9 |
+
return Variable(data = tensor).cuda()
|
10 |
+
else:
|
11 |
+
return Variable(data = tensor)
|
12 |
+
else:
|
13 |
+
if torch.cuda.is_available():
|
14 |
+
return Variable(data = tensor, requires_grad=requires_grad).cuda()
|
15 |
+
else:
|
16 |
+
return Variable(data = tensor, requires_grad=requires_grad)
|
17 |
+
|
18 |
+
def index_select_ND(source, dim, index):
|
19 |
+
index_size = index.size()
|
20 |
+
suffix_dim = source.size()[1:]
|
21 |
+
final_size = index_size + suffix_dim
|
22 |
+
target = source.index_select(dim, index.view(-1))
|
23 |
+
return target.view(final_size)
|
24 |
+
|
25 |
+
def avg_pool(all_vecs, scope, dim):
|
26 |
+
size = create_var(torch.Tensor([le for _,le in scope]))
|
27 |
+
return all_vecs.sum(dim=dim) / size.unsqueeze(-1)
|
28 |
+
|
29 |
+
def stack_pad_tensor(tensor_list):
|
30 |
+
max_len = max([t.size(0) for t in tensor_list])
|
31 |
+
for i,tensor in enumerate(tensor_list):
|
32 |
+
pad_len = max_len - tensor.size(0)
|
33 |
+
tensor_list[i] = F.pad( tensor, (0,0,0,pad_len) )
|
34 |
+
return torch.stack(tensor_list, dim=0)
|
35 |
+
|
36 |
+
#3D padded tensor to 2D matrix, with padded zeros removed
|
37 |
+
def flatten_tensor(tensor, scope):
|
38 |
+
assert tensor.size(0) == len(scope)
|
39 |
+
tlist = []
|
40 |
+
for i,tup in enumerate(scope):
|
41 |
+
le = tup[1]
|
42 |
+
tlist.append( tensor[i, 0:le] )
|
43 |
+
return torch.cat(tlist, dim=0)
|
44 |
+
|
45 |
+
#2D matrix to 3D padded tensor
|
46 |
+
def inflate_tensor(tensor, scope):
|
47 |
+
max_len = max([le for _,le in scope])
|
48 |
+
batch_vecs = []
|
49 |
+
for st,le in scope:
|
50 |
+
cur_vecs = tensor[st : st + le]
|
51 |
+
cur_vecs = F.pad( cur_vecs, (0,0,0,max_len-le) )
|
52 |
+
batch_vecs.append( cur_vecs )
|
53 |
+
|
54 |
+
return torch.stack(batch_vecs, dim=0)
|
55 |
+
|
56 |
+
def GRU(x, h_nei, W_z, W_r, U_r, W_h):
|
57 |
+
hidden_size = x.size()[-1]
|
58 |
+
sum_h = h_nei.sum(dim=1)
|
59 |
+
z_input = torch.cat([x,sum_h], dim=1)
|
60 |
+
z = F.sigmoid(W_z(z_input))
|
61 |
+
|
62 |
+
r_1 = W_r(x).view(-1,1,hidden_size)
|
63 |
+
r_2 = U_r(h_nei)
|
64 |
+
r = F.sigmoid(r_1 + r_2)
|
65 |
+
|
66 |
+
gated_h = r * h_nei
|
67 |
+
sum_gated_h = gated_h.sum(dim=1)
|
68 |
+
h_input = torch.cat([x,sum_gated_h], dim=1)
|
69 |
+
pre_h = F.tanh(W_h(h_input))
|
70 |
+
new_h = (1.0 - z) * sum_h + z * pre_h
|
71 |
+
return new_h
|
72 |
+
|
fast_jtnn/vocab.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import rdkit
|
2 |
+
import rdkit.Chem as Chem
|
3 |
+
import copy
|
4 |
+
|
5 |
+
def get_slots(smiles):
|
6 |
+
mol = Chem.MolFromSmiles(smiles)
|
7 |
+
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
|
8 |
+
|
9 |
+
class Vocab(object):
|
10 |
+
benzynes = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1']
|
11 |
+
penzynes = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1']
|
12 |
+
|
13 |
+
def __init__(self, smiles_list):
|
14 |
+
self.vocab = smiles_list
|
15 |
+
self.vmap = {x:i for i,x in enumerate(self.vocab)}
|
16 |
+
self.slots = [get_slots(smiles) for smiles in self.vocab]
|
17 |
+
Vocab.benzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 6] + ['C1=CCNCC1']
|
18 |
+
Vocab.penzynes = [s for s in smiles_list if s.count('=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 5] + ['C1=NCCN1','C1=NNCC1']
|
19 |
+
|
20 |
+
def get_index(self, smiles):
|
21 |
+
return self.vmap[smiles]
|
22 |
+
|
23 |
+
def get_smiles(self, idx):
|
24 |
+
return self.vocab[idx]
|
25 |
+
|
26 |
+
def get_slots(self, idx):
|
27 |
+
return copy.deepcopy(self.slots[idx])
|
28 |
+
|
29 |
+
def size(self):
|
30 |
+
return len(self.vocab)
|
31 |
+
|
fpscores.pkl.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
|
3 |
+
size 3848394
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rdkit
|
2 |
+
numpy
|
3 |
+
torch
|
4 |
+
argparse
|
5 |
+
tqdm
|
6 |
+
networkx
|
7 |
+
scipy
|
8 |
+
copy
|
9 |
+
molbloom
|
sascorer.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# calculation of synthetic accessibility score as described in:
|
3 |
+
#
|
4 |
+
# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
|
5 |
+
# Peter Ertl and Ansgar Schuffenhauer
|
6 |
+
# Journal of Cheminformatics 1:8 (2009)
|
7 |
+
# http://www.jcheminf.com/content/1/1/8
|
8 |
+
#
|
9 |
+
# several small modifications to the original paper are included
|
10 |
+
# particularly slightly different formula for marocyclic penalty
|
11 |
+
# and taking into account also molecule symmetry (fingerprint density)
|
12 |
+
#
|
13 |
+
# for a set of 10k diverse molecules the agreement between the original method
|
14 |
+
# as implemented in PipelinePilot and this implementation is r2 = 0.97
|
15 |
+
#
|
16 |
+
# peter ertl & greg landrum, september 2013
|
17 |
+
#
|
18 |
+
|
19 |
+
|
20 |
+
from rdkit import Chem
|
21 |
+
from rdkit.Chem import rdMolDescriptors
|
22 |
+
import pickle
|
23 |
+
|
24 |
+
import math
|
25 |
+
from collections import defaultdict
|
26 |
+
|
27 |
+
import os.path as op
|
28 |
+
|
29 |
+
_fscores = None
|
30 |
+
|
31 |
+
|
32 |
+
def readFragmentScores(name='fpscores'):
|
33 |
+
import gzip
|
34 |
+
global _fscores
|
35 |
+
# generate the full path filename:
|
36 |
+
if name == "fpscores":
|
37 |
+
name = op.join(op.dirname(__file__), name)
|
38 |
+
_fscores = pickle.load(gzip.open('%s.pkl.gz' % name))
|
39 |
+
outDict = {}
|
40 |
+
for i in _fscores:
|
41 |
+
for j in range(1, len(i)):
|
42 |
+
outDict[i[j]] = float(i[0])
|
43 |
+
_fscores = outDict
|
44 |
+
|
45 |
+
|
46 |
+
def numBridgeheadsAndSpiro(mol, ri=None):
|
47 |
+
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
48 |
+
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
49 |
+
return nBridgehead, nSpiro
|
50 |
+
|
51 |
+
|
52 |
+
def calculateScore(m):
|
53 |
+
if _fscores is None:
|
54 |
+
readFragmentScores()
|
55 |
+
|
56 |
+
# fragment score
|
57 |
+
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
58 |
+
2) # <- 2 is the *radius* of the circular fingerprint
|
59 |
+
fps = fp.GetNonzeroElements()
|
60 |
+
score1 = 0.
|
61 |
+
nf = 0
|
62 |
+
for bitId, v in fps.items():
|
63 |
+
nf += v
|
64 |
+
sfp = bitId
|
65 |
+
score1 += _fscores.get(sfp, -4) * v
|
66 |
+
score1 /= nf
|
67 |
+
|
68 |
+
# features score
|
69 |
+
nAtoms = m.GetNumAtoms()
|
70 |
+
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
71 |
+
ri = m.GetRingInfo()
|
72 |
+
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
73 |
+
nMacrocycles = 0
|
74 |
+
for x in ri.AtomRings():
|
75 |
+
if len(x) > 8:
|
76 |
+
nMacrocycles += 1
|
77 |
+
|
78 |
+
sizePenalty = nAtoms**1.005 - nAtoms
|
79 |
+
stereoPenalty = math.log10(nChiralCenters + 1)
|
80 |
+
spiroPenalty = math.log10(nSpiro + 1)
|
81 |
+
bridgePenalty = math.log10(nBridgeheads + 1)
|
82 |
+
macrocyclePenalty = 0.
|
83 |
+
# ---------------------------------------
|
84 |
+
# This differs from the paper, which defines:
|
85 |
+
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
86 |
+
# This form generates better results when 2 or more macrocycles are present
|
87 |
+
if nMacrocycles > 0:
|
88 |
+
macrocyclePenalty = math.log10(2)
|
89 |
+
|
90 |
+
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
91 |
+
|
92 |
+
# correction for the fingerprint density
|
93 |
+
# not in the original publication, added in version 1.1
|
94 |
+
# to make highly symmetrical molecules easier to synthetise
|
95 |
+
score3 = 0.
|
96 |
+
if nAtoms > len(fps):
|
97 |
+
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
98 |
+
|
99 |
+
sascore = score1 + score2 + score3
|
100 |
+
|
101 |
+
# need to transform "raw" value into scale between 1 and 10
|
102 |
+
min = -4.0
|
103 |
+
max = 2.5
|
104 |
+
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
105 |
+
# smooth the 10-end
|
106 |
+
if sascore > 8.:
|
107 |
+
sascore = 8. + math.log(sascore + 1. - 9.)
|
108 |
+
if sascore > 10.:
|
109 |
+
sascore = 10.0
|
110 |
+
elif sascore < 1.:
|
111 |
+
sascore = 1.0
|
112 |
+
|
113 |
+
return sascore
|
114 |
+
|
115 |
+
|
116 |
+
def processMols(mols):
|
117 |
+
print('smiles\tName\tsa_score')
|
118 |
+
for i, m in enumerate(mols):
|
119 |
+
if m is None:
|
120 |
+
continue
|
121 |
+
|
122 |
+
s = calculateScore(m)
|
123 |
+
|
124 |
+
smiles = Chem.MolToSmiles(m)
|
125 |
+
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == '__main__':
|
129 |
+
import sys
|
130 |
+
import time
|
131 |
+
|
132 |
+
t1 = time.time()
|
133 |
+
readFragmentScores("fpscores")
|
134 |
+
t2 = time.time()
|
135 |
+
|
136 |
+
suppl = Chem.SmilesMolSupplier(sys.argv[1])
|
137 |
+
t3 = time.time()
|
138 |
+
processMols(suppl)
|
139 |
+
t4 = time.time()
|
140 |
+
|
141 |
+
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
|
142 |
+
file=sys.stderr)
|
143 |
+
|
144 |
+
#
|
145 |
+
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
|
146 |
+
# All rights reserved.
|
147 |
+
#
|
148 |
+
# Redistribution and use in source and binary forms, with or without
|
149 |
+
# modification, are permitted provided that the following conditions are
|
150 |
+
# met:
|
151 |
+
#
|
152 |
+
# * Redistributions of source code must retain the above copyright
|
153 |
+
# notice, this list of conditions and the following disclaimer.
|
154 |
+
# * Redistributions in binary form must reproduce the above
|
155 |
+
# copyright notice, this list of conditions and the following
|
156 |
+
# disclaimer in the documentation and/or other materials provided
|
157 |
+
# with the distribution.
|
158 |
+
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
|
159 |
+
# nor the names of its contributors may be used to endorse or promote
|
160 |
+
# products derived from this software without specific prior written permission.
|
161 |
+
#
|
162 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
163 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
164 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
165 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
166 |
+
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
167 |
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
168 |
+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
169 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
170 |
+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
171 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
172 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
173 |
+
#
|
vocab.txt
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
C=N
|
2 |
+
C1CSCNN1
|
3 |
+
C1CCSCC1
|
4 |
+
C1=NSN=C1
|
5 |
+
C1=CSCS1
|
6 |
+
C1NCS1
|
7 |
+
C1=CSCC1
|
8 |
+
C1=CC2C=CC1O2
|
9 |
+
C1CC2COCC(CN1)N2
|
10 |
+
C1=COCCCN1
|
11 |
+
C1=CNSNC1
|
12 |
+
C1=NNCS1
|
13 |
+
C1=NNCSC1
|
14 |
+
C1=CSNCS1
|
15 |
+
C1=NC=NC=N1
|
16 |
+
C1=CCCCC=C1
|
17 |
+
C1CNCN=N1
|
18 |
+
C1=NNN=N1
|
19 |
+
C1NCC2OCC1O2
|
20 |
+
C1CC2CCOC(C2)N1
|
21 |
+
C1=CC2CCC1C2
|
22 |
+
C1=NN=NCC1
|
23 |
+
C1=CC2CC1C2
|
24 |
+
C1=CNCN=C1
|
25 |
+
C1=CC2C=CC1NN2
|
26 |
+
C1CN=NNC1
|
27 |
+
C1=CSCCN1
|
28 |
+
C1=CC=COC=C1
|
29 |
+
C1=NCSC1
|
30 |
+
C1CC2CNCC1N2
|
31 |
+
C1N=NNCN1
|
32 |
+
C=S
|
33 |
+
C1CC2OCC3CCC(O1)C2C3
|
34 |
+
C1COCCOC1
|
35 |
+
C1NCSN1
|
36 |
+
C1=CNN=C1
|
37 |
+
C1CNCCOC1
|
38 |
+
C1CSNCN1
|
39 |
+
C1CNSN1
|
40 |
+
C1CC2CCC(O1)O2
|
41 |
+
C1=CNSC=C1
|
42 |
+
C1=NCCO1
|
43 |
+
C1=CCSNC1
|
44 |
+
C1=NCOCC1
|
45 |
+
C1=CONC1
|
46 |
+
C1CC2CCC1O2
|
47 |
+
C1=NN=NS1
|
48 |
+
C1=CC2CCCC1NN2
|
49 |
+
C1CSCN1
|
50 |
+
C1NC2CC3CC1CC(C3)C2
|
51 |
+
C1=NC=N[SH]=N1
|
52 |
+
C1OC2CNC1C2
|
53 |
+
C1=CCSC1
|
54 |
+
C1=CC2C=CC1N2
|
55 |
+
C1=CCNC=NC1
|
56 |
+
O=S
|
57 |
+
C1NN=NN1
|
58 |
+
C1=CCNC1
|
59 |
+
C1=CCCN=C1
|
60 |
+
C1=COCCOC1
|
61 |
+
C1CC2CNCC(C1)NC2
|
62 |
+
C1=CCOC1
|
63 |
+
C1=CCNNCC1
|
64 |
+
C1=CNNCC1
|
65 |
+
C1=CSC1
|
66 |
+
C1CC2CCNC(C1)C2
|
67 |
+
C1CC2NC(CCS2)S1
|
68 |
+
C1CNN1
|
69 |
+
C1=CNCCNC1
|
70 |
+
C1=CNCCCN1
|
71 |
+
C1NCNCN1
|
72 |
+
C1=CC2C=CC1C2
|
73 |
+
OCl
|
74 |
+
C1=CN=CSC1
|
75 |
+
C1=CCNC=CC1
|
76 |
+
C1=CC2CC1CN2
|
77 |
+
C1=CC2CCCN(C1)C2
|
78 |
+
C1=CC2CC(C1)CN2
|
79 |
+
NO
|
80 |
+
C1=COC=NC1
|
81 |
+
C1=NSN=CC1
|
82 |
+
C1=CN=COC1
|
83 |
+
C1=CSCCN=C1
|
84 |
+
C1=CC=CNC=C1
|
85 |
+
C1=NN=CO1
|
86 |
+
C1=CNSC1
|
87 |
+
C1CNC1
|
88 |
+
C1=COCN1
|
89 |
+
C1CCSCNC1
|
90 |
+
C1CC2CNCC1C2
|
91 |
+
N
|
92 |
+
C1=NN=CN1
|
93 |
+
C1=NCCNCC1
|
94 |
+
C1CCCCCC1
|
95 |
+
C1CC2CC(N1)C1COC2O1
|
96 |
+
C1=NCNNC1
|
97 |
+
C1=CN=NC=N1
|
98 |
+
C1COCN1
|
99 |
+
C1CNNCN1
|
100 |
+
C1=NCC=NCC1
|
101 |
+
C1=NCNS1
|
102 |
+
C1OC2OCC1CO2
|
103 |
+
C1CCOCNC1
|
104 |
+
CF
|
105 |
+
C1=NNCO1
|
106 |
+
C1=NC=NCCN1
|
107 |
+
C1CC2CCCC(C1)O2
|
108 |
+
C1=NNCN1
|
109 |
+
C1=CCN=CC1
|
110 |
+
C1CC2CNCC1CN2
|
111 |
+
C1CC2CNCC(C1)N2
|
112 |
+
C1=CCSCC1
|
113 |
+
C1CSNS1
|
114 |
+
C1=NCNC1
|
115 |
+
C1CC2C3NNC2C1O3
|
116 |
+
C1CC2CNC(C2)N1
|
117 |
+
C1=CCOCCC1
|
118 |
+
C1CC2CCC(C1)O2
|
119 |
+
C1=CC2CC1NN2
|
120 |
+
C1=CC2CCCC(C1)C2
|
121 |
+
C1=CN=CCC1
|
122 |
+
C1CC2NCC1CN2
|
123 |
+
C1=CSCN1
|
124 |
+
C1=NCCC1
|
125 |
+
C1CC2C3CC1CC23
|
126 |
+
C1=CN=CN=C1
|
127 |
+
C1=NSNC1
|
128 |
+
C1CSN=N1
|
129 |
+
C1COCCSC1
|
130 |
+
C1CNSC1
|
131 |
+
C1=CCCCCC1
|
132 |
+
C1=NCNCN1
|
133 |
+
C1CC2CCC(C1)C2
|
134 |
+
C1CC2CNCC(C1)C2
|
135 |
+
C1=CC2CCCC1C2
|
136 |
+
C1=CC2CCC1O2
|
137 |
+
C1=CC2CCC1CC2
|
138 |
+
C1=CC2C=CC(C1)CC2
|
139 |
+
C1CSNCS1
|
140 |
+
NBr
|
141 |
+
C1COCSN1
|
142 |
+
C1=CSCNC1
|
143 |
+
C1CC2NCC1NN2
|
144 |
+
C1=CN=CCNC1
|
145 |
+
C1CC2CNCC1O2
|
146 |
+
C1=NNCCC1
|
147 |
+
C1CC2CC(C1)N2
|
148 |
+
C1=NNSC1
|
149 |
+
C1=COCCN1
|
150 |
+
C1=CCC2CC(C1)C2
|
151 |
+
C1=CC2CCCC1CC2
|
152 |
+
C1=CCNCC1
|
153 |
+
C1CNNC1
|
154 |
+
C1=NC=NC1
|
155 |
+
C1=COC=N1
|
156 |
+
C1CC2CNC(C2)O1
|
157 |
+
C1=CCC2CC=CC(C1)C2
|
158 |
+
C1=CC=CC=C1
|
159 |
+
C1=CCSOC1
|
160 |
+
C1CN2CCC1CC2
|
161 |
+
C1=CSCCCN1
|
162 |
+
C1CNCCSC1
|
163 |
+
C1=NCCSN1
|
164 |
+
C1=NCCNN1
|
165 |
+
C1NCON1
|
166 |
+
C1NCC2CC1CN2
|
167 |
+
C1=CNC=N1
|
168 |
+
C1=CC2CCC1NN2
|
169 |
+
C1CNCOC1
|
170 |
+
C1C2CC3CC1OC(O2)O3
|
171 |
+
C1=NNCSCC1
|
172 |
+
C1=CCOCC1
|
173 |
+
C1=NSCC1
|
174 |
+
C1=CNSCC1
|
175 |
+
C1=CCC1
|
176 |
+
C1CCCNCC1
|
177 |
+
C1CC2CCCN(C1)C2
|
178 |
+
CN
|
179 |
+
C1CC2CC(O1)C1OCC2O1
|
180 |
+
C1CC2CNC(C1)N2
|
181 |
+
CO
|
182 |
+
C1=C2CCCCC1C2
|
183 |
+
C1=COCCNC1
|
184 |
+
C1=CN=CNC1
|
185 |
+
C1=CNCCCC1
|
186 |
+
C1NN=NS1
|
187 |
+
C1=NC=NCS1
|
188 |
+
C1=NN=CCC1
|
189 |
+
C1=NCCCC1
|
190 |
+
C1CC2CCC3CC1NC23
|
191 |
+
C1=NCCOCC1
|
192 |
+
C1=CSNC=N1
|
193 |
+
C1=CSCNNC1
|
194 |
+
C1=COCCC1
|
195 |
+
C1=COCCC=N1
|
196 |
+
C1=NNCCN1
|
197 |
+
C1C2CC1C2
|
198 |
+
N1C2NC3NC4NC3NC2NC14
|
199 |
+
C1=CC2C3NSC2C13
|
200 |
+
C1CC2COCC1N2
|
201 |
+
C1=CC2CCN3CC1CC23
|
202 |
+
C1CCONC1
|
203 |
+
C1=CC1
|
204 |
+
C1CCNNC1
|
205 |
+
C1=CSC=N1
|
206 |
+
C1CC1
|
207 |
+
C1=NCNCC1
|
208 |
+
C1=CC2CC(NCN2)O1
|
209 |
+
C1=CNSN=C1
|
210 |
+
C1=CCCNCC1
|
211 |
+
C1=CSN=CN1
|
212 |
+
NS
|
213 |
+
C1=CSN=N1
|
214 |
+
C1=CN=CC1
|
215 |
+
C1=CC2CCC(O1)O2
|
216 |
+
C1C2CC3C1OC1NCC2C13
|
217 |
+
C1COCSC1
|
218 |
+
C1=CCCOCC1
|
219 |
+
OBr
|
220 |
+
C1=COCSN1
|
221 |
+
C1=CN=N[SH]=C1
|
222 |
+
C1=CSCC=NC1
|
223 |
+
C1=CC=C2CCCCC(=C1)C2
|
224 |
+
C1=CSC=CC1
|
225 |
+
C=C
|
226 |
+
C1=CNCOC1
|
227 |
+
C1=NON=C1
|
228 |
+
C1=C[SH]=NC=N1
|
229 |
+
C1=CNCNC1
|
230 |
+
C1CN2C3CC4CC(C1C3)C2C4
|
231 |
+
C1=CONCNN1
|
232 |
+
CBr
|
233 |
+
C1CCNCNC1
|
234 |
+
C1=CCNN=C1
|
235 |
+
C1=CCNCCC1
|
236 |
+
C1=CSCCCC1
|
237 |
+
C1=CC2COCC(C1)C2
|
238 |
+
C1=CCCC=CC1
|
239 |
+
C1COCCN1
|
240 |
+
N1NN1
|
241 |
+
C1=CC2CCCC(C1)N2
|
242 |
+
C1=CC=NC=C1
|
243 |
+
C1=NC=NNC1
|
244 |
+
C1CC2CCC1CNC2
|
245 |
+
C1=CNCC1
|
246 |
+
N=S
|
247 |
+
CC
|
248 |
+
C1=NNCNCC1
|
249 |
+
C1=NN=NN1
|
250 |
+
C1=NSNCC1
|
251 |
+
C1=CNCCSC1
|
252 |
+
C1=NC=NCCC1
|
253 |
+
C1CCNSCC1
|
254 |
+
C1=NC=NN=C1
|
255 |
+
C1=CCOCOC1
|
256 |
+
C1=NN=CNC1
|
257 |
+
C1=COCCCC1
|
258 |
+
C1=NCCCO1
|
259 |
+
C1CC2CNC(C1)C2
|
260 |
+
C1=CSN=CS1
|
261 |
+
C1CC2CCC(C1)CNC2
|
262 |
+
C1C[SH]=NS1
|
263 |
+
C1CC2CCCC(C1)C2
|
264 |
+
C1CC2CCC(C1)N2
|
265 |
+
C1=NOCC1
|
266 |
+
C1=NCCOC1
|
267 |
+
C1CC2C3CC1C2CO3
|
268 |
+
C1NCN1
|
269 |
+
C1COC2CCC(C2)N1
|
270 |
+
C1=CNC=CC1
|
271 |
+
C1=CNCCC1
|
272 |
+
C1=NN=CN=N1
|
273 |
+
C1=CNCCC=N1
|
274 |
+
C1NCC2CC1C2
|
275 |
+
C1C2CC3CC1CC(C2)C3
|
276 |
+
C1=NC2CC(C1)CN2
|
277 |
+
C1=CNCN=CC1
|
278 |
+
C1NCSCN1
|
279 |
+
S
|
280 |
+
C1CC2CCC1C2
|
281 |
+
C1=CCC=NC=C1
|
282 |
+
C1CCOCC1
|
283 |
+
C1CCN2CCCC(C1)C2
|
284 |
+
C1=NSCCS1
|
285 |
+
C1NC2CCOC(C2)N1
|
286 |
+
C1=NCNN1
|
287 |
+
C1=CSC=CNC1
|
288 |
+
C1=CCSC=C1
|
289 |
+
C1=COC=CO1
|
290 |
+
C1=CCCNC=C1
|
291 |
+
C1CSCCN1
|
292 |
+
C1CCCC1
|
293 |
+
C1CCSOC1
|
294 |
+
C1=NSCO1
|
295 |
+
C1=NCN=CO1
|
296 |
+
C1CN2CCN1CC2
|
297 |
+
C1=NCN=CN1
|
298 |
+
C1CC2CCC(N1)O2
|
299 |
+
C1COCO1
|
300 |
+
C1=COC=C1
|
301 |
+
C1=CCC2CCC(C1)N2
|
302 |
+
C1CCC2CCCC(C1)C2
|
303 |
+
C1=COCC1
|
304 |
+
C1=NSN=CO1
|
305 |
+
C1=CCCOC=C1
|
306 |
+
C1NNCO1
|
307 |
+
CS
|
308 |
+
C1CC2CC1CO2
|
309 |
+
C1NC2CNC1C2
|
310 |
+
C1=CC=C1
|
311 |
+
C1=COC=CC1
|
312 |
+
C1=CCNNC1
|
313 |
+
NN
|
314 |
+
C1=COCCO1
|
315 |
+
C1=NCC=NN1
|
316 |
+
C1CNCNCN1
|
317 |
+
C1=CN=NC=C1
|
318 |
+
C1=CSCO1
|
319 |
+
C1=NC=NSO1
|
320 |
+
C1=NCCS1
|
321 |
+
C1CC2CNC1CN2
|
322 |
+
C1=NNCC1
|
323 |
+
C1=CCCC=C1
|
324 |
+
C1=NCCN1
|
325 |
+
C1=NCCSCC1
|
326 |
+
C1=NNNC1
|
327 |
+
C1=CC2CCNC(C1)C2
|
328 |
+
C1=NCCN=N1
|
329 |
+
C1=CNSN1
|
330 |
+
C1=CNOC1
|
331 |
+
C1=CSN=C1
|
332 |
+
C1CCC2CCC(C1)C2
|
333 |
+
C1NCC2COCC1C2
|
334 |
+
C1COC2CC(N1)C1COC2O1
|
335 |
+
C1=NC=NN1
|
336 |
+
C1CSNSC1
|
337 |
+
C1CCN=NC1
|
338 |
+
C1=CO1
|
339 |
+
C1=CNCCOC1
|
340 |
+
C1=NNN=C1
|
341 |
+
C1=CC=NN=C1
|
342 |
+
C1CC2CCC1CC2
|
343 |
+
C1NC2CC(CO2)O1
|
344 |
+
C1NNCS1
|
345 |
+
C1=CC2CCCCC(CC1)C2
|
346 |
+
C1=CN=NCC1
|
347 |
+
C1=CON=C1
|
348 |
+
C1CSN1
|
349 |
+
C1=CSNC1
|
350 |
+
C1CC2CC(C1)C2
|
351 |
+
C1CNCN1
|
352 |
+
C1=CSCNN1
|
353 |
+
C1=CSCC=N1
|
354 |
+
C1=CC2CCC1CN2
|
355 |
+
C1=CSCCC1
|
356 |
+
C1=CC2CCC(C2)N1
|
357 |
+
C1CC2CC3CC1CC(C2)C3
|
358 |
+
C1=CSCCS1
|
359 |
+
C1=NC=NCC1
|
360 |
+
C1=CCC2CC(C=CN2)C1
|
361 |
+
C1OC2C3OC4C1C1OC2C3OCC41
|
362 |
+
C1=NCCCN1
|
363 |
+
C1COCOC1
|
364 |
+
C1=CNCC=N1
|
365 |
+
C1=NCCCS1
|
366 |
+
C1=CC2C=CC(C=C1)C2
|
367 |
+
C1CC2CC1CN2
|
368 |
+
C1=CNN1
|
369 |
+
C1=CC2CCCC1O2
|
370 |
+
C1=CCCCC1
|
371 |
+
C1=CNC=NC1
|
372 |
+
C1=CSNCN1
|
373 |
+
C1=CSCNN=C1
|
374 |
+
C1=CNC=CNC1
|
375 |
+
C1OC2CC3CC1C2C3
|
376 |
+
C1=NCN=C1
|
377 |
+
C1=CSC=CS1
|
378 |
+
C
|
379 |
+
C1=NCCN=CC1
|
380 |
+
C1=CSNCCN1
|
381 |
+
C1=CCN=CCC1
|
382 |
+
C1=CC2CCNC(C2)O1
|
383 |
+
C1C2CN3CCN(C2)CC1C3
|
384 |
+
C1C2CN3CC1CN(C2)C3
|
385 |
+
C1=NCCCCN1
|
386 |
+
C1C2CC3CC(CC1O3)N2
|
387 |
+
C1=NN=COC1
|
388 |
+
C1=CSCCNC1
|
389 |
+
C1=NOCN1
|
390 |
+
C1NCC2CNCC1C2
|
391 |
+
C1CCSC1
|
392 |
+
C1=CN=NNC1
|
393 |
+
C1CSCS1
|
394 |
+
C1CNN=N1
|
395 |
+
C1NCNSN1
|
396 |
+
C1=NCC1
|
397 |
+
N1NO1
|
398 |
+
C1COSCSO1
|
399 |
+
C1=CNNCN=C1
|
400 |
+
C1CSCCS1
|
401 |
+
C1=CSC=NC1
|
402 |
+
C1=CC2CCC(C1)NN2
|
403 |
+
C1CNCCNC1
|
404 |
+
C1C2CC3OC1CC3O2
|
405 |
+
C1CNCCN1
|
406 |
+
C1=NCCCNC1
|
407 |
+
C=O
|
408 |
+
C1CNSCCN1
|
409 |
+
C1CNCSC1
|
410 |
+
C1CCC2C3CCC2C(C1)C3
|
411 |
+
C1=CCC=C1
|
412 |
+
C1COC1
|
413 |
+
C1CC2CC1NN2
|
414 |
+
C1=NN=CS1
|
415 |
+
C1COCCSN1
|
416 |
+
C1CCNC1
|
417 |
+
C1=NOCCC1
|
418 |
+
C1CCC1
|
419 |
+
C1CC2CCC1CO2
|
420 |
+
C1=CCNC=C1
|
421 |
+
C1=CN=CN=CC1
|
422 |
+
C1=NC=NO1
|
423 |
+
C1=CC2CCOC(C2)N1
|
424 |
+
C1=CNCN=N1
|
425 |
+
C1NCN2CNCC1C2
|
426 |
+
C1CN2CCC3CC1CC2C3
|
427 |
+
C1=CSC=CO1
|
428 |
+
C1=CNCCN=C1
|
429 |
+
C1CCC2CC(C1)C2
|
430 |
+
C1=COCO1
|
431 |
+
C1=NCOC1
|
432 |
+
C1CSCO1
|
433 |
+
C1=NCCSC1
|
434 |
+
C1CCC2CCCC(C1)N2
|
435 |
+
C1=NCC=NC1
|
436 |
+
C1CSCCO1
|
437 |
+
C1=CNC=C1
|
438 |
+
C1COSN1
|
439 |
+
C1=CC2CC(O1)C1OCC2O1
|
440 |
+
C1CC2CCC(CN1)N2
|
441 |
+
C1=COCCN=C1
|
442 |
+
C1=CCOC=CC1
|
443 |
+
C1=NCN=CC1
|
444 |
+
C1OC2COC1C2
|
445 |
+
C1=NC2CC(C1)CCO2
|
446 |
+
C1=NC=NS1
|
447 |
+
C1=CSN=CO1
|
448 |
+
C1=CNN=CC1
|
449 |
+
C1=CC2C=CC(C1)CNC2
|
450 |
+
C1CC2CCNC(C2)N1
|
451 |
+
C1=NCCNC1
|
452 |
+
C1CC2COCC(C1)C2
|
453 |
+
C1=CN=CCN=C1
|
454 |
+
C1=CNN=NC1
|
455 |
+
C1CC2CC3CCN2C(C1)C3
|
456 |
+
C1=CC2COCC(C1)N2
|
457 |
+
C1CNOC1
|
458 |
+
C1=CCCC1
|
459 |
+
C1CCCSCC1
|
460 |
+
C1CCOCOC1
|
461 |
+
C1=NC=NCN1
|
462 |
+
C1=NCCNS1
|
463 |
+
C1NCNN1
|
464 |
+
C1CN2C3CC4CC1CC2C4O3
|
465 |
+
C1=COCNC1
|
466 |
+
C1CN2CC3OCC2CC13
|
467 |
+
C1NC2CC1C2
|
468 |
+
C1=NSCCN1
|
469 |
+
C1=CCOC=C1
|
470 |
+
C1CCSNC1
|
471 |
+
C1=CC2C=CC1CC2
|
472 |
+
C1CNSNC1
|
473 |
+
C1=CN=CNCC1
|
474 |
+
C1=NNCNC1
|
475 |
+
C1=CNCNN=C1
|
476 |
+
C1=NNCCCC1
|
477 |
+
C1=NCC2CCCC1C2
|
478 |
+
C1CCNCC1
|
479 |
+
C1NCC2CNCC1O2
|
480 |
+
C1=COCOC1
|
481 |
+
C1CC2COC(C1)O2
|
482 |
+
C#N
|
483 |
+
C1=NCNN=C1
|
484 |
+
C1CC2COC(C1)C2
|
485 |
+
C1CC2CCC1NN2
|
486 |
+
C1=CN=CC=N1
|
487 |
+
C1=NCCN=C1
|
488 |
+
C1CSCNCN1
|
489 |
+
C1NCC2COCC1N2
|
490 |
+
C1=CCN=C1
|
491 |
+
C1COCCO1
|
492 |
+
C1=COCCCO1
|
493 |
+
C1CC2COCC(C1)N2
|
494 |
+
C1=COCC=N1
|
495 |
+
C1CC2OC(CCS2)S1
|
496 |
+
C1=CCC=CC1
|
497 |
+
C1=NC1
|
498 |
+
C1CCCOCC1
|
499 |
+
C1=CC2CCC(C1)N2
|
500 |
+
C1=CNC=CN1
|
501 |
+
C1CCOC1
|
502 |
+
C1=CC=CCC=C1
|
503 |
+
C1=CC2CCNC(CC1)C2
|
504 |
+
C1=CNCSC1
|
505 |
+
C1CC2CCC1N2
|
506 |
+
C1=NSNCN1
|
507 |
+
C1=CSNCC1
|
508 |
+
C1CC2CCCC(C1)N2
|
509 |
+
C1=CN=CCC=N1
|
510 |
+
C1NCC2COCC1CN2
|
511 |
+
C1=CSC=C1
|
512 |
+
C1C2CN3CN1CN(C2)C3
|
513 |
+
C1=CCON=C1
|
514 |
+
C1=NCCCCC1
|
515 |
+
C1=CSCCCS1
|
516 |
+
C1=CNN=N1
|
517 |
+
C1CC2C3CC1C2CN3
|
518 |
+
C1CC2CCC(C2)O1
|
519 |
+
OS
|
520 |
+
C1CC2OC3CC1CC2O3
|
521 |
+
C1CC2CCC1CN2
|
522 |
+
C1=CNCN1
|
523 |
+
C1=CC2CC(N1)C1C=CC2C1
|
524 |
+
C1=CC2CC(C1)C2
|
525 |
+
C1=NCCCSC1
|
526 |
+
C1NC2CC(N1)C1OCC2O1
|
527 |
+
C#C
|
528 |
+
CCl
|
529 |
+
C1=CN=NN=C1
|
530 |
+
C1=CNCCN1
|
531 |
+
C1CCCCC1
|
532 |
+
C1CNCNC1
|
533 |
+
C1=CNNC1
|