import torch import torch.nn as nn import torch.nn.functional as F from utils import enumerate_spans from torch.nn.utils.rnn import pad_sequence class SpanEnumerationLayer(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def compute_embeddings(self,embeddings, enumerations, operation = 'sum'): computed_embeddings = [] for enumeration, embedding in zip(enumerations, embeddings): output_embeddings = [] dim_size = embedding.shape[-1] for idx in range(len(enumeration)): x1,x2 = enumeration[idx] output_tensor = embedding[x1:x2+1].sum(dim=0) if(operation == 'average'): divisor = abs((x2+1)-x1) output_tensor=torch.div(output_tensor, divisor) output_embeddings.append(output_tensor) computed_embeddings.append(torch.stack(output_embeddings)) return computed_embeddings def forward(self, embeddings, lengths): enumerations = enumerate_spans(lengths) computed_embeddings = self.compute_embeddings(embeddings, enumerations) return computed_embeddings, enumerations