File size: 5,502 Bytes
a3ea5d3
 
 
 
f9355e9
a3ea5d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9355e9
 
 
 
a3ea5d3
f9355e9
a3ea5d3
f9355e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3ea5d3
f9355e9
 
 
 
 
 
 
 
 
 
 
 
 
a3ea5d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import rdkit
import rdkit.Chem as Chem
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
from vocab import *
# import argparse

class MolTreeNode(object):

    def __init__(self, smiles, clique=[]):
        self.smiles = smiles
        self.mol = get_mol(self.smiles)

        self.clique = [x for x in clique] #copy
        self.neighbors = []
        
    def add_neighbor(self, nei_node):
        self.neighbors.append(nei_node)

    def recover(self, original_mol):
        clique = []
        clique.extend(self.clique)
        if not self.is_leaf:
            for cidx in self.clique:
                original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)

        for nei_node in self.neighbors:
            clique.extend(nei_node.clique)
            if nei_node.is_leaf: #Leaf node, no need to mark 
                continue
            for cidx in nei_node.clique:
                #allow singleton node override the atom mapping
                if cidx not in self.clique or len(nei_node.clique) == 1:
                    atom = original_mol.GetAtomWithIdx(cidx)
                    atom.SetAtomMapNum(nei_node.nid)

        clique = list(set(clique))
        label_mol = get_clique_mol(original_mol, clique)
        self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))

        for cidx in clique:
            original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)

        return self.label
    
    def assemble(self):
        neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
        singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cands,aroma = enum_assemble(self, neighbors)
        new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0]
        if len(new_cands) > 0: cands = new_cands

        if len(cands) > 0:
            self.cands, _ = zip(*cands)
            self.cands = list(self.cands)
        else:
            self.cands = []

class MolTree(object):

    def __init__(self, smiles):
        self.smiles = smiles
        self.mol = get_mol(smiles)

        #Stereo Generation (currently disabled)
        #mol = Chem.MolFromSmiles(smiles)
        #self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
        #self.smiles2D = Chem.MolToSmiles(mol)
        #self.stereo_cands = decode_stereo(self.smiles2D)

        cliques, edges = tree_decomp(self.mol)
        self.nodes = []
        root = 0
        for i,c in enumerate(cliques):
            cmol = get_clique_mol(self.mol, c)
            node = MolTreeNode(get_smiles(cmol), c)
            self.nodes.append(node)
            if min(c) == 0: root = i

        for x,y in edges:
            self.nodes[x].add_neighbor(self.nodes[y])
            self.nodes[y].add_neighbor(self.nodes[x])
        
        if root > 0:
            self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]

        for i,node in enumerate(self.nodes):
            node.nid = i + 1
            if len(node.neighbors) > 1: #Leaf node mol is not marked
                set_atommap(node.mol, node.nid)
            node.is_leaf = (len(node.neighbors) == 1)

    def size(self):
        return len(self.nodes)

    def recover(self):
        for node in self.nodes:
            node.recover(self.mol)

    def assemble(self):
        for node in self.nodes:
            node.assemble()

def dfs(node, fa_idx):
    max_depth = 0
    for child in node.neighbors:
        if child.idx == fa_idx: continue
        max_depth = max(max_depth, dfs(child, node.idx))
    return max_depth + 1

def data_process_chunk(smiles_list):
    cset = set()
    for line in smiles_list:
        smiles = line.split()[0]
        # print(smiles)
        mol = MolTree(smiles)
        for c in mol.nodes:
            cset.add(c.smiles)
        # i+=1
        # if i%10000 == 0:
        #     # print(i,end='\x1b[1K\r')
        #     print(i, ' / 1584663')
    return list(cset)

# if __name__ == "__main__":
#     import sys
#     lg = rdkit.RDLogger.logger() 
#     lg.setLevel(rdkit.RDLogger.CRITICAL)
    
#     i = 0
    
#     import os
#     from joblib import Parallel,delayed
#     from tqdm import tqdm
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--smiles_path', type=str,required=True)
#     parser.add_argument('--vocab_path', type=str,required=True)
#     parser.add_argument('--prop', type=bool,default=False)
#     parser.add_argument('--ncpu', default=8,type=int)
#     args = parser.parse_args()

#     if args.prop:
#         import pandas as pd
#         smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES'])
#         smiles_list = list(smiles_list.SMILES)
#     else:
#         with open(args.smiles_path,'r') as f:
#             smiles_list = [line.split()[0] for line in f]
#     print('Total smiles = ',len(smiles_list))

#     # moses: 1584663
    
#     chunk_size = 10000
#     vocab_set_list = Parallel(n_jobs=args.ncpu)(
#         delayed(data_process_chunk)(smiles_list[start: start + chunk_size])
#         for start in tqdm(range(0, len(smiles_list), chunk_size))
#     )
#     vocab_list =[]
#     for set_list in vocab_set_list:
#         vocab_list.extend(set_list)

#     cset = set(vocab_list)
#     with open(args.vocab_path,'w') as f:
#         for x in cset:
#             f.write(''.join([x,'\n']))