import rdkit import rdkit.Chem as Chem from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo from vocab import * # import argparse class MolTreeNode(object): def __init__(self, smiles, clique=[]): self.smiles = smiles self.mol = get_mol(self.smiles) self.clique = [x for x in clique] #copy self.neighbors = [] def add_neighbor(self, nei_node): self.neighbors.append(nei_node) def recover(self, original_mol): clique = [] clique.extend(self.clique) if not self.is_leaf: for cidx in self.clique: original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid) for nei_node in self.neighbors: clique.extend(nei_node.clique) if nei_node.is_leaf: #Leaf node, no need to mark continue for cidx in nei_node.clique: #allow singleton node override the atom mapping if cidx not in self.clique or len(nei_node.clique) == 1: atom = original_mol.GetAtomWithIdx(cidx) atom.SetAtomMapNum(nei_node.nid) clique = list(set(clique)) label_mol = get_clique_mol(original_mol, clique) self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol))) for cidx in clique: original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0) return self.label def assemble(self): neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1] neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1] neighbors = singletons + neighbors cands,aroma = enum_assemble(self, neighbors) new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0] if len(new_cands) > 0: cands = new_cands if len(cands) > 0: self.cands, _ = zip(*cands) self.cands = list(self.cands) else: self.cands = [] class MolTree(object): def __init__(self, smiles): self.smiles = smiles self.mol = get_mol(smiles) #Stereo Generation (currently disabled) #mol = Chem.MolFromSmiles(smiles) #self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True) #self.smiles2D = Chem.MolToSmiles(mol) #self.stereo_cands = decode_stereo(self.smiles2D) cliques, edges = tree_decomp(self.mol) self.nodes = [] root = 0 for i,c in enumerate(cliques): cmol = get_clique_mol(self.mol, c) node = MolTreeNode(get_smiles(cmol), c) self.nodes.append(node) if min(c) == 0: root = i for x,y in edges: self.nodes[x].add_neighbor(self.nodes[y]) self.nodes[y].add_neighbor(self.nodes[x]) if root > 0: self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0] for i,node in enumerate(self.nodes): node.nid = i + 1 if len(node.neighbors) > 1: #Leaf node mol is not marked set_atommap(node.mol, node.nid) node.is_leaf = (len(node.neighbors) == 1) def size(self): return len(self.nodes) def recover(self): for node in self.nodes: node.recover(self.mol) def assemble(self): for node in self.nodes: node.assemble() def dfs(node, fa_idx): max_depth = 0 for child in node.neighbors: if child.idx == fa_idx: continue max_depth = max(max_depth, dfs(child, node.idx)) return max_depth + 1 def data_process_chunk(smiles_list): cset = set() for line in smiles_list: smiles = line.split()[0] # print(smiles) mol = MolTree(smiles) for c in mol.nodes: cset.add(c.smiles) # i+=1 # if i%10000 == 0: # # print(i,end='\x1b[1K\r') # print(i, ' / 1584663') return list(cset) # if __name__ == "__main__": # import sys # lg = rdkit.RDLogger.logger() # lg.setLevel(rdkit.RDLogger.CRITICAL) # i = 0 # import os # from joblib import Parallel,delayed # from tqdm import tqdm # parser = argparse.ArgumentParser() # parser.add_argument('--smiles_path', type=str,required=True) # parser.add_argument('--vocab_path', type=str,required=True) # parser.add_argument('--prop', type=bool,default=False) # parser.add_argument('--ncpu', default=8,type=int) # args = parser.parse_args() # if args.prop: # import pandas as pd # smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES']) # smiles_list = list(smiles_list.SMILES) # else: # with open(args.smiles_path,'r') as f: # smiles_list = [line.split()[0] for line in f] # print('Total smiles = ',len(smiles_list)) # # moses: 1584663 # chunk_size = 10000 # vocab_set_list = Parallel(n_jobs=args.ncpu)( # delayed(data_process_chunk)(smiles_list[start: start + chunk_size]) # for start in tqdm(range(0, len(smiles_list), chunk_size)) # ) # vocab_list =[] # for set_list in vocab_set_list: # vocab_list.extend(set_list) # cset = set(vocab_list) # with open(args.vocab_path,'w') as f: # for x in cset: # f.write(''.join([x,'\n']))