Spaces:
Running
Running
File size: 2,283 Bytes
a3ea5d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def create_var(tensor, requires_grad=None):
if requires_grad is None:
if torch.cuda.is_available():
return Variable(data = tensor).cuda()
else:
return Variable(data = tensor)
else:
if torch.cuda.is_available():
return Variable(data = tensor, requires_grad=requires_grad).cuda()
else:
return Variable(data = tensor, requires_grad=requires_grad)
def index_select_ND(source, dim, index):
index_size = index.size()
suffix_dim = source.size()[1:]
final_size = index_size + suffix_dim
target = source.index_select(dim, index.view(-1))
return target.view(final_size)
def avg_pool(all_vecs, scope, dim):
size = create_var(torch.Tensor([le for _,le in scope]))
return all_vecs.sum(dim=dim) / size.unsqueeze(-1)
def stack_pad_tensor(tensor_list):
max_len = max([t.size(0) for t in tensor_list])
for i,tensor in enumerate(tensor_list):
pad_len = max_len - tensor.size(0)
tensor_list[i] = F.pad( tensor, (0,0,0,pad_len) )
return torch.stack(tensor_list, dim=0)
#3D padded tensor to 2D matrix, with padded zeros removed
def flatten_tensor(tensor, scope):
assert tensor.size(0) == len(scope)
tlist = []
for i,tup in enumerate(scope):
le = tup[1]
tlist.append( tensor[i, 0:le] )
return torch.cat(tlist, dim=0)
#2D matrix to 3D padded tensor
def inflate_tensor(tensor, scope):
max_len = max([le for _,le in scope])
batch_vecs = []
for st,le in scope:
cur_vecs = tensor[st : st + le]
cur_vecs = F.pad( cur_vecs, (0,0,0,max_len-le) )
batch_vecs.append( cur_vecs )
return torch.stack(batch_vecs, dim=0)
def GRU(x, h_nei, W_z, W_r, U_r, W_h):
hidden_size = x.size()[-1]
sum_h = h_nei.sum(dim=1)
z_input = torch.cat([x,sum_h], dim=1)
z = F.sigmoid(W_z(z_input))
r_1 = W_r(x).view(-1,1,hidden_size)
r_2 = U_r(h_nei)
r = F.sigmoid(r_1 + r_2)
gated_h = r * h_nei
sum_gated_h = gated_h.sum(dim=1)
h_input = torch.cat([x,sum_gated_h], dim=1)
pre_h = F.tanh(W_h(h_input))
new_h = (1.0 - z) * sum_h + z * pre_h
return new_h
|