jtvae-demo / fast_jtnn /mol_tree.py
Trương Gia Bảo
Update UI
f9355e9
raw
history blame
5.5 kB
import rdkit
import rdkit.Chem as Chem
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
from vocab import *
# import argparse
class MolTreeNode(object):
def __init__(self, smiles, clique=[]):
self.smiles = smiles
self.mol = get_mol(self.smiles)
self.clique = [x for x in clique] #copy
self.neighbors = []
def add_neighbor(self, nei_node):
self.neighbors.append(nei_node)
def recover(self, original_mol):
clique = []
clique.extend(self.clique)
if not self.is_leaf:
for cidx in self.clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
for nei_node in self.neighbors:
clique.extend(nei_node.clique)
if nei_node.is_leaf: #Leaf node, no need to mark
continue
for cidx in nei_node.clique:
#allow singleton node override the atom mapping
if cidx not in self.clique or len(nei_node.clique) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node.nid)
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return self.label
def assemble(self):
neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands,aroma = enum_assemble(self, neighbors)
new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0]
if len(new_cands) > 0: cands = new_cands
if len(cands) > 0:
self.cands, _ = zip(*cands)
self.cands = list(self.cands)
else:
self.cands = []
class MolTree(object):
def __init__(self, smiles):
self.smiles = smiles
self.mol = get_mol(smiles)
#Stereo Generation (currently disabled)
#mol = Chem.MolFromSmiles(smiles)
#self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
#self.smiles2D = Chem.MolToSmiles(mol)
#self.stereo_cands = decode_stereo(self.smiles2D)
cliques, edges = tree_decomp(self.mol)
self.nodes = []
root = 0
for i,c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
node = MolTreeNode(get_smiles(cmol), c)
self.nodes.append(node)
if min(c) == 0: root = i
for x,y in edges:
self.nodes[x].add_neighbor(self.nodes[y])
self.nodes[y].add_neighbor(self.nodes[x])
if root > 0:
self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]
for i,node in enumerate(self.nodes):
node.nid = i + 1
if len(node.neighbors) > 1: #Leaf node mol is not marked
set_atommap(node.mol, node.nid)
node.is_leaf = (len(node.neighbors) == 1)
def size(self):
return len(self.nodes)
def recover(self):
for node in self.nodes:
node.recover(self.mol)
def assemble(self):
for node in self.nodes:
node.assemble()
def dfs(node, fa_idx):
max_depth = 0
for child in node.neighbors:
if child.idx == fa_idx: continue
max_depth = max(max_depth, dfs(child, node.idx))
return max_depth + 1
def data_process_chunk(smiles_list):
cset = set()
for line in smiles_list:
smiles = line.split()[0]
# print(smiles)
mol = MolTree(smiles)
for c in mol.nodes:
cset.add(c.smiles)
# i+=1
# if i%10000 == 0:
# # print(i,end='\x1b[1K\r')
# print(i, ' / 1584663')
return list(cset)
# if __name__ == "__main__":
# import sys
# lg = rdkit.RDLogger.logger()
# lg.setLevel(rdkit.RDLogger.CRITICAL)
# i = 0
# import os
# from joblib import Parallel,delayed
# from tqdm import tqdm
# parser = argparse.ArgumentParser()
# parser.add_argument('--smiles_path', type=str,required=True)
# parser.add_argument('--vocab_path', type=str,required=True)
# parser.add_argument('--prop', type=bool,default=False)
# parser.add_argument('--ncpu', default=8,type=int)
# args = parser.parse_args()
# if args.prop:
# import pandas as pd
# smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES'])
# smiles_list = list(smiles_list.SMILES)
# else:
# with open(args.smiles_path,'r') as f:
# smiles_list = [line.split()[0] for line in f]
# print('Total smiles = ',len(smiles_list))
# # moses: 1584663
# chunk_size = 10000
# vocab_set_list = Parallel(n_jobs=args.ncpu)(
# delayed(data_process_chunk)(smiles_list[start: start + chunk_size])
# for start in tqdm(range(0, len(smiles_list), chunk_size))
# )
# vocab_list =[]
# for set_list in vocab_set_list:
# vocab_list.extend(set_list)
# cset = set(vocab_list)
# with open(args.vocab_path,'w') as f:
# for x in cset:
# f.write(''.join([x,'\n']))