|
import torch |
|
from torch import Tensor |
|
from torch import nn |
|
from typing import Union, Tuple, List, Iterable, Dict |
|
import os |
|
import json |
|
|
|
|
|
class Pooling(nn.Module): |
|
"""Performs pooling (max or mean) on the token embeddings. |
|
|
|
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. |
|
You can concatenate multiple poolings together. |
|
|
|
:param word_embedding_dimension: Dimensions for the word embeddings |
|
:param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings |
|
:param pooling_mode_cls_token: Use the first token (CLS token) as text representations |
|
:param pooling_mode_max_tokens: Use max in each dimension over all tokens. |
|
:param pooling_mode_mean_tokens: Perform mean-pooling |
|
:param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but devide by sqrt(input_length). |
|
:param pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling, see https://arxiv.org/abs/2202.08904 |
|
:param pooling_mode_lasttoken: Perform last token pooling, see https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005 |
|
""" |
|
def __init__(self, |
|
word_embedding_dimension: int, |
|
pooling_mode: str = None, |
|
pooling_mode_cls_token: bool = False, |
|
pooling_mode_max_tokens: bool = False, |
|
pooling_mode_mean_tokens: bool = True, |
|
pooling_mode_mean_sqrt_len_tokens: bool = False, |
|
pooling_mode_weightedmean_tokens: bool = False, |
|
pooling_mode_lasttoken: bool = False, |
|
): |
|
super(Pooling, self).__init__() |
|
|
|
self.config_keys = ['word_embedding_dimension', 'pooling_mode_cls_token', 'pooling_mode_mean_tokens', 'pooling_mode_max_tokens', |
|
'pooling_mode_mean_sqrt_len_tokens', 'pooling_mode_weightedmean_tokens', 'pooling_mode_lasttoken'] |
|
|
|
if pooling_mode is not None: |
|
pooling_mode = pooling_mode.lower() |
|
assert pooling_mode in ['mean', 'max', 'cls', 'weightedmean', 'lasttoken'] |
|
pooling_mode_cls_token = (pooling_mode == 'cls') |
|
pooling_mode_max_tokens = (pooling_mode == 'max') |
|
pooling_mode_mean_tokens = (pooling_mode == 'mean') |
|
pooling_mode_weightedmean_tokens = (pooling_mode == 'weightedmean') |
|
pooling_mode_lasttoken = (pooling_mode == 'lasttoken') |
|
|
|
self.word_embedding_dimension = word_embedding_dimension |
|
self.pooling_mode_cls_token = pooling_mode_cls_token |
|
self.pooling_mode_mean_tokens = pooling_mode_mean_tokens |
|
self.pooling_mode_max_tokens = pooling_mode_max_tokens |
|
self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens |
|
self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens |
|
self.pooling_mode_lasttoken = pooling_mode_lasttoken |
|
|
|
pooling_mode_multiplier = sum([pooling_mode_cls_token, pooling_mode_max_tokens, pooling_mode_mean_tokens, |
|
pooling_mode_mean_sqrt_len_tokens, pooling_mode_weightedmean_tokens, pooling_mode_lasttoken]) |
|
self.pooling_output_dimension = (pooling_mode_multiplier * word_embedding_dimension) |
|
|
|
|
|
def __repr__(self): |
|
return "Pooling({})".format(self.get_config_dict()) |
|
|
|
def get_pooling_mode_str(self) -> str: |
|
""" |
|
Returns the pooling mode as string |
|
""" |
|
modes = [] |
|
if self.pooling_mode_cls_token: |
|
modes.append('cls') |
|
if self.pooling_mode_mean_tokens: |
|
modes.append('mean') |
|
if self.pooling_mode_max_tokens: |
|
modes.append('max') |
|
if self.pooling_mode_mean_sqrt_len_tokens: |
|
modes.append('mean_sqrt_len_tokens') |
|
if self.pooling_mode_weightedmean_tokens: |
|
modes.append('weightedmean') |
|
if self.pooling_mode_lasttoken: |
|
modes.append('lasttoken') |
|
|
|
return "+".join(modes) |
|
|
|
def forward(self, features: Dict[str, Tensor]): |
|
token_embeddings = features['token_embeddings'] |
|
attention_mask = features['attention_mask'] |
|
|
|
|
|
output_vectors = [] |
|
if self.pooling_mode_cls_token: |
|
cls_token = features.get('cls_token_embeddings', token_embeddings[:, 0]) |
|
output_vectors.append(cls_token) |
|
if self.pooling_mode_max_tokens: |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
token_embeddings[input_mask_expanded == 0] = -1e9 |
|
max_over_time = torch.max(token_embeddings, 1)[0] |
|
output_vectors.append(max_over_time) |
|
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
|
|
|
if 'token_weights_sum' in features: |
|
sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size()) |
|
else: |
|
sum_mask = input_mask_expanded.sum(1) |
|
|
|
sum_mask = torch.clamp(sum_mask, min=1e-9) |
|
|
|
if self.pooling_mode_mean_tokens: |
|
output_vectors.append(sum_embeddings / sum_mask) |
|
if self.pooling_mode_mean_sqrt_len_tokens: |
|
output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) |
|
if self.pooling_mode_weightedmean_tokens: |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
|
weights = ( |
|
torch.arange(start=1, end=token_embeddings.shape[1] + 1) |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
.expand(token_embeddings.size()) |
|
.float().to(token_embeddings.device) |
|
) |
|
assert weights.shape == token_embeddings.shape == input_mask_expanded.shape |
|
input_mask_expanded = input_mask_expanded * weights |
|
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
|
|
|
if 'token_weights_sum' in features: |
|
sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size()) |
|
else: |
|
sum_mask = input_mask_expanded.sum(1) |
|
|
|
sum_mask = torch.clamp(sum_mask, min=1e-9) |
|
output_vectors.append(sum_embeddings / sum_mask) |
|
if self.pooling_mode_lasttoken: |
|
bs, seq_len, hidden_dim = token_embeddings.shape |
|
|
|
|
|
|
|
gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1 |
|
|
|
|
|
gather_indices = torch.clamp(gather_indices, min=0) |
|
|
|
|
|
gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) |
|
gather_indices = gather_indices.unsqueeze(1) |
|
assert gather_indices.shape == (bs, 1, hidden_dim) |
|
|
|
|
|
|
|
|
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1) |
|
output_vectors.append(embedding) |
|
|
|
output_vector = torch.cat(output_vectors, 1) |
|
features.update({'sentence_embedding': output_vector}) |
|
return features |
|
|
|
def get_sentence_embedding_dimension(self): |
|
return self.pooling_output_dimension |
|
|
|
def get_config_dict(self): |
|
return {key: self.__dict__[key] for key in self.config_keys} |
|
|
|
def save(self, output_path): |
|
with open(os.path.join(output_path, 'config.json'), 'w') as fOut: |
|
json.dump(self.get_config_dict(), fOut, indent=2) |
|
|
|
@staticmethod |
|
def load(input_path): |
|
with open(os.path.join(input_path, 'config.json')) as fIn: |
|
config = json.load(fIn) |
|
|
|
return Pooling(**config) |
|
|