Spaces:
Sleeping
Sleeping
import rdkit | |
import rdkit.Chem as Chem | |
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo | |
from vocab import * | |
# import argparse | |
class MolTreeNode(object): | |
def __init__(self, smiles, clique=[]): | |
self.smiles = smiles | |
self.mol = get_mol(self.smiles) | |
self.clique = [x for x in clique] #copy | |
self.neighbors = [] | |
def add_neighbor(self, nei_node): | |
self.neighbors.append(nei_node) | |
def recover(self, original_mol): | |
clique = [] | |
clique.extend(self.clique) | |
if not self.is_leaf: | |
for cidx in self.clique: | |
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid) | |
for nei_node in self.neighbors: | |
clique.extend(nei_node.clique) | |
if nei_node.is_leaf: #Leaf node, no need to mark | |
continue | |
for cidx in nei_node.clique: | |
#allow singleton node override the atom mapping | |
if cidx not in self.clique or len(nei_node.clique) == 1: | |
atom = original_mol.GetAtomWithIdx(cidx) | |
atom.SetAtomMapNum(nei_node.nid) | |
clique = list(set(clique)) | |
label_mol = get_clique_mol(original_mol, clique) | |
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol))) | |
for cidx in clique: | |
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0) | |
return self.label | |
def assemble(self): | |
neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1] | |
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) | |
singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1] | |
neighbors = singletons + neighbors | |
cands,aroma = enum_assemble(self, neighbors) | |
new_cands = [cand for i,cand in enumerate(cands) if aroma[i] >= 0] | |
if len(new_cands) > 0: cands = new_cands | |
if len(cands) > 0: | |
self.cands, _ = zip(*cands) | |
self.cands = list(self.cands) | |
else: | |
self.cands = [] | |
class MolTree(object): | |
def __init__(self, smiles): | |
self.smiles = smiles | |
self.mol = get_mol(smiles) | |
#Stereo Generation (currently disabled) | |
#mol = Chem.MolFromSmiles(smiles) | |
#self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True) | |
#self.smiles2D = Chem.MolToSmiles(mol) | |
#self.stereo_cands = decode_stereo(self.smiles2D) | |
cliques, edges = tree_decomp(self.mol) | |
self.nodes = [] | |
root = 0 | |
for i,c in enumerate(cliques): | |
cmol = get_clique_mol(self.mol, c) | |
node = MolTreeNode(get_smiles(cmol), c) | |
self.nodes.append(node) | |
if min(c) == 0: root = i | |
for x,y in edges: | |
self.nodes[x].add_neighbor(self.nodes[y]) | |
self.nodes[y].add_neighbor(self.nodes[x]) | |
if root > 0: | |
self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0] | |
for i,node in enumerate(self.nodes): | |
node.nid = i + 1 | |
if len(node.neighbors) > 1: #Leaf node mol is not marked | |
set_atommap(node.mol, node.nid) | |
node.is_leaf = (len(node.neighbors) == 1) | |
def size(self): | |
return len(self.nodes) | |
def recover(self): | |
for node in self.nodes: | |
node.recover(self.mol) | |
def assemble(self): | |
for node in self.nodes: | |
node.assemble() | |
def dfs(node, fa_idx): | |
max_depth = 0 | |
for child in node.neighbors: | |
if child.idx == fa_idx: continue | |
max_depth = max(max_depth, dfs(child, node.idx)) | |
return max_depth + 1 | |
def data_process_chunk(smiles_list): | |
cset = set() | |
for line in smiles_list: | |
smiles = line.split()[0] | |
# print(smiles) | |
mol = MolTree(smiles) | |
for c in mol.nodes: | |
cset.add(c.smiles) | |
# i+=1 | |
# if i%10000 == 0: | |
# # print(i,end='\x1b[1K\r') | |
# print(i, ' / 1584663') | |
return list(cset) | |
# if __name__ == "__main__": | |
# import sys | |
# lg = rdkit.RDLogger.logger() | |
# lg.setLevel(rdkit.RDLogger.CRITICAL) | |
# i = 0 | |
# import os | |
# from joblib import Parallel,delayed | |
# from tqdm import tqdm | |
# parser = argparse.ArgumentParser() | |
# parser.add_argument('--smiles_path', type=str,required=True) | |
# parser.add_argument('--vocab_path', type=str,required=True) | |
# parser.add_argument('--prop', type=bool,default=False) | |
# parser.add_argument('--ncpu', default=8,type=int) | |
# args = parser.parse_args() | |
# if args.prop: | |
# import pandas as pd | |
# smiles_list = pd.read_csv(args.smiles_path,usecols=['SMILES']) | |
# smiles_list = list(smiles_list.SMILES) | |
# else: | |
# with open(args.smiles_path,'r') as f: | |
# smiles_list = [line.split()[0] for line in f] | |
# print('Total smiles = ',len(smiles_list)) | |
# # moses: 1584663 | |
# chunk_size = 10000 | |
# vocab_set_list = Parallel(n_jobs=args.ncpu)( | |
# delayed(data_process_chunk)(smiles_list[start: start + chunk_size]) | |
# for start in tqdm(range(0, len(smiles_list), chunk_size)) | |
# ) | |
# vocab_list =[] | |
# for set_list in vocab_set_list: | |
# vocab_list.extend(set_list) | |
# cset = set(vocab_list) | |
# with open(args.vocab_path,'w') as f: | |
# for x in cset: | |
# f.write(''.join([x,'\n'])) | |