Spaces:
Sleeping
Sleeping
import rdkit | |
import rdkit.Chem as Chem | |
from scipy.sparse import csr_matrix | |
from scipy.sparse.csgraph import minimum_spanning_tree | |
from collections import defaultdict | |
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions | |
from vocab import Vocab | |
MST_MAX_WEIGHT = 100 | |
MAX_NCAND = 2000 | |
def set_atommap(mol, num=0): | |
for atom in mol.GetAtoms(): | |
atom.SetAtomMapNum(num) | |
def get_mol(smiles): | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return None | |
Chem.Kekulize(mol, clearAromaticFlags=True) | |
return mol | |
def get_smiles(mol): | |
return Chem.MolToSmiles(mol, kekuleSmiles=True) | |
def decode_stereo(smiles2D): | |
mol = Chem.MolFromSmiles(smiles2D) | |
dec_isomers = list(EnumerateStereoisomers(mol)) | |
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers] | |
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers] | |
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"] | |
if len(chiralN) > 0: | |
for mol in dec_isomers: | |
for idx in chiralN: | |
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) | |
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True)) | |
return smiles3D | |
def sanitize(mol): | |
try: | |
smiles = get_smiles(mol) | |
mol = get_mol(smiles) | |
except Exception as e: | |
return None | |
return mol | |
def copy_atom(atom): | |
new_atom = Chem.Atom(atom.GetSymbol()) | |
new_atom.SetFormalCharge(atom.GetFormalCharge()) | |
new_atom.SetAtomMapNum(atom.GetAtomMapNum()) | |
return new_atom | |
def copy_edit_mol(mol): | |
new_mol = Chem.RWMol(Chem.MolFromSmiles('')) | |
for atom in mol.GetAtoms(): | |
new_atom = copy_atom(atom) | |
new_mol.AddAtom(new_atom) | |
for bond in mol.GetBonds(): | |
a1 = bond.GetBeginAtom().GetIdx() | |
a2 = bond.GetEndAtom().GetIdx() | |
bt = bond.GetBondType() | |
new_mol.AddBond(a1, a2, bt) | |
return new_mol | |
def get_clique_mol(mol, atoms): | |
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) | |
new_mol = Chem.MolFromSmiles(smiles, sanitize=False) | |
new_mol = copy_edit_mol(new_mol).GetMol() | |
new_mol = sanitize(new_mol) #We assume this is not None | |
return new_mol | |
def tree_decomp(mol): | |
n_atoms = mol.GetNumAtoms() | |
if n_atoms == 1: #special case | |
return [[0]], [] | |
cliques = [] | |
for bond in mol.GetBonds(): | |
a1 = bond.GetBeginAtom().GetIdx() | |
a2 = bond.GetEndAtom().GetIdx() | |
if not bond.IsInRing(): | |
cliques.append([a1,a2]) | |
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)] | |
cliques.extend(ssr) | |
nei_list = [[] for i in range(n_atoms)] | |
for i in range(len(cliques)): | |
for atom in cliques[i]: | |
nei_list[atom].append(i) | |
#Merge Rings with intersection > 2 atoms | |
for i in range(len(cliques)): | |
if len(cliques[i]) <= 2: continue | |
for atom in cliques[i]: | |
for j in nei_list[atom]: | |
if i >= j or len(cliques[j]) <= 2: continue | |
inter = set(cliques[i]) & set(cliques[j]) | |
if len(inter) > 2: | |
cliques[i].extend(cliques[j]) | |
cliques[i] = list(set(cliques[i])) | |
cliques[j] = [] | |
cliques = [c for c in cliques if len(c) > 0] | |
nei_list = [[] for i in range(n_atoms)] | |
for i in range(len(cliques)): | |
for atom in cliques[i]: | |
nei_list[atom].append(i) | |
#Build edges and add singleton cliques | |
edges = defaultdict(int) | |
for atom in range(n_atoms): | |
if len(nei_list[atom]) <= 1: | |
continue | |
cnei = nei_list[atom] | |
bonds = [c for c in cnei if len(cliques[c]) == 2] | |
rings = [c for c in cnei if len(cliques[c]) > 4] | |
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with. | |
cliques.append([atom]) | |
c2 = len(cliques) - 1 | |
for c1 in cnei: | |
edges[(c1,c2)] = 1 | |
elif len(rings) > 2: #Multiple (n>2) complex rings | |
cliques.append([atom]) | |
c2 = len(cliques) - 1 | |
for c1 in cnei: | |
edges[(c1,c2)] = MST_MAX_WEIGHT - 1 | |
else: | |
for i in range(len(cnei)): | |
for j in range(i + 1, len(cnei)): | |
c1,c2 = cnei[i],cnei[j] | |
inter = set(cliques[c1]) & set(cliques[c2]) | |
if edges[(c1,c2)] < len(inter): | |
edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction | |
edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()] | |
if len(edges) == 0: | |
return cliques, edges | |
#Compute Maximum Spanning Tree | |
row,col,data = zip(*edges) | |
n_clique = len(cliques) | |
clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) ) | |
junc_tree = minimum_spanning_tree(clique_graph) | |
row,col = junc_tree.nonzero() | |
edges = [(row[i],col[i]) for i in range(len(row))] | |
return (cliques, edges) | |
def atom_equal(a1, a2): | |
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge() | |
#Bond type not considered because all aromatic (so SINGLE matches DOUBLE) | |
def ring_bond_equal(b1, b2, reverse=False): | |
b1 = (b1.GetBeginAtom(), b1.GetEndAtom()) | |
if reverse: | |
b2 = (b2.GetEndAtom(), b2.GetBeginAtom()) | |
else: | |
b2 = (b2.GetBeginAtom(), b2.GetEndAtom()) | |
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) | |
def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap): | |
prev_nids = [node.nid for node in prev_nodes] | |
for nei_node in prev_nodes + neighbors: | |
nei_id,nei_mol = nei_node.nid,nei_node.mol | |
amap = nei_amap[nei_id] | |
for atom in nei_mol.GetAtoms(): | |
if atom.GetIdx() not in amap: | |
new_atom = copy_atom(atom) | |
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom) | |
if nei_mol.GetNumBonds() == 0: | |
nei_atom = nei_mol.GetAtomWithIdx(0) | |
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0]) | |
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum()) | |
else: | |
for bond in nei_mol.GetBonds(): | |
a1 = amap[bond.GetBeginAtom().GetIdx()] | |
a2 = amap[bond.GetEndAtom().GetIdx()] | |
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None: | |
ctr_mol.AddBond(a1, a2, bond.GetBondType()) | |
elif nei_id in prev_nids: #father node overrides | |
ctr_mol.RemoveBond(a1, a2) | |
ctr_mol.AddBond(a1, a2, bond.GetBondType()) | |
return ctr_mol | |
def local_attach(ctr_mol, neighbors, prev_nodes, amap_list): | |
ctr_mol = copy_edit_mol(ctr_mol) | |
nei_amap = {nei.nid:{} for nei in prev_nodes + neighbors} | |
for nei_id,ctr_atom,nei_atom in amap_list: | |
nei_amap[nei_id][nei_atom] = ctr_atom | |
ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap) | |
return ctr_mol.GetMol() | |
#This version records idx mapping between ctr_mol and nei_mol | |
def enum_attach(ctr_mol, nei_node, amap, singletons): | |
nei_mol,nei_idx = nei_node.mol,nei_node.nid | |
att_confs = [] | |
black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons] | |
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list] | |
ctr_bonds = [bond for bond in ctr_mol.GetBonds()] | |
if nei_mol.GetNumBonds() == 0: #neighbor singleton | |
nei_atom = nei_mol.GetAtomWithIdx(0) | |
used_list = [atom_idx for _,atom_idx,_ in amap] | |
for atom in ctr_atoms: | |
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list: | |
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)] | |
att_confs.append( new_amap ) | |
elif nei_mol.GetNumBonds() == 1: #neighbor is a bond | |
bond = nei_mol.GetBondWithIdx(0) | |
bond_val = int(bond.GetBondTypeAsDouble()) | |
b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom() | |
for atom in ctr_atoms: | |
#Optimize if atom is carbon (other atoms may change valence) | |
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val: | |
continue | |
if atom_equal(atom, b1): | |
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())] | |
att_confs.append( new_amap ) | |
elif atom_equal(atom, b2): | |
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())] | |
att_confs.append( new_amap ) | |
else: | |
#intersection is an atom | |
for a1 in ctr_atoms: | |
for a2 in nei_mol.GetAtoms(): | |
if atom_equal(a1, a2): | |
#Optimize if atom is carbon (other atoms may change valence) | |
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4: | |
continue | |
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())] | |
att_confs.append( new_amap ) | |
#intersection is an bond | |
if ctr_mol.GetNumBonds() > 1: | |
for b1 in ctr_bonds: | |
for b2 in nei_mol.GetBonds(): | |
if ring_bond_equal(b1, b2): | |
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())] | |
att_confs.append( new_amap ) | |
if ring_bond_equal(b1, b2, reverse=True): | |
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())] | |
att_confs.append( new_amap ) | |
return att_confs | |
#Try rings first: Speed-Up | |
def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]): | |
all_attach_confs = [] | |
singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1] | |
def search(cur_amap, depth): | |
if len(all_attach_confs) > MAX_NCAND: | |
return | |
if depth == len(neighbors): | |
all_attach_confs.append(cur_amap) | |
return | |
nei_node = neighbors[depth] | |
cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons) | |
cand_smiles = set() | |
candidates = [] | |
for amap in cand_amap: | |
cand_mol = local_attach(node.mol, neighbors[:depth+1], prev_nodes, amap) | |
cand_mol = sanitize(cand_mol) | |
if cand_mol is None: | |
continue | |
smiles = get_smiles(cand_mol) | |
if smiles in cand_smiles: | |
continue | |
cand_smiles.add(smiles) | |
candidates.append(amap) | |
if len(candidates) == 0: | |
return | |
for new_amap in candidates: | |
search(new_amap, depth + 1) | |
search(prev_amap, 0) | |
cand_smiles = set() | |
candidates = [] | |
aroma_score = [] | |
for amap in all_attach_confs: | |
cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap) | |
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol)) | |
smiles = Chem.MolToSmiles(cand_mol) | |
if smiles in cand_smiles or check_singleton(cand_mol, node, neighbors) == False: | |
continue | |
cand_smiles.add(smiles) | |
candidates.append( (smiles,amap) ) | |
aroma_score.append( check_aroma(cand_mol, node, neighbors) ) | |
return candidates, aroma_score | |
def check_singleton(cand_mol, ctr_node, nei_nodes): | |
rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() > 2] | |
singletons = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() == 1] | |
if len(singletons) > 0 or len(rings) == 0: return True | |
n_leaf2_atoms = 0 | |
for atom in cand_mol.GetAtoms(): | |
nei_leaf_atoms = [a for a in atom.GetNeighbors() if not a.IsInRing()] #a.GetDegree() == 1] | |
if len(nei_leaf_atoms) > 1: | |
n_leaf2_atoms += 1 | |
return n_leaf2_atoms == 0 | |
def check_aroma(cand_mol, ctr_node, nei_nodes): | |
rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() >= 3] | |
if len(rings) < 2: return 0 #Only multi-ring system needs to be checked | |
get_nid = lambda x: 0 if x.is_leaf else x.nid | |
benzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.benzynes] | |
penzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.penzynes] | |
if len(benzynes) + len(penzynes) == 0: | |
return 0 #No specific aromatic rings | |
n_aroma_atoms = 0 | |
for atom in cand_mol.GetAtoms(): | |
if atom.GetAtomMapNum() in benzynes+penzynes and atom.GetIsAromatic(): | |
n_aroma_atoms += 1 | |
if n_aroma_atoms >= len(benzynes) * 4 + len(penzynes) * 3: | |
return 1000 | |
else: | |
return -0.001 | |
#Only used for debugging purpose | |
def dfs_assemble(cur_mol, global_amap, fa_amap, cur_node, fa_node): | |
fa_nid = fa_node.nid if fa_node is not None else -1 | |
prev_nodes = [fa_node] if fa_node is not None else [] | |
children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] | |
neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] | |
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) | |
singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] | |
neighbors = singletons + neighbors | |
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid] | |
cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) | |
cand_smiles,cand_amap = zip(*cands) | |
label_idx = cand_smiles.index(cur_node.label) | |
label_amap = cand_amap[label_idx] | |
for nei_id,ctr_atom,nei_atom in label_amap: | |
if nei_id == fa_nid: | |
continue | |
global_amap[nei_id][nei_atom] = global_amap[cur_node.nid][ctr_atom] | |
cur_mol = attach_mols(cur_mol, children, [], global_amap) #father is already attached | |
for nei_node in children: | |
if not nei_node.is_leaf: | |
dfs_assemble(cur_mol, global_amap, label_amap, nei_node, cur_node) | |
if __name__ == "__main__": | |
import sys | |
from mol_tree import MolTree | |
lg = rdkit.RDLogger.logger() | |
lg.setLevel(rdkit.RDLogger.CRITICAL) | |
smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1","O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"] | |
def tree_test(): | |
for s in sys.stdin: | |
s = s.split()[0] | |
tree = MolTree(s) | |
print('-------------------------------------------') | |
print(s) | |
for node in tree.nodes: | |
print(node.smiles, [x.smiles for x in node.neighbors]) | |
def decode_test(): | |
wrong = 0 | |
for tot,s in enumerate(sys.stdin): | |
s = s.split()[0] | |
tree = MolTree(s) | |
tree.recover() | |
cur_mol = copy_edit_mol(tree.nodes[0].mol) | |
global_amap = [{}] + [{} for node in tree.nodes] | |
global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} | |
dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None) | |
cur_mol = cur_mol.GetMol() | |
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) | |
set_atommap(cur_mol) | |
dec_smiles = Chem.MolToSmiles(cur_mol) | |
gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s)) | |
if gold_smiles != dec_smiles: | |
print(gold_smiles, dec_smiles) | |
wrong += 1 | |
print(wrong, tot + 1) | |
def enum_test(): | |
for s in sys.stdin: | |
s = s.split()[0] | |
tree = MolTree(s) | |
tree.recover() | |
tree.assemble() | |
for node in tree.nodes: | |
if node.label not in node.cands: | |
print(tree.smiles) | |
print(node.smiles, [x.smiles for x in node.neighbors]) | |
print(node.label, len(node.cands)) | |
def count(): | |
cnt,n = 0,0 | |
for s in sys.stdin: | |
s = s.split()[0] | |
tree = MolTree(s) | |
tree.recover() | |
tree.assemble() | |
for node in tree.nodes: | |
cnt += len(node.cands) | |
n += len(tree.nodes) | |
#print cnt * 1.0 / n | |
count() | |