Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from utils import enumerate_spans | |
from torch.nn.utils.rnn import pad_sequence | |
class SpanEnumerationLayer(nn.Module): | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
def compute_embeddings(self,embeddings, enumerations, operation = 'sum'): | |
computed_embeddings = [] | |
for enumeration, embedding in zip(enumerations, embeddings): | |
output_embeddings = [] | |
dim_size = embedding.shape[-1] | |
for idx in range(len(enumeration)): | |
x1,x2 = enumeration[idx] | |
output_tensor = embedding[x1:x2+1].sum(dim=0) | |
if(operation == 'average'): | |
divisor = abs((x2+1)-x1) | |
output_tensor=torch.div(output_tensor, divisor) | |
output_embeddings.append(output_tensor) | |
computed_embeddings.append(torch.stack(output_embeddings)) | |
return computed_embeddings | |
def forward(self, embeddings, lengths): | |
enumerations = enumerate_spans(lengths) | |
computed_embeddings = self.compute_embeddings(embeddings, enumerations) | |
return computed_embeddings, enumerations | |