Spaces:
Sleeping
Sleeping
File size: 5,502 Bytes
a3ea5d3 f9355e9 a3ea5d3 f9355e9 a3ea5d3 f9355e9 a3ea5d3 f9355e9 a3ea5d3 f9355e9 a3ea5d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import rdkit
import rdkit.Chem as Chem
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
from vocab import *
# import argparse
class MolTreeNode(object):
def __init__(self, smiles, clique=[]):
self.smiles = smiles
self.mol = get_mol(self.smiles)
self.clique = [x for x in clique] #copy
self.neighbors = []
def add_neighbor(self, nei_node):
self.neighbors.append(nei_node)
def recover(self, original_mol):
clique = []
clique.extend(self.clique)
if not self.is_leaf:
for cidx in self.clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
for nei_node in self.neighbors:
clique.extend(nei_node.clique)
if nei_node.is_leaf: #Leaf node, no need to mark
continue
for cidx in nei_node.clique:
#allow singleton node override the atom mapping
if cidx not in self.clique or len(nei_node.clique) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node.nid)
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return self.label
def assemble(self):
neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands,aroma = enum_assemble(self, neighbors)
new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0]
if len(new_cands) > 0: cands = new_cands
if len(cands) > 0:
self.cands, _ = zip(*cands)
self.cands = list(self.cands)
else:
self.cands = []
class MolTree(object):
def __init__(self, smiles):
self.smiles = smiles
self.mol = get_mol(smiles)
#Stereo Generation (currently disabled)
#mol = Chem.MolFromSmiles(smiles)
#self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
#self.smiles2D = Chem.MolToSmiles(mol)
#self.stereo_cands = decode_stereo(self.smiles2D)
cliques, edges = tree_decomp(self.mol)
self.nodes = []
root = 0
for i,c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
node = MolTreeNode(get_smiles(cmol), c)
self.nodes.append(node)
if min(c) == 0: root = i
for x,y in edges:
self.nodes[x].add_neighbor(self.nodes[y])
self.nodes[y].add_neighbor(self.nodes[x])
if root > 0:
self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]
for i,node in enumerate(self.nodes):
node.nid = i + 1
if len(node.neighbors) > 1: #Leaf node mol is not marked
set_atommap(node.mol, node.nid)
node.is_leaf = (len(node.neighbors) == 1)
def size(self):
return len(self.nodes)
def recover(self):
for node in self.nodes:
node.recover(self.mol)
def assemble(self):
for node in self.nodes:
node.assemble()
def dfs(node, fa_idx):
max_depth = 0
for child in node.neighbors:
if child.idx == fa_idx: continue
max_depth = max(max_depth, dfs(child, node.idx))
return max_depth + 1
def data_process_chunk(smiles_list):
cset = set()
for line in smiles_list:
smiles = line.split()[0]
# print(smiles)
mol = MolTree(smiles)
for c in mol.nodes:
cset.add(c.smiles)
# i+=1
# if i%10000 == 0:
# # print(i,end='\x1b[1K\r')
# print(i, ' / 1584663')
return list(cset)
# if __name__ == "__main__":
# import sys
# lg = rdkit.RDLogger.logger()
# lg.setLevel(rdkit.RDLogger.CRITICAL)
# i = 0
# import os
# from joblib import Parallel,delayed
# from tqdm import tqdm
# parser = argparse.ArgumentParser()
# parser.add_argument('--smiles_path', type=str,required=True)
# parser.add_argument('--vocab_path', type=str,required=True)
# parser.add_argument('--prop', type=bool,default=False)
# parser.add_argument('--ncpu', default=8,type=int)
# args = parser.parse_args()
# if args.prop:
# import pandas as pd
# smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES'])
# smiles_list = list(smiles_list.SMILES)
# else:
# with open(args.smiles_path,'r') as f:
# smiles_list = [line.split()[0] for line in f]
# print('Total smiles = ',len(smiles_list))
# # moses: 1584663
# chunk_size = 10000
# vocab_set_list = Parallel(n_jobs=args.ncpu)(
# delayed(data_process_chunk)(smiles_list[start: start + chunk_size])
# for start in tqdm(range(0, len(smiles_list), chunk_size))
# )
# vocab_list =[]
# for set_list in vocab_set_list:
# vocab_list.extend(set_list)
# cset = set(vocab_list)
# with open(args.vocab_path,'w') as f:
# for x in cset:
# f.write(''.join([x,'\n']))
|