nehalelkaref's picture
Update layers.py
aaeb391
raw
history blame
1.26 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import enumerate_spans
from torch.nn.utils.rnn import pad_sequence
class SpanEnumerationLayer(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def compute_embeddings(self,embeddings, enumerations, operation = 'sum'):
computed_embeddings = []
for enumeration, embedding in zip(enumerations, embeddings):
output_embeddings = []
dim_size = embedding.shape[-1]
for idx in range(len(enumeration)):
x1,x2 = enumeration[idx]
output_tensor = embedding[x1:x2+1].sum(dim=0)
if(operation == 'average'):
divisor = abs((x2+1)-x1)
output_tensor=torch.div(output_tensor, divisor)
output_embeddings.append(output_tensor)
computed_embeddings.append(torch.stack(output_embeddings))
return computed_embeddings
def forward(self, embeddings, lengths):
enumerations = enumerate_spans(lengths)
computed_embeddings = self.compute_embeddings(embeddings, enumerations)
return computed_embeddings, enumerations