lengocduc195's picture
pushNe
2359bda
raw
history blame
1.16 kB
import torch
from torch import Tensor
from torch import nn
from typing import Union, Tuple, List, Iterable, Dict
import os
import json
class LayerNorm(nn.Module):
def __init__(self, dimension: int):
super(LayerNorm, self).__init__()
self.dimension = dimension
self.norm = nn.LayerNorm(dimension)
def forward(self, features: Dict[str, Tensor]):
features['sentence_embedding'] = self.norm(features['sentence_embedding'])
return features
def get_sentence_embedding_dimension(self):
return self.dimension
def save(self, output_path):
with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
json.dump({'dimension': self.dimension}, fOut, indent=2)
torch.save(self.state_dict(), os.path.join(output_path, 'pytorch_model.bin'))
@staticmethod
def load(input_path):
with open(os.path.join(input_path, 'config.json')) as fIn:
config = json.load(fIn)
model = LayerNorm(**config)
model.load_state_dict(torch.load(os.path.join(input_path, 'pytorch_model.bin'), map_location=torch.device('cpu')))
return model