lengocduc195's picture
pushNe
2359bda
raw
history blame
566 Bytes
from torch import Tensor
from torch import nn
from typing import Dict
import torch.nn.functional as F
class Normalize(nn.Module):
"""
This layer normalizes embeddings to unit length
"""
def __init__(self):
super(Normalize, self).__init__()
def forward(self, features: Dict[str, Tensor]):
features.update({'sentence_embedding': F.normalize(features['sentence_embedding'], p=2, dim=1)})
return features
def save(self, output_path):
pass
@staticmethod
def load(input_path):
return Normalize()