Spaces:
Running
Running
File size: 4,471 Bytes
a3ea5d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
from mol_tree import Vocab, MolTree
from nnutils import create_var, index_select_ND
class JTNNEncoder(nn.Module):
def __init__(self, hidden_size, depth, embedding):
super(JTNNEncoder, self).__init__()
self.hidden_size = hidden_size
self.depth = depth
self.embedding = embedding
self.outputNN = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU()
)
self.GRU = GraphGRU(hidden_size, hidden_size, depth=depth)
def forward(self, fnode, fmess, node_graph, mess_graph, scope):
fnode = create_var(fnode)
fmess = create_var(fmess)
node_graph = create_var(node_graph)
mess_graph = create_var(mess_graph)
messages = create_var(torch.zeros(mess_graph.size(0), self.hidden_size))
fnode = self.embedding(fnode)
fmess = index_select_ND(fnode, 0, fmess)
messages = self.GRU(messages, fmess, mess_graph)
mess_nei = index_select_ND(messages, 0, node_graph)
node_vecs = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
node_vecs = self.outputNN(node_vecs)
max_len = max([x for _,x in scope])
batch_vecs = []
for st,le in scope:
cur_vecs = node_vecs[st] #Root is the first node
batch_vecs.append( cur_vecs )
tree_vecs = torch.stack(batch_vecs, dim=0)
return tree_vecs, messages
@staticmethod
def tensorize(tree_batch):
node_batch = []
scope = []
for tree in tree_batch:
scope.append( (len(node_batch), len(tree.nodes)) )
node_batch.extend(tree.nodes)
return JTNNEncoder.tensorize_nodes(node_batch, scope)
@staticmethod
def tensorize_nodes(node_batch, scope):
messages,mess_dict = [None],{}
fnode = []
for x in node_batch:
fnode.append(x.wid)
for y in x.neighbors:
mess_dict[(x.idx,y.idx)] = len(messages)
messages.append( (x,y) )
node_graph = [[] for i in range(len(node_batch))]
mess_graph = [[] for i in range(len(messages))]
fmess = [0] * len(messages)
for x,y in messages[1:]:
mid1 = mess_dict[(x.idx,y.idx)]
fmess[mid1] = x.idx
node_graph[y.idx].append(mid1)
for z in y.neighbors:
if z.idx == x.idx: continue
mid2 = mess_dict[(y.idx,z.idx)]
mess_graph[mid2].append(mid1)
max_len = max([len(t) for t in node_graph] + [1])
for t in node_graph:
pad_len = max_len - len(t)
t.extend([0] * pad_len)
max_len = max([len(t) for t in mess_graph] + [1])
for t in mess_graph:
pad_len = max_len - len(t)
t.extend([0] * pad_len)
mess_graph = torch.LongTensor(mess_graph)
node_graph = torch.LongTensor(node_graph)
fmess = torch.LongTensor(fmess)
fnode = torch.LongTensor(fnode)
return (fnode, fmess, node_graph, mess_graph, scope), mess_dict
class GraphGRU(nn.Module):
def __init__(self, input_size, hidden_size, depth):
super(GraphGRU, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.depth = depth
self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
self.W_r = nn.Linear(input_size, hidden_size, bias=False)
self.U_r = nn.Linear(hidden_size, hidden_size)
self.W_h = nn.Linear(input_size + hidden_size, hidden_size)
def forward(self, h, x, mess_graph):
mask = torch.ones(h.size(0), 1)
mask[0] = 0 #first vector is padding
mask = create_var(mask)
for it in range(self.depth):
h_nei = index_select_ND(h, 0, mess_graph)
sum_h = h_nei.sum(dim=1)
z_input = torch.cat([x, sum_h], dim=1)
z = F.sigmoid(self.W_z(z_input))
r_1 = self.W_r(x).view(-1, 1, self.hidden_size)
r_2 = self.U_r(h_nei)
r = F.sigmoid(r_1 + r_2)
gated_h = r * h_nei
sum_gated_h = gated_h.sum(dim=1)
h_input = torch.cat([x, sum_gated_h], dim=1)
pre_h = F.tanh(self.W_h(h_input))
h = (1.0 - z) * sum_h + z * pre_h
h = h * mask
return h
|