|
import torch |
|
from torch import nn, Tensor |
|
from typing import Union, Tuple, List, Iterable, Dict |
|
import logging |
|
import gzip |
|
from tqdm import tqdm |
|
import numpy as np |
|
import os |
|
import json |
|
from ..util import import_from_string, fullname, http_get |
|
from .tokenizer import WordTokenizer, WhitespaceTokenizer |
|
|
|
|
|
class CNN(nn.Module): |
|
"""CNN-layer with multiple kernel-sizes over the word embeddings""" |
|
|
|
def __init__(self, in_word_embedding_dimension: int, out_channels: int = 256, kernel_sizes: List[int] = [1, 3, 5], stride_sizes: List[int] = None): |
|
nn.Module.__init__(self) |
|
self.config_keys = ['in_word_embedding_dimension', 'out_channels', 'kernel_sizes'] |
|
self.in_word_embedding_dimension = in_word_embedding_dimension |
|
self.out_channels = out_channels |
|
self.kernel_sizes = kernel_sizes |
|
|
|
self.embeddings_dimension = out_channels*len(kernel_sizes) |
|
self.convs = nn.ModuleList() |
|
|
|
in_channels = in_word_embedding_dimension |
|
if stride_sizes is None: |
|
stride_sizes = [1] * len(kernel_sizes) |
|
|
|
for kernel_size, stride in zip(kernel_sizes, stride_sizes): |
|
padding_size = int((kernel_size - 1) / 2) |
|
conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding_size) |
|
self.convs.append(conv) |
|
|
|
def forward(self, features): |
|
token_embeddings = features['token_embeddings'] |
|
|
|
token_embeddings = token_embeddings.transpose(1, -1) |
|
vectors = [conv(token_embeddings) for conv in self.convs] |
|
out = torch.cat(vectors, 1).transpose(1, -1) |
|
|
|
features.update({'token_embeddings': out}) |
|
return features |
|
|
|
def get_word_embedding_dimension(self) -> int: |
|
return self.embeddings_dimension |
|
|
|
def tokenize(self, text: str) -> List[int]: |
|
raise NotImplementedError() |
|
|
|
def save(self, output_path: str): |
|
with open(os.path.join(output_path, 'cnn_config.json'), 'w') as fOut: |
|
json.dump(self.get_config_dict(), fOut, indent=2) |
|
|
|
torch.save(self.state_dict(), os.path.join(output_path, 'pytorch_model.bin')) |
|
|
|
def get_config_dict(self): |
|
return {key: self.__dict__[key] for key in self.config_keys} |
|
|
|
@staticmethod |
|
def load(input_path: str): |
|
with open(os.path.join(input_path, 'cnn_config.json'), 'r') as fIn: |
|
config = json.load(fIn) |
|
|
|
weights = torch.load(os.path.join(input_path, 'pytorch_model.bin'), map_location=torch.device('cpu')) |
|
model = CNN(**config) |
|
model.load_state_dict(weights) |
|
return model |
|
|
|
|