Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from collections import deque | |
from mol_tree import Vocab, MolTree | |
from nnutils import create_var, index_select_ND | |
class JTNNEncoder(nn.Module): | |
def __init__(self, hidden_size, depth, embedding): | |
super(JTNNEncoder, self).__init__() | |
self.hidden_size = hidden_size | |
self.depth = depth | |
self.embedding = embedding | |
self.outputNN = nn.Sequential( | |
nn.Linear(2 * hidden_size, hidden_size), | |
nn.ReLU() | |
) | |
self.GRU = GraphGRU(hidden_size, hidden_size, depth=depth) | |
def forward(self, fnode, fmess, node_graph, mess_graph, scope): | |
fnode = create_var(fnode) | |
fmess = create_var(fmess) | |
node_graph = create_var(node_graph) | |
mess_graph = create_var(mess_graph) | |
messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size)) | |
fnode = self.embedding(fnode) | |
fmess = index_select_ND(fnode, 0, fmess) | |
messages = self.GRU(messages, fmess, mess_graph) | |
mess_nei = index_select_ND(messages, 0, node_graph) | |
node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1) | |
node_vecs = self.outputNN(node_vecs) | |
max_len = max([x for _,x in scope]) | |
batch_vecs = [] | |
for st,le in scope: | |
cur_vecs = node_vecs[st] #Root is the first node | |
batch_vecs.append( cur_vecs ) | |
tree_vecs = torch.stack(batch_vecs, dim=0) | |
return tree_vecs, messages | |
def tensorize(tree_batch): | |
node_batch = [] | |
scope = [] | |
for tree in tree_batch: | |
scope.append( (len(node_batch), len(tree.nodes)) ) | |
node_batch.extend(tree.nodes) | |
return JTNNEncoder.tensorize_nodes(node_batch, scope) | |
def tensorize_nodes(node_batch, scope): | |
messages,mess_dict = [None],{} | |
fnode = [] | |
for x in node_batch: | |
fnode.append(x.wid) | |
for y in x.neighbors: | |
mess_dict[(x.idx,y.idx)] = len(messages) | |
messages.append( (x,y) ) | |
node_graph = [[] for i in range(len(node_batch))] | |
mess_graph = [[] for i in range(len(messages))] | |
fmess = [0] * len(messages) | |
for x,y in messages[1:]: | |
mid1 = mess_dict[(x.idx,y.idx)] | |
fmess[mid1] = x.idx | |
node_graph[y.idx].append(mid1) | |
for z in y.neighbors: | |
if z.idx == x.idx: continue | |
mid2 = mess_dict[(y.idx,z.idx)] | |
mess_graph[mid2].append(mid1) | |
max_len = max([len(t) for t in node_graph] + [1]) | |
for t in node_graph: | |
pad_len = max_len - len(t) | |
t.extend([0] * pad_len) | |
max_len = max([len(t) for t in mess_graph] + [1]) | |
for t in mess_graph: | |
pad_len = max_len - len(t) | |
t.extend([0] * pad_len) | |
mess_graph = torch.LongTensor(mess_graph) | |
node_graph = torch.LongTensor(node_graph) | |
fmess = torch.LongTensor(fmess) | |
fnode = torch.LongTensor(fnode) | |
return (fnode, fmess, node_graph, mess_graph, scope), mess_dict | |
class GraphGRU(nn.Module): | |
def __init__(self, input_size, hidden_size, depth): | |
super(GraphGRU, self).__init__() | |
self.hidden_size = hidden_size | |
self.input_size = input_size | |
self.depth = depth | |
self.W_z = nn.Linear(input_size + hidden_size, hidden_size) | |
self.W_r = nn.Linear(input_size, hidden_size, bias=False) | |
self.U_r = nn.Linear(hidden_size, hidden_size) | |
self.W_h = nn.Linear(input_size + hidden_size, hidden_size) | |
def forward(self, h, x, mess_graph): | |
mask = torch.ones(h.size(0), 1) | |
mask[0] = 0 #first vector is padding | |
mask = create_var(mask) | |
for it in range(self.depth): | |
h_nei = index_select_ND(h, 0, mess_graph) | |
sum_h = h_nei.sum(dim=1) | |
z_input = torch.cat([x, sum_h], dim=1) | |
z = F.sigmoid(self.W_z(z_input)) | |
r_1 = self.W_r(x).view(-1, 1, self.hidden_size) | |
r_2 = self.U_r(h_nei) | |
r = F.sigmoid(r_1 + r_2) | |
gated_h = r * h_nei | |
sum_gated_h = gated_h.sum(dim=1) | |
h_input = torch.cat([x, sum_gated_h], dim=1) | |
pre_h = F.tanh(self.W_h(h_input)) | |
h = (1.0 - z) * sum_h + z * pre_h | |
h = h * mask | |
return h | |