jtvae-demo / fast_jtnn /datautils.py
Trương Gia Bảo
Initial commit
a3ea5d3
raw
history blame
7.34 kB
import torch
from torch.utils.data import Dataset, DataLoader
from mol_tree import MolTree
import numpy as np
from jtnn_enc import JTNNEncoder
from mpn import MPN
from jtmpn import JTMPN
import pickle
import os, random
class PairTreeFolder(object):
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, y_assm=True, replicate=None):
self.data_folder = data_folder
self.data_files = [fn for fn in os.listdir(data_folder)]
self.batch_size = batch_size
self.vocab = vocab
self.num_workers = num_workers
self.y_assm = y_assm
self.shuffle = shuffle
if replicate is not None: #expand is int
self.data_files = self.data_files * replicate
def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
with open(fn, 'rb') as f:
data = pickle.load(f)
if self.shuffle:
random.shuffle(data) #shuffle data before batch
batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
if len(batches[-1]) < self.batch_size:
batches.pop()
dataset = PairTreeDataset(batches, self.vocab, self.y_assm)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
for b in dataloader:
yield b
del data, batches, dataset, dataloader
class MolTreeFolder(object):
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
self.data_folder = data_folder
self.data_files = [fn for fn in os.listdir(data_folder)]
self.batch_size = batch_size
self.vocab = vocab
self.num_workers = num_workers
self.shuffle = shuffle
self.assm = assm
if replicate is not None: #expand is int
self.data_files = self.data_files * replicate
def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
with open(fn, 'rb') as f:
data = pickle.load(f)
if self.shuffle:
random.shuffle(data) #shuffle data before batch
batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
if len(batches[-1]) < self.batch_size:
batches.pop()
dataset = MolTreeDataset(batches, self.vocab, self.assm)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
for b in dataloader:
yield b
del data, batches, dataset, dataloader
class PairTreeDataset(Dataset):
def __init__(self, data, vocab, y_assm):
self.data = data
self.vocab = vocab
self.y_assm = y_assm
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
batch0, batch1 = zip(*self.data[idx])
return tensorize(batch0, self.vocab, assm=False), tensorize(batch1, self.vocab, assm=self.y_assm)
class MolTreeDataset(Dataset):
def __init__(self, data, vocab, assm=True):
self.data = data
self.vocab = vocab
self.assm = assm
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return tensorize(self.data[idx], self.vocab, assm=self.assm)
def tensorize(tree_batch, vocab, assm=True):
set_batch_nodeID(tree_batch, vocab)
smiles_batch = [tree.smiles for tree in tree_batch]
jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
jtenc_holder = jtenc_holder
mpn_holder = MPN.tensorize(smiles_batch)
if assm is False:
return tree_batch, jtenc_holder, mpn_holder
cands = []
batch_idx = []
for i,mol_tree in enumerate(tree_batch):
for node in mol_tree.nodes:
#Leaf node's attachment is determined by neighboring node's attachment
if node.is_leaf or len(node.cands) == 1: continue
cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
batch_idx.extend([i] * len(node.cands))
jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
batch_idx = torch.LongTensor(batch_idx)
return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx)
def set_batch_nodeID(mol_batch, vocab):
tot = 0
for mol_tree in mol_batch:
for node in mol_tree.nodes:
node.idx = tot
node.wid = vocab.get_index(node.smiles)
tot += 1
class PropMolTreeDataset(Dataset):
def __init__(self, data, vocab, assm=True):
self.data = data
self.vocab = vocab
self.assm = assm
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return tensorize_prop(self.data[idx],self.vocab, assm=self.assm)
class PropMolTreeFolder(object):
def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None):
self.data_folder = data_folder
self.data_files = [fn for fn in os.listdir(data_folder)]
self.batch_size = batch_size
self.vocab = vocab
self.num_workers = num_workers
self.shuffle = shuffle
self.assm = assm
if replicate is not None: #expand is int
self.data_files = self.data_files * replicate
def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
with open(fn, 'rb') as f:
data = pickle.load(f)
# print(data[0:5])
if self.shuffle:
random.shuffle(data) #shuffle data before batch
batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)]
if len(batches[-1]) < self.batch_size:
batches.pop()
dataset = PropMolTreeDataset(batches, self.vocab, self.assm)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0])
for b in dataloader:
yield b
del data, batches, dataset, dataloader
def tensorize_prop(data, vocab, assm=True):
tree_batch,prop = list(zip(*data))
set_batch_nodeID(tree_batch, vocab)
smiles_batch = [tree.smiles for tree in tree_batch]
jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch)
jtenc_holder = jtenc_holder
mpn_holder = MPN.tensorize(smiles_batch)
if assm is False:
return tree_batch, jtenc_holder, mpn_holder
cands = []
batch_idx = []
for i,mol_tree in enumerate(tree_batch):
for node in mol_tree.nodes:
#Leaf node's attachment is determined by neighboring node's attachment
if node.is_leaf or len(node.cands) == 1: continue
cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] )
batch_idx.extend([i] * len(node.cands))
jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
batch_idx = torch.LongTensor(batch_idx)
return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx), prop